TARGET CONDITIONED REPRESENTATION INDEPEN-DENCE (TCRI); FROM DOMAIN-INVARIANT TO DOMAIN-GENERAL REPRESENTATIONS

Abstract

We propose a Target Conditioned Representation Independence (TCRI) objective for domain generalization. TCRI addresses the limitations of existing domain generalization methods due to incomplete constraints. Specifically, TCRI implements regularizers motivated by conditional independence constraints that are sufficient to strictly learn complete sets of invariant mechanisms, which we show are necessary and sufficient for domain generalization. Empirically, we show that TCRI is effective on both synthetic and real-world data. TCRI is competitive with baselines in average accuracy while outperforming them in worst-domain accuracy, indicating desired cross-domain stability.

1. INTRODUCTION

Machine learning algorithms are evaluated by their ability to generalize (generate reasonable predictions for unseen examples). Often, learning frameworks are designed to exploit some shared structure between training data and the expected data at deployment. A common assumption is that the training and testing examples are drawn independently and from the same distribution (iid). Given the iid assumption, Empirical Risk Minimization (ERM; Vapnik (1991) ) and its variants give strong generalization guarantees and are effective in practice. Nevertheless, many practical problems contain distribution shifts between train and test domains, and ERM can fail under this setting (Arjovsky et al., 2019) . This failure mode has impactful real-world implications. For example, in safety-critical settings such as autonomous driving (Amodei et al., 2016; Filos et al., 2020) , where a lack of robustness to distribution shift can lead to human casualties; or in ethical settings such as healthcare, where distribution shifts can lead to biases that adversely affect subgroups of the population (Singh et al., 2021) . To address this limitation, many works have developed approaches for learning under distribution shift. Among the various strategies to achieve domain generalization, Invariant Causal Predictions (ICP; Peters et al. (2016) ) has emerged as popular. ICPs assume that while some aspects of the data distributions may vary across domains, the causal structure (or data-generating mechanisms) remains the same and try to learn those domain-general causal predictors. Following ICP, Arjovsky et al. (2019) propose Invariant Risk Minimization (IRM) to identify invariant mechanisms by learning a representation of the observed features that yields a shared optimal linear predictor across domains. However, recent work (Rosenfeld et al., 2020) , has shown that the IRM objective does not necessarily strictly identify the causal predictors, i.e., the representation learn may include noncausal features. Thus, we investigate the conditions necessary to learn the desired domaingeneral predictor and diagnose that the common domain-invariance Directed Acyclic Graph (DAG) constraint is insufficient to (i) strictly and (ii) wholly identify the set of causal mechanisms from observed domains. This insight motivates us to specify appropriate conditions to learn domain-general models which we propose to implement using regularizers. Contributions. We show that neither a strict subset nor superset of existing invariant causal mechanisms is sufficient to learn domain-general predictors. Unlike previous work, we outline the constraints that identify the strict and complete set of causal mechanisms to achieve domain generality. We then propose regularizers to implement these constraints and empirically show the efficacy of our proposed algorithm compared to the state-of-the-art on synthetic and real-world data. To this end, we observe that the conditional independence measures are effective for model selection -outperforming standard validation approaches. While our contributions are focused on methodology, our results also highlight existing gaps in standard evaluation using domain-average metrics -which we show can hide worst-case performance; arguably a more meaningful measure of domain generality.

2. RELATED WORK

