TARGET CONDITIONED REPRESENTATION INDEPEN-DENCE (TCRI); FROM DOMAIN-INVARIANT TO DOMAIN-GENERAL REPRESENTATIONS

Abstract

We propose a Target Conditioned Representation Independence (TCRI) objective for domain generalization. TCRI addresses the limitations of existing domain generalization methods due to incomplete constraints. Specifically, TCRI implements regularizers motivated by conditional independence constraints that are sufficient to strictly learn complete sets of invariant mechanisms, which we show are necessary and sufficient for domain generalization. Empirically, we show that TCRI is effective on both synthetic and real-world data. TCRI is competitive with baselines in average accuracy while outperforming them in worst-domain accuracy, indicating desired cross-domain stability.

1. INTRODUCTION

Machine learning algorithms are evaluated by their ability to generalize (generate reasonable predictions for unseen examples). Often, learning frameworks are designed to exploit some shared structure between training data and the expected data at deployment. A common assumption is that the training and testing examples are drawn independently and from the same distribution (iid). Given the iid assumption, Empirical Risk Minimization (ERM; Vapnik (1991) ) and its variants give strong generalization guarantees and are effective in practice. Nevertheless, many practical problems contain distribution shifts between train and test domains, and ERM can fail under this setting (Arjovsky et al., 2019) . This failure mode has impactful real-world implications. For example, in safety-critical settings such as autonomous driving (Amodei et al., 2016; Filos et al., 2020) , where a lack of robustness to distribution shift can lead to human casualties; or in ethical settings such as healthcare, where distribution shifts can lead to biases that adversely affect subgroups of the population (Singh et al., 2021) . To address this limitation, many works have developed approaches for learning under distribution shift. Among the various strategies to achieve domain generalization, Invariant Causal Predictions (ICP; Peters et al. ( 2016)) has emerged as popular. ICPs assume that while some aspects of the data distributions may vary across domains, the causal structure (or data-generating mechanisms) remains the same and try to learn those domain-general causal predictors. Following ICP, Arjovsky et al. (2019) propose Invariant Risk Minimization (IRM) to identify invariant mechanisms by learning a representation of the observed features that yields a shared optimal linear predictor across domains. However, recent work (Rosenfeld et al., 2020) , has shown that the IRM objective does not necessarily strictly identify the causal predictors, i.e., the representation learn may include noncausal features. Thus, we investigate the conditions necessary to learn the desired domaingeneral predictor and diagnose that the common domain-invariance Directed Acyclic Graph (DAG) constraint is insufficient to (i) strictly and (ii) wholly identify the set of causal mechanisms from observed domains. This insight motivates us to specify appropriate conditions to learn domain-general models which we propose to implement using regularizers. Contributions. We show that neither a strict subset nor superset of existing invariant causal mechanisms is sufficient to learn domain-general predictors. Unlike previous work, we outline the constraints that identify the strict and complete set of causal mechanisms to achieve domain generality. We then propose regularizers to implement these constraints and empirically show the efficacy of our proposed algorithm compared to the state-of-the-art on synthetic and real-world data. To this end, we observe that the conditional independence measures are effective for model selection -outperforming standard validation approaches. While our contributions are focused on methodology, our results also highlight existing gaps in standard evaluation using domain-average metrics -which we show can hide worst-case performance; arguably a more meaningful measure of domain generality.

2. RELATED WORK

Domain adaptation and generalization have grown to be large sub-fields in recent years. Thus, we do not attempt an exhaustive review, and will only highlight a few papers most related to our work. To address covariate shift, Ben-David et al. (2009) gives bounds on target error based on the Hdivergence between the source and target covariate distributions, which motivates domain alignment methods like the Domain Adversarial Neural Networks (Ganin et al., 2016) . Others have followed up on this work with other notions of covariant distance for domain adaptation such as mean maximum discrepancy (MMD) (Long et al., 2016) and Wasserstein distance (Courty et al., 2017) , etc. However, Kpotufe and Martinet (2018) show that these divergence metrics fail to capture many important properties of transferability, such as asymmetry and non-overlapping support. Zhao et al. (2019) show that even with distribution alignment of covariates, large distances between label distributions inhibit transfer; they propose a label conditional importance weighting method to address this limitation. Additionally, Schrouff et al. (2022) show that many real-world problems contain more complicated 'compound' shifts than covariate shifts. Furthermore, domain alignment methods are useful when one has unlabeled or partially labeled samples from the domain one would like to adapt to during training, however, the domain generalization problem setting may not include such information. The notion of invariant representations starts to address the problem of domain generalization, the topic of this work, rather than domain adaptation. Arjovsky et al. (2019) propose an objective to learn a representation of the observed features which, when conditioned on, yields a distribution on targets that is domain-invariant, that is, conditionally independent of domain. They argue that satisfying this invariance gives a feature representation that only uses domain-general information. However, Rosenfeld et al. (2020) shows that the IRM objective can fail to recover a predictor that does not use spurious correlations without observing a number of domains greater than the number of spurious features, which can inhibit generalization. Variants of this work Krueger et al. (2021) ; Wang et al. (2022) address this problem by imposing higher order moment constraints to reduce the necessary number of observed domains. However, as we will show, invariance on the observed domain is not sufficient for domain generalization. Additionally, one of the motivations for domain generalization is mitigating the worst domain performance. Gulrajani and Lopez-Paz (2020) observe empirically that ERM is competitive and often best in worst domain accuracy across a range of real-world datasets. Rosenfeld et al. (2022) analyze the task of domain generalization as extrapolation via bounded affine transformations and find that ERM remains minimax optimal in the linear regime. However, the extent of the worst-domain shift is often unknown in practice and may not be captured by bounded affine transformations Shen et al. (2021) . In contrast, our work allows for arbitrary distribution shifts, provided that invariant mechanisms remain unchanged. In addition, we show that our proposed method gives a predictor that recovers all domain-general mechanisms and is free of spurious correlations without necessitating examples (neither labeled nor unlabeled) from the target domain.

3. PROBLEM SETUP

We consider the data generating mechanism as described by the causal graph in Figure 1 and the equivalent structural equation model (or structural causal model SCM (equation 1; Pearl (2010)). One particular setting where this graph applies is medicine where we are often interested in predicting conditions from potential causes and symptoms of the condition. Additionally, these features may be influenced by demographic factors that may vary across hospitals (Schrouff et al., 2022) . Additionally, in physical processes where measurement is slow, one observes both upstream (causal) and downstream (anticausal) features of the events of interest. An example is in task-fMRI where the BOLD (Blood-Oxygen-Level-Dependent) signal in task-fMRI (functional Magnetic Resonance

