WHICH INVARIANCE SHOULD WE TRANSFER? A CAUSAL MINIMAX LEARNING APPROACH

Abstract

A major barrier to deploy current machine learning models lies in their sensitivity to dataset shifts. To resolve this problem, most existing studies attempted to transfer stable information to unseen environments. Among these, graph-based methods causally decomposed the data generating process into stable and mutable mechanisms. By removing the effect of mutable generation, they identified a set of stable predictors. However, a key question regarding robustness remains: which subset of the whole stable information should the model transfer, in order to achieve optimal generalization ability? To answer this question, we provide a comprehensive minimax analysis that fully characterizes conditions for a subset to be optimal. Particularly in general cases, we propose to maximize over mutable mechanisms (i.e., the source of dataset shifts), which is provable to identify the worst-case risk over all environments. This ensures us to select the optimal subset with the minimal worst-case risk. To reduce computational costs, we propose to search over only equivalent classes in terms of worst-case risk, instead of over all subsets. In cases when the searching space is still large, we turn this subset selection problem into a sparse min-max optimization scheme, which enjoys the simplicity and efficiency of implementation. The utility of our methods is demonstrated on the diagnosis of Alzheimer's Disease and gene function prediction.

1. INTRODUCTION

Current machine learning systems, which are commonly deployed based on their in-distribution performance, often encounter dataset shifts Subbaswamy et al. ( 2019) such as covariate shift, label shift, etc., due to changes in the data generating process. When such a shift exists in deployment environments, the model may give unreliable prediction results, which can cause severe consequences in safe-critical tasks such as healthcare (Hendrycks et al., 2021) . At the heart of this unreliability issue are stability and robustness aspects, which refer to the insensitivity of prediction behavior and generalization errors over shifts, respectively. For example, consider the system deployed to predict the Functional Activities Questionnaire (FAQ) score, which is commonly adopted Mayo (2016) to measure the severity of Alzheimer's Disease (AD). During prediction, the system can only access biomarkers or volumes of brain regions with anonymous demographic information for privacy consideration. However, the changes in demographics can cause shifts in covariates. To achieve reliability for the deployed model, it is desired for its prediction to be stable against demographic changes, and meanwhile to be constantly accurate over all different populations. To incorporate both aspects, this paper targets at finding the most robust (i.e., min-max optimal Müller et al. ( 2020)) predictor, among the set of stable predictors over all distributions. To achieve this goal, many studies have proposed to learn invariance to transfer to unseen data. Examples include ICP Peters et al. (2016) and (Arjovsky et al., 2019; Liu et al., 2021; Ahuja et al., 2021) that assumed the prediction mechanism given causal features or representations to be invariant; Anchor Regression Rothenhäusler et al. ( 2021) that explicitly attributed the variation to exogenous variables. Particularly, the Subbaswamy & Saria (2020); Subbaswamy et al. ( 2019) causally decomposed the joint distribution into mutable M and stable S sets, with respectively changed and unchanged causal mechanisms. They then proposed to intervene on M to obtain a set of stable predictors. Still, a question regarding robustness remains: which subset of stable information should the model transfer, in order to be most robust against dataset shifts? The answer given by To give a comprehensive answer, we first provide a graphical condition that is sufficient for the whole stable set to be optimal. This condition can be easily tested via causal discovery. When this condition fails, we prove that the worst-case risk can be identified by maximizing over the generating mechanism of M , i.e., the only source of shift. This conclusion ensures us to select the optimal subset in a more accurate way. Consider again the example of FAQ prediction in AD diagnosis, Fig. 1 (b) shows that our method is more reflective of the maximal mean squared error (MSE) than Subbaswamy et al. ( 2019), which explains our advantage in predicting FAQ across patient groups shown in Fig. 1 (a) . Besides, to reduce the searching cost, we propose to search over only equivalent classes in terms of worst-case risk. We however find that in some cases such a search can still be expensive. To improve efficiency in these cases, we turn this subset selection task into a sparse min-max optimization scheme, which alternates between a gradient ascent step on the M 's generating function and a sparse optimization with Lasso-type penalty to detect the optimal subset. We demonstrate the utility of our methods on a synthetic dataset and two real-world applications: Alzheimer's Disease diagnosis and gene function prediction. Contributions. We summarize our contributions as follows: 1. We propose to identify the optimal subset of invariance to transfer, guided by a comprehensive min-max analysis. To the best of our knowledge, we are the first to comprehensively study the problem of which part among all sources of invariance should the model transfer, in the literature of robust learning. 2. We introduce the concept of "equivalent relation" in terms of worst-case risk, in order to analyze the computational complexity, and propose a sparse min-max optimization method as a surrogate scheme to improve efficiency. 3. Our method can significantly outperform others in terms of subset selection and generalization robustness, on Alzheimer's Disease diagnosis and gene function prediction. 



Figure 1: FAQ prediction in Alzheimer's Disease. (a) Maximal mean square error (MSE) over test environments; (b) Maximal MSE of predictors that are ranked in ascending order from left to right, respectively according to the estimated worst-case risk of our method and the validation's loss of the graph surgery estimator Subbaswamy et al. (2019). As shown, our method is more reflective of the maximal MSE than the graph surgery method.

Problem Setup & Notations. We consider the supervised regression scenario, where the system includes a target variable Y ∈ Y, a multivariate predictive variable X := [X 1 , ..., X d ] ∈ R d , and data collected from heterogeneous environments. In practice, different "environments" can refer to different groups of subjects or different experimental settings. We use {D e |e ∈ E Tr } to denote our training data, with D