Domain adaptation and generalization have grown to be large sub-fields in recent years. Thus, we do not attempt an exhaustive review, and will only highlight a few papers most related to our work. To address covariate shift, Ben-David et al. (2009) gives bounds on target error based on the Hdivergence between the source and target covariate distributions, which motivates domain alignment methods like the Domain Adversarial Neural Networks (Ganin et al., 2016) . Others have followed up on this work with other notions of covariant distance for domain adaptation such as mean maximum discrepancy (MMD) (Long et al., 2016) and Wasserstein distance (Courty et al., 2017) , etc. However, Kpotufe and Martinet (2018) show that these divergence metrics fail to capture many important properties of transferability, such as asymmetry and non-overlapping support. Zhao et al. (2019) show that even with distribution alignment of covariates, large distances between label distributions inhibit transfer; they propose a label conditional importance weighting method to address this limitation. Additionally, Schrouff et al. (2022) show that many real-world problems contain more complicated 'compound' shifts than covariate shifts. Furthermore, domain alignment methods are useful when one has unlabeled or partially labeled samples from the domain one would like to adapt to during training, however, the domain generalization problem setting may not include such information. The notion of invariant representations starts to address the problem of domain generalization, the topic of this work, rather than domain adaptation. Arjovsky et al. (2019) propose an objective to learn a representation of the observed features which, when conditioned on, yields a distribution on targets that is domain-invariant, that is, conditionally independent of domain. They argue that satisfying this invariance gives a feature representation that only uses domain-general information. However, Rosenfeld et al. (2020) shows that the IRM objective can fail to recover a predictor that does not use spurious correlations without observing a number of domains greater than the number of spurious features, which can inhibit generalization. Variants of this work Krueger et al. (2021) ; Wang et al. (2022) address this problem by imposing higher order moment constraints to reduce the necessary number of observed domains. However, as we will show, invariance on the observed domain is not sufficient for domain generalization. Additionally, one of the motivations for domain generalization is mitigating the worst domain performance. Gulrajani and Lopez-Paz (2020) observe empirically that ERM is competitive and often best in worst domain accuracy across a range of real-world datasets. Rosenfeld et al. (2022) analyze the task of domain generalization as extrapolation via bounded affine transformations and find that ERM remains minimax optimal in the linear regime. However, the extent of the worst-domain shift is often unknown in practice and may not be captured by bounded affine transformations Shen et al. (2021) . In contrast, our work allows for arbitrary distribution shifts, provided that invariant mechanisms remain unchanged. In addition, we show that our proposed method gives a predictor that recovers all domain-general mechanisms and is free of spurious correlations without necessitating examples (neither labeled nor unlabeled) from the target domain.

3. PROBLEM SETUP

We consider the data generating mechanism as described by the causal graph in Figure 1 and the equivalent structural equation model (or structural causal model SCM (equation 1; Pearl (2010) ). One particular setting where this graph applies is medicine where we are often interested in predicting conditions from potential causes and symptoms of the condition. Additionally, these features may be influenced by demographic factors that may vary across hospitals (Schrouff et al., 2022) . Additionally, in physical processes where measurement is slow, one observes both upstream (causal) and downstream (anticausal) features of the events of interest. An example is in task-fMRI where the BOLD (Blood-Oxygen-Level-Dependent) signal in task-fMRI (functional Magnetic Resonance Imaging) is much slower than the neural processes in the brain associated with the task (Glover, 2011) . Many other real-world problems fall under this causal and anticausal setting; we note that this graph is also assumed by previous work Arjovsky et al. (2019) . We also assume that the observed data are drawn from a set of E tr training domains E tr = {e i : i = 1, 2, . . . , E tr }, all generated from Equation 1, thereby fixing the mechanisms by which the observed distribution is generated: SCM(e i ) :=        z (ei) c ∼ P (ei) Zc y (ei) = f y z (ei) c , ν where ν ⊥ ⊥ z (ei) c , z = f ze y (ei) , η (ei) where η (ei) ⊥ ⊥ y (ei) , where P Zc is the causal covariate distribution, f y , f ze are generative mechanisms of y and z e , respectively, and ν, η are exogenous variables. These mechanisms are assumed to hold for any domain generated by this generative process, i.e., µ ei (y | z c ) = µ(y | z c ) and ν ei (z e | y) = ν(z e | y) ∀e i ∈ E for some distributions µ and ν, where E is the set of all possible domains. Under the Markov assumption, we can immediately read off some properties of any distribution induced by the datagenerating process shown in Figure 1 : (i) e ⊥ ⊥ Y | Z c , (ii) Z c ⊥ ⊥ Z e | Y, e, and (iii) Y ̸ ⊥ ⊥ e | X. However, as shown in Figure 1 , we only observe an unknown function of latent variables z c and z e , x = h(z c , z e ) and we would like predictions that do not depend on domains (e). Ideally, we want to map x → z c such that the mechanism f y can be learned, as this would suffice for domain generalization. In contrast, though we have that the mechanism z e → y is preserved, we have no such guarantee on the inverse z e ̸ ← y as it may not exist or be unique and, therefore, does not satisfy domain generalization. This becomes a problem when considering mappings that include z e : µ ei (y | z e ) ̸ = µ ej (y | z e ) and µ ei (y | z c , z e ) ̸ = µ ej (y | z c , z e ) for i ̸ = j. The latter implies that µ ei (y | x) ̸ = µ ej (y | x) for i ̸ = j, and therefore, the original observed features will not be domain-general. Note that the generative process and its implicates are vital for our approach, as we assume that it is preserved across domains. Assumption 3.1. We assume that all distributions, train, and test (observed and unobserved at train), are generated by the generative process described in Equation 1. In the following sections, we will introduce our proposed algorithm to learn a feature extractor that maps to Z c and show the generative properties that are necessary and sufficient to do so.

4. TARGET-CONDITIONED REPRESENTATION INDEPENDENCE OBJECTIVE

We first define two distinct types of representations of observed features that we will need henceforth -domain-invariant and domain-general. Definition 4.1. A domain-invariant representation Φ(X), with respect to a set of observed domains E tr , is one that satisfies the intervention properties µ ei (y|Φ(X) = x) = µ ej (y|Φ(X) = x) ∀e i , e j ∈ E tr for any fixed x, where e i , e j are domain identifiers and µ is a probability distribution. In other words, under a domain-invariant representation, the output conditional distribution for a given input is necessarily the same across the reference (typically observed) domains. This is consistent with the existence of µ s.t. µ (y|Φ(X) = x) = µ (y|do(Φ(X) = x)), where do(•) is the "do-operator" Pearl (2010) , denoting an intervention (arbitrary assignment of value). Conversely, for a domain-specific representation, the output conditional distribution for a given input need not be the same across domains. X e i F X e i F Φ Ψ Z c Z e θ c θ e i ⊕ y c y e i However, by definition, this representation is domain-invariant up to a specific set of domains, often the set of domains that it is learned on. So, it may not be domain-invariant with respect to test domains without additional assumptions on the set of training domains, e.g., their convex hull (Rosenfeld et al., 2020) . So, we need to refine this property to better specify the representations we would like to learn. This motivates the following definition which ties domain-invariance to a specific generative process, as opposed to a set of the observed distributions, which, along with Assumption 3.1, connects causality to domain-generality. Definition 4.2. A representation is domain-general with respect to a DAG G if it is domain-invariant with respect to E, where E is the set of all possible domains generated by G. By Assumption 3.1, causal mechanisms from features Z c to target Y are domain-general, so a natural strategy is to extract Z c from the observed features. In order to understand when we can recover Z c from the observed X, we will consider two conditional independence properties implied by the assumed causal graph: Y ⊥ ⊥ e | Z c , which we call the domain invariance property, and Z c ⊥ ⊥ Z e | Y, e, which we capture in the following target conditioned representation independence property (TCRI; definition 4.3). Definition 4.3 (Target Conditioned Representation Independence). Two functions, Φ, Ψ, are said to satisfy TCRI with respect to random variables X, Y, e i if I(Φ(X), Ψ(X); Y ) = I(Z c , Z e ; Y ) (total-chain-information-criterion) and Φ(X) ⊥ ⊥ Ψ(X) | Y ∀e i . We will show that these properties together identify Z c from X, Y to give a domain-general representation in Section 5. Based on our results, we design an algorithm to learn a feature mapping that recovers Z c , i.e., the domain-general representation (mechanisms). Figure 2 illustrates the learning framework. In practice, we propose a TCRI objective containing four terms, each related to the properties desired of the learned representations, as follows, L T CRI = 1 E tr ei∈Etr [αL Φ + (1 -α)L Φ⊕Ψ + λL IRM v1 ′ + βL CI ] , where α ∈ [0, 1], β > 0, and λ > 0 are hyperparameters -Figure 2 shows the full framework. In detail, we let L Φ = R(θ c • Φ) represent the domain-general predictor accuracy and let L IRM v1 ′ be a penalty from on Φ and the linear predictor θ c that enforces that Φ has the same optimal predictor across training domains, capturing the domain invariance property, where Φ : Arjovsky et al., 2019) : X → H Φ , θ c : H Φ → Y ( L IRM v1 ′ = ∥∇ θc R ei (θ c • Φ)∥ 2 , where R ei denotes the empirical risk achieved on domain e i . The L Φ and L IRM v1 ′ together implement a domain-invariance property, however, we know that this is not sufficient for domain generalization (Rosenfeld et al., 2020) . We will show later that the addition of the TCRI property suffices for domain generalization (Theorem 5.4). To implement the TCRI property (definition 4.3), we also learn a domain-specific representation Ψ, which is constrained to be (i) conditionally independent of the domain-general representation given the target and domain and (ii) yield a predictor as good as one from X when combined with Φ. We first address (ii); given a domain-specific representation, Ψ : X → H Ψ , we define a set of domain-specific predictors {θ ei : H Φ × H Ψ → Y : i = 1, . . . , E tr }. We then add a term in the objective that minimizes the loss of these domain-specific predictors: L Φ⊕Ψ = R ei (θ ei • (Φ ⊕ Ψ)) . This term aims to enforce the total-chain-information-criterion of TCRI, by allowing Φ and Ψ to minimize a domain-specific loss together, where non-domain-general information about the target can be used to improve performance. Since we have that both the domain-general and domain-specific have unique information about the target, the optimal model will use both types of information. We allow the θ ei for each domain to be different since, by definition of the problem, we expect these mechanisms to vary across domains. We define L CI to be the conditional independence part of the TCRI property and use the V-statisticbased Hilbert-Schmidt Independence Criterion (HSIC) estimate (more details on HSIC can be found in Gretton et al. (2007) ). For the two representations Φ(X), Ψ(X) for which we want to determine independence (Φ(X) ⊥ ⊥ Ψ(X)), define L CI = 1 C C k=1 HSIC(Φ(X) k , Ψ(X) k ) = 1 C C k=1 1 n 2 k trace(K Φ H n k K Ψ H n k ), where k, indicates which class the examples in the estimate correspond to, C is the number of classes, K Φ ∈ R n k ×n k , K Ψ ∈ R n k ×n k are Gram matrices, K i,j Φ = κ(Φ(X) i , Φ(X) j ), K i,j Ψ = ω(Ψ(X) i , Ψ(X) j ) with kernels κ, ω are radial basis functions, H n k = I n k -1 n 2 k 11 ⊤ is a centering matrix, I n k is the n k × n k dimensional identity matrix, 1 n k is the n k -dimensional vector whose elements are all 1, and ⊤ denotes the transpose. We condition on a label and domain by taking only examples of each label-domain pair and computing the empirical HSIC; then we take the average across all labels. We note that any criterion for conditional independence can be used as L CI , e.g., partial covariance. Altogether, we have the following task: min Φ,Ψ,θc,θ1,θ2,...,θ E tr 1 E tr ei∈Etr [αL Φ + (1 -α)L Φ⊕Ψ + L IRM v1 ′ + βL CI ] . We compute the complete objective for each domain separately to condition on the domains e, and after minimizing this objective, only the invariant representation and its predictor, θ c • Φ, are used.

5. CONDITIONS FOR A DOMAIN-GENERAL REPRESENTATIONS

Now we provide some analysis to justify our method. Consider a feature extractor Φ : X → Z, where X , Z are input and latent features spaces, respectively. We first show that a Φ that captures a strict subset of the causal features satisfies the domain invariance property Y ⊥ ⊥ e | Φ(X) while not necessarily being domain-general. Lemma 5.1 (Insufficiency of Causal Subsets for domain generalization). Conditioning on a subset of causal variables (invariant mechanisms) does not imply domain generalization (definition 4.2). Z ⊂ Z c ̸=⇒ µ ei (y | Z = z) = µ ej (y | Z = z)∀e i ̸ = e j , z ∈ Z where Z c is the causal feature space. Proof. We provide a simple counterexample. Suppose we have the following generative process d ei,1 → z 1 → y ← z 2 ← d ei,2 with z 1 , z 2 ∈ R, d ei,1 , d ei,2 ∈ R ≥0 , r ∈ R >0 , and y = 0 if z 2 1 + z 2 2 <= r 1 if z 2 1 + z 2 2 > r.

Suppose we observe domains {e

i : i = 1, . . . , E tr }, where z 1 ∼ Uniform(-r + d ei,1 , r + d ei,1 ) and z 2 ∼ Uniform(-r + d ei,2 , r + d ei,2 ) for domain e i and d ei is a domain-specific quantity. Now, suppose we condition on z 1 . The optimal predictor is µ ei (y = 1 | z 1 ) = 2 r 2 -z 2 1 -d ei,2 . In this case, z 1 , a causal subset, does not yield domain-general representation since its optimal predictor depends on d ei , which changes across domains. We provide a visualization of Lemma 5.1 in Figure 3 . Additionally, we show that subsets of the causal features can satisfy the domain-invariance property. Lemma 5.2. A representation Φ that maps to a strictly causal subset can be Domain-invariant (definition 4.1). Proof. This proof follows a similar argument as Lemma 5.1. We replace the mechanism for z 1 with z 1 ∼ Uniform(d ei,1 • -r, d ei,1 • r), i.e., the domain is scaled along z 1 . Suppose d ei,2 = 0 for all training domains {e i : i = 1, . . . , E tr } and Φ(z 1 , z 2 ) = z 1 . There is still a covariate distribution shift due to varying values (scaling) of d ei,1 . Then Φ is domain invariant on the training domain but will no longer be domain-general with respect to any domain e j where d ej ,2 > 0. In other words, one failure case that illustrates a limitation of strictly considering the domaininvariance property (Φ(X) ⊥ ⊥ Y | e) is when the shifts observed in the training set are different than those in the test set. Specifically, an incomplete causal representation may not be domaingeneral if the shifts at test-time are on causal features that are not captured in the incomplete causal representation. We now show that the domain-invariance property and Target Conditioned Representation Independence (TCRI) property are sufficient for recovering the complete set of invariant mechanisms. Lemma 5.3 shows that the TCRI property directly solves the problem identified in Lemma 5.1, that is, the TCRI property implies that all causal information about the target is aggregated into one representation Φ. Now, to identify Z c , we only need to show that Φ is strictly causal. Theorem 5.4 (Sufficiency of TCRI + domain-invariance properties for identifying Z c ). Recall that Z c , Z e are the true latent features in the assumed generative model and X are the observed features -Figure 1 . If Φ, Ψ satisfy TCRI and domain-invariance property, then Φ recovers Z c and is therefore domain-general. Proof. By Lemma 5.3, we have that when Φ, Ψ satisfy TCRI, Φ(X) contains all of the causal (necessarily domain-general) information in X. However, Φ(X) may also contain non-causal (domain-specific) information, spoiling the domain-generality of Φ. It remains to show that Φ(X) is strictly causal when we add the domain-invariance property. If Φ satisfies the domain-invariance property, then Y ⊥ ⊥ e | Φ(X). Clearly, this cannot be the case if Φ(X) contains features of Z e , since e and Y are colliders on Z e and therefore conditioning on Z e opens a path between e and Y , making them dependent. Thus Φ(X) can only contain causal features. Therefore a representation that satisfies TCRI and the domain-invariance property is wholly and strictly causal and thus domain-general. The latter follows from the Z c having invariant mechanisms. Theorem 5.4 suggests that by learning two representations that together capture the mutual information between the observed X, Y , where one satisfies the domain-invariance property and both satisfy TCRI, one can recover the strictly and complete causal feature extractor and domain-general predictor. Remark 5.5. (Revisiting Lemma 5.1's counterexample) Given two representations Φ, Ψ that satisfy TCRI, Φ necessarily captures z 1 , z 2 . By definition, Φ, Ψ must capture all of the information in z 1 , z 2 about y, and we know from the graph that they are conditionally dependent given y, i.e, z 1 , z 2 are common causes of y (colliders), so conditioning on y renders the marginally independent variables dependent. So, z 1 and z 2 must be captured the same feature extractor. Remark 5.6. One limitation of TCRI is a failure mode when the strictly anticausal representation gives a domain invariant predictor. In this case, either representation may be Φ. However, one of the benefits of having a domain-specific predictor for each observed domain is that we can check if those classifiers are interchangeable. Specifically, if we are in this scenario where the causal features are mapped to the domain-specific feature extractor, we will see that the domain-specific classifiers give similar results when applied to a domain that they were not trained on since they are based on invariant causal mechanisms. This, however, gives a test not a fix for this setting -we leave a fix for future work.

6. EXPERIMENTS

SCM(e i ) :=       z (ei) c ∼ Exp (σ ei ) y (ei) = z (ei) c + Exp (0.25) , z (ei) e = y (ei) + Exp σ ei η . (4) Model Φ 0,0 Φ 1,0 ERM 0. c = 1., i.e., ŷ = x × Φ × w, where x ∈ R N ×2 , Φ ∈ R 2×1 , w ∈ R. Oracle indicates the coefficients achieved by regressing y on z c directly. To evaluate our method in a setting that exactly matches our assumptions and we know the ground truth mechanisms, we use Equation 4to generate our linear-Gaussian data, with domain parameters σ ei , σ ei η -code provided in the supplemental materials. We observe 2 domains with parameters σ ei=0 = σ ei=0 η = 0.1, σ ei=1 = σ ei=1 η = 1, each with 1000 samples, and use linear feature extractors and predictors. Minimizing the TCRI objective (Equation 3) recovers a linear feature representation that maps back to z c (Table 1 ). Note that for 2021)). Additional discussion on the algorithms can be found in the appendix. ERM, λ = β = 0, α = 1, IRM, λ = 0.1, β = 0, α = 1, We evaluate our proposed method on real-world datasets. Given observed domains E tr = {e i : i = 1, 2, . . . , E tr }, we train on E tr \ e i and evaluate the model on the unseen domain e i , for each e i . Model Selection: Typically, ML practitioners use a within-domain hold-out validation set for model selection. However, this strategy is biased towards the empirical risk minimizer, i.e., the one with the lowest error on the validation set from the training domains, however, we know that the model that achieves the highest validation accuracy may not be domain-general. This same is true if we use an out-of-domain validation set that is not from the target domain. Alternatively, we propose to leverage the generative assumptions for model selection. Specifically, we consider other properties of our desired model for model selection; specifically, a low L CI . To do this, we follow the practice of a hold-out within-domain validation set, however, we compute L CI for the validation set and choose the example with the lowest CI score instead of the highest validation accuracy. We compare this strategy with validation accuracy in our results. Additional details can be found in Appendix C ColoredMNIST: We evaluate our method on the ColoredMNIST dataset Arjovsky et al. (2019) which is composed of 7000 (2 × 28 × 28, 1) images of a hand-written digit and binary-label pairs. There are three domains with different correlations between image color and label, i.e., the image color is spuriously related to the label by assigning a color to each of the two classes (0: digits 0-4, 1: digits 5-9). The color is then flipped with probabilities {0.1, 0.2, 0.9} to create three domains, making the color-label relationship domain-specific, as it changes across domains. There is also label flip noise of 0.25, so we expect that the best we can do is 75% accuracy. As in Figure 1 , Z c corresponds to the original image, Z e the color, e the label-color correlation, Y the image label, and X the observed colored image. Code (a variant of https://github.com/facebookresearch/DomainBed) can be found at https://anonymous.4open.science/r/DomainBed-8D3F. We use MNIST-ConvNet Gulrajani and Lopez-Paz (2020) backbones for the MNIST datasets and parameterize our experiments with the DomainBed hyperparameters with three trials to select the best model Gulrajani and Lopez-Paz (2020) . The MNIST-ConvNet backbone corresponds to the generic featurizer F in Figure 2 , and both Φ and Ψ are linear layers of size 128 × 128 that are appended to the backbone. The predictors θ c , θ 1 , . . . , θ Etr are also parameterized to be linear and appended to the Φ and Ψ layers, respectively.

6.2. RESULTS AND DISCUSSION

We observe that TCRI is competitive in average accuracy with the baseline methods. However, it is significantly more stable when using conditional independence for model selection, i.e., the worst-domain accuracy is highest and variance across domains is lowest for TCRI -both by a large margin. We note that domain -90%, which has a color-label relationship flip probability of 0.9, has a majority color-label pairing that is opposite of domains +80% and +90%, with flip probabilities of 0.1 and 0.2, respectively. Hence, we observe that the baseline algorithms generalize poorly to domain -90%, relative to the other two domains. This is a clear indication that, in TCRI, the baselines are using spurious information (color) for prediction. While TCRI does not obtain the expected best accuracy of 75%, it is evident that the information being used for prediction is general across the three domains, given the low variance in cross-domain accuracy. Worst-domain Accuracy: An important implication of a domain-general is stability -robustness in worst-domain performance, up to domain difficulty. While average accuracy across domains, provides some insight into an algorithm's ability to generalize to new domains, it is susceptible to being dominated by the performance of subsets of observed domains. For example, ARM outperforms the baselines in average accuracy, but this improvement is driven primarily by the first domain (+90%), while the worst-domain accuracy stays the same. In the context of real-world challenges such as algorithmic fairness, comparable worst-domain accuracy is necessary (Hardt et al., 2016) . TCRI achieves 5x (53.0 vs. 10.2) the worst-domain accuracy of the best baseline while maintaining the competitive average accuracy -outperforming most of the baselines on average. Additionally, the variance of the domain accuracies is over 10x lower than that of the baselines, showing further evidence of the cross-domain stability one would expect from a domain-general algorithm. The baselines also include V-REx, which implements regularizers on risk variance across observed domains.

Limitations:

The strength of TCRI is also its limitation; TCRI is very conservative, so as to be robust to worst-domain shifts. While many important real-world problems require robustness to worstdomain shifts, this is not always the case, and in this setting, TCRI sacrifices performance gains from non-domain-general information that may be domain-invariant with respect to the expected domains. The practitioner should apply this method when it is appropriate, i.e., when domain generality is critical and the target domains may be sufficiently different than the source domain. It is, however, important to note that in many settings where one is happy with domain-invariance as opposed to domain-generality, ERM may be sufficient (Gulrajani and Lopez-Paz, 2020; Rosenfeld et al., 2022) .

7. CONCLUSION AND FUTURE WORK

We address the limitations of state-of-the-art algorithms' inability to learn domain-general predictors by developing an objective that enforces DAG properties that are sufficient to disentangle causal (domain-general) and anticausal mechanisms. We compare domain generality to domain-invarianceshowing that our method is competitive with other state-of-the-art domain-generalization algorithms on real-world datasets in terms of average across domains. Moreover, TCRI outperforms all baseline algorithms in worst-domain accuracy, indicating desired stability across domains. We also find that using conditional independence metrics for model selection outperforms the typical validation accuracy strategy. Future work includes further investigating other model selection strategies that preserve the desired domain-generality properties and curating more benchmark real-world datasets that exhibit worst-case behavior.

A SIMULATED DATA

We observe 2 domain with parameters σ ei=0 = σ ei=0 η = 0.1, σ ei=1 = σ ei=1 η = 1, each with 1000 samples. We let , and use linear feature extractors and predictors. Minimizing the TCRI objective (Equation 3) recovers a linear feature representation that maps back to z c (Table 1 ). Note that for ERM, λ = β = 0, α = 1, IRM, λ = 0.1, β = 0, α = 1, and TCRI, λ = 0.1, β = 10 and α = 0.75; additional details can be found in Appendix A. In addition to letting θ c = 1. be a dummy variable, we also solve the OLS (Ordinary Least Squares -(X ⊤ X) -1 X ⊤ y) problem to compute the L Φ⊕Ψ term in the loss. Each backward pass takes in all examples from a domain as a batch.

B HYPERPARAMETER SELECTION

We do a random search over hyperparameters for our method -six randomly selected hyperparameter sets in total. Additionally, for each set, we run three trials to generate standard errors. Additional sampling details can be found in https://anonymous.4open.science/r/DomainBed-8D3F/domainbed/hparams_registry.py. We use the default values for the baseline algorithms since the results closely match those reported by Gulrajani and Lopez-Paz (2020) . The hyperparameters used for each trial are also provided in the supplemental material.

C MODEL SELECTION

Across the hyperparameter sweep, we select the model with the lowest average conditional independence score between the two TCRI representations (Definition 4.3) to evaluate on our test set. This is in lieu of selecting the model with the highest validation accuracy on the training domains. Additionally, we show results based on oracle selection, that is, selection based on held-out target domain data. We observe that TCRI still outperforms the baseline methods and has accuracies in the source domains that are closer to the baselines. This suggests that the regularizers are not so harsh that the TCRI models are not able to learn good predictors in practice. The results, however, do suggest that there is room for improvement in model selection. We leave this for future work. 

D ON BENCHMARK DATASETS FOR EVALUATING DOMAIN GENERALIZATION -WORST-CASE

We show some results below that illustrate the challenge of accurately evaluating the efficacy of an algorithm in domain generalization. We first note that we expect ERM (naive) to perform poorly in domain generalization tasks, certainly so when we observe worst-case shifts at test time. However, like other works (Gulrajani and Lopez-Paz, 2020) , we observe that ERM performs as well as other baselines during transfer on various benchmark datasets. Previous theoretical results (Rosenfeld et al., 2022) suggest that this observation is indicative of properties of the benchmark domains that may be sufficient for domain generalization with ERM -specifically that the distribution (and equivalently loss) of the target domain can be written as a convex combination of the those in the source domains. To further investigate this, we develop additional experiments motivated by the ColoredMNIST (Arjovsky et al., 2019) which seems to not fall into the scenario in (Rosenfeld et al., 2022) . We note that in the +90%, +80%, and -90% domains of ColoredMNIST, the -90% domain has the opposite relationship between the spurious correlation and the label, so the use of spurious correlation generalizes catastrophically in the -90% domain. In the setting, the baseline algorithms we present achieve poor accuracy in the -90% domain while maintaining high accuracy in the +90% and +80% domains. Consequently, we investigate two settings, setting a: +90%, +80%, +70%, -90% domains and setting b: +90%, +80%, -80%, -90% domains. In setting a, we add another domain with the majority direction in the relationship between spurious correlation and labels, and in setting b, we add another domain with the minority direction. We use Oracle model selection (held-out target data) to remove the effect of model selection for all methods in the results. We find that in setting a where we add a domain (+70%) that has spurious correlations that do not generalize the -90% domain, we observe worst-case accuracy across baselines is still very different from the median-case 4). 



Figure 1: Graphical model depicting the structure of our data-generating process -shaded nodes indicate observed variables. X represents the observed features, Y represents observed targets, and e represents domain influences. There is an explicit separation of domain-general (causal) Z c and domain-specific (anticausal) Z e features, which are combined to generate observed X.

Figure 2: We first generate a feature representation via featurizer F . During training, both representations, Φ and Ψ, generate predictions -domain-general and domain-specific predictions, respectively. However, only the domain-invariant representations/predictions are used during test time -indicated by the solid red arrows. ⊕ indicates concatenation.

Figure 3: Visualization of Lemma 5.1, where x 1 = z 1 and x 2 = z 2 . The large circle is the true decision boundary, where × : y = 0, • : y = 1. The dotted square indicates the covariate distribution (Uniform), which differs across domains. The length of the solid line indicates the µ(y = 0 | z 1 = 0). Clearly the length of each line changes across domains, hence the causal subset z 1 does not yield a domain-general predictor.

Lemma 5.3. (Sufficiency of TCRI for Causal Aggregation). Recall, X, Z c , Z e , Y from Figure 1. Let Z c , Z e be direct causes and direct effects of Y , respectively, and recall that X is a function of Z c and Z e . If the two representations induced by feature extractors Φ, Ψ satisfy TCRI, then wlog I(Φ(X); Y ) ≥ I(Z c ; Y ).

and TCRI, λ = 0.1, β = 10 and α = 0.75; additional details can be found in Appendix A. 6.1 DATASETS Algorithms: We compare our method to the following baselines: Empirical Risk Minimization: Empirical Risk Minimization (ERM, Vapnik (1991)), Invariant Risk Minimization (IRM Arjovsky et al. (2019)), Variance Risk Extrapolation (V-REx, Krueger et al. (2021)), Meta-Learning for Domain Generalization (MLDG, Li et al. (2018)), Group Distributionally Robust Optimization (GroupDRO, Sagawa et al. (2019)), and Adaptive Risk Minimization (ARM Zhang et al. (

Colored MNIST. 'acc' indicates model selection via validation accuracy, and 'cov' indicates model selection via validation conditional independence. 'ci' indicates a conditional cross-covariance penalty, and 'HSIC' indicates a Hilbert-Schmidt Independence Criterion penalty. Columns {+90%, +80%, -90%} indicate domains -{0.1, 0.2, 0.9} digit label and color correlation, respectively. μ, σ indicate the mean and standard deviation of the average domain accuracies, over 3 trials each, respectively.

Colored MNIST. Columns {+90%, +80%, -90%} indicate domains -{0.1, 0.2, 0.3, 0.9} digit label and color correlation, respectively. μ, σ indicate the mean and standard deviation of the average domain accuracies, over 3 trials each, respectively. Using the oracle selection method -held out target data.

Colored MNIST setting a. Columns {+90%, +80%, +70%, -90%} indicate domains -{0.1, 0.2, 0.3, 0.9} digit label and color correlation, respectively. μ, σ indicate the mean and standard deviation of the average domain accuracies, over 3 trials each, respectively. Using the oracle selection method -held out target data.

E THEORETICAL RESULTS

Lemma E.1. (Sufficiency of TCRI for Causal Aggregation). Recall, X, Z c , Z e , Y from Figure 1 . Let Z c , Z e be direct causes and direct effects of Y , respectively, and recall that X is a function of Z c and Z e . If the two representations induced by feature extractors Φ, Ψ satisfy TCRI, then wlogProof.(i) First we define K Z i c 's to be random variables with non-zero mutual information with Z c marginally and conditioned on Y :(ii) Furthermore, we have from (i.) that for any pair(iii) Given the total-chain-information criterion, we have that there exist a set of(iv) Combining (ii) and (iii), we have that all Z i c ′ s are aggregated in one of the two representations, say Φ, since for any Z k c that satisfies (i.), (ii.) =⇒ Z k c ∈ Φ(X), and therefore (iii.) =⇒ I(Φ(X); Y ) ≥ I(Z c ; Y ).

