CAUSAL BALANCING FOR DOMAIN GENERALIZATION

Abstract

While machine learning models rapidly advance the state-of-the-art on various real-world tasks, out-of-domain (OOD) generalization remains a challenging problem given the vulnerability of these models to spurious correlations. We propose a balanced mini-batch sampling strategy to transform a biased data distribution into a spurious-free balanced distribution, based on the invariance of the underlying causal mechanisms for the data generation process. We argue that the Bayes optimal classifiers trained on such balanced distribution are minimax optimal across a diverse enough environment space. We also provide an identifiability guarantee of the latent variable model of the proposed data generation process, when utilizing enough train environments. Experiments are conducted on DomainBed, demonstrating empirically that our method obtains the best performance across 20 baselines reported on the benchmark. 1

1. INTRODUCTION

Machine learning is achieving tremendous success in many fields with useful real-world applications (Silver et al., 2016; Devlin et al., 2019; Jumper et al., 2021) . While machine learning models can perform well on in-domain data sampled from seen environments, they often fail to generalize to out-of-domain (OOD) data sampled from unseen environments (Quiñonero-Candela et al., 2009; Szegedy et al., 2014) . One explanation is that machine learning models are prone to learning spurious correlations that change between environments. For example, in image classification, instead of relying on the object of interest, machine learning models easily rely on surface-level textures (Jo & Bengio, 2017; Geirhos et al., 2019) or background environments (Beery et al., 2018; Zhang et al., 2020) . This vulnerability to changes in environments can cause serious problems for machine learning systems deployed in the real world, calling into question their reliability over time. Various methods have been proposed to improve the OOD generalizability by considering the invariance of causal features or the underlying causal mechanism (Pearl, 2009) through which data is generated. Such methods often aim to find invariant data representations using new loss function designs that incorporate some invariance conditions across different domains into the training process (Arjovsky et al., 2020; Mahajan et al., 2021; Liu et al., 2021a; Lu et al., 2022; Wald et al., 2021) . Unfortunately, these approaches have to contend with trade-offs between weak linear models or approaches without theoretical guarantees (Arjovsky et al., 2020; Wald et al., 2021) , and empirical studies have shown their utility in the real world to be questionable (Gulrajani & Lopez-Paz, 2020) . In this paper, we consider the setting that multiple train domains/environments are available. We theoretically show that the Bayes optimal classifier trained on a balanced (spurious-free) distribution is minimax optimal across all environments. Then we propose a principled two-step method to sample balanced mini-batches from such balanced distribution: (1) learn the observed data distribution using a variational autoencoder (VAE) and identify the latent covariate; (2) match train examples with the closest latent covariate to create balanced mini-batches. By only modifying the mini-batch sampling strategy, our method is lightweight and highly flexible, enabling seamless incorporation with complex classification models or improvement upon other domain generalization methods. Our contributions are as follows: (1) We propose a general non-linear causality-based framework for the domain generalization problem of classification tasks; (2) We prove that a spurious-free balanced distribution can produce minimax optimal classifiers for OOD generalization; (3) We rigorously demonstrate that the source of spurious correlation, as a latent variable, can be identified given a large enough set of training environments in a nonlinear setting; (4) We propose a novel and principled balanced mini-batch sampling algorithm that, in an ideal scenario, can remove the spurious correlations in the observed data distribution; (5) Our empirical results show that our method obtains significant performance gain compared to 20 baselines on DomainBed (Arjovsky et al., 2020) .

2. PRELIMINARIES

Problem Setting. We consider a standard domain generalization setting with a potentially highdimensional variable X (e.g. an image), a label variable Y and a discrete environment (or domain) variable E in the sample spaces X , Y, E, respectively. Here we focus on the classification problems with Y = {1, 2, ..., m} and X ⊆ R d . We assume that the training data are collected from a finite subset of training environments E train ⊂ E. The training data D e = {(x e i , y e i )} N e i=1 is then sampled from the distribution p e (X, Y ) = p(X, Y |E = e) for all e ∈ E train . Our goal is to learn a classifier C ψ : X → Y that performs well in a new, unseen environment e test ̸ ∈ E train . We assume that there is a data generation process of the observed data distribution p e (X, Y ) represented by an underlying structural causal model (SCM) shown in Figure 1a . More specifically, we assume that X is caused by label Y , an unobserved latent variable Z (with sample space Z ∈ R n ) and an independent noise variable ϵ with the following formulation: X = f (Y, Z) + ϵ = f Y (Z) + ϵ. Here, we assume the causal mechanism is invariant across all environments e ∈ E and we further characterize f with the following assumption: Assumption 2.1. f : {1, 2, ..., m} × Z → X is injective. f -1 : X → {1, 2, ..., m} × Z is the left inverse of f . Note that this assumption forces the generation process of X to consider both Z and Y instead of only one of them. Suppose ϵ has a known probability density function p ϵ > 0. Then we have p f (X|Z, Y ) = p ϵ (X -f Y (Z)). While the causal mechanism is invariant across environments, we assume that the correlation between label Y and latent Z is environment-variant and Z should exclude Y information. i.e., Y cannot be recovered as a function of Z. If Y is a function of Z, the generation process of X can completely ignore Y and f would not be injective. We consider the following family of distributions: F = { p e (X, Y, Z) = p f (X | Z, Y )p e (Z|Y )p e (Y )|p e (Z|Y ), p e (Y ) > 0 } e . (1) Then the environment space we consider would be all the index of F: E = { e | p e ∈ F }. Note that any mixture of distributions from F would also be a member of F. i.e. Any combination of the environments from E would also be an environment in E. To better understand our setting, consider the following example: an image X of an object in class Y has an appearance driven by the fundamental shared properties of Y as well as other meaningful latent features Z that do not determine "Y -ness", but can be spuriously correlated with Y . In Figure 2 , we plot causal diagrams for the joint distributions p(X, Y, E) of two example domain generalization datasets, ColoredMNIST (Arjovsky et al., 2020) and PACS (Li et al., 2017) . In ColoredMNIST, Z indicates the assigned color, which is determined by the digit label Y and the environment E = p(Z|Y ). In PACS, images of the same objects in different styles (e.g. sketches and photographs) occur in different environments, with Z containing this stylistic information. In this setting, we can see that the correlation between X and Y would vary for different values of e. We argue that the correlation Y ↔ Z → X is not stable in an unseen environment e ̸ ∈ E train as it involves E and we only want to learn the stable causal relation Y → X. However, the learned predictor may inevitably absorb the unstable relation between X and Y if we simply train it on the observed train distribution p e (X, Y ) with empirical risk minimization. Balanced Distribution. To avoid learning the unstable relations, we propose to consider a balanced distribution: Definition 2.2. A balanced distribution can be written as p B (X, Y, Z) = p f (X|Y, Z)p B (Z)p B (Y ), where p B (Y ) = U{1, 2, ..., m} and Y ⊥ ⊥ B Z. Here we do not specify p B (Z). Note that p B (X|Y, Z) = p f (X|Y, Z) is a result of the unchanged causal mechanism Z → X ← Y , and that p B (X, Y, X) ∈ F can also be regarded as constructing an new environment B ∈ E. In this new distribution, X and Y are only correlated through the stable causal relation Y → X. We want to argue that the Bayesian optimal classifier trained on such a balanced distribution would have the lowest worst-case risk, compared to Bayesian optimal classifiers trained on other environments in E as defined in Equation (1). To support this statement, we further assume some degree of disentanglement of the causal mechanism: Assumption 2.3. There exist functions g Y , g Z and noise variables ϵ Y , ϵ Z , such that (Y, Z) = f -1 (X -ϵ) = (g Y (X -ϵ Y ), g Z (X -ϵ Z )), and ϵ Y ⊥ ⊥ B ϵ Z . The above assumption implies that Y ⊥ ⊥ B Z|X. We can then have the following theoremfoot_1 : Theorem 2.4. Consider a classifier C ψ (X) = arg max Y p ψ (Y |X) with parameter ψ. The risk of such a classifier on an environment e ∈ E is its cross entropy: L e (p ψ (Y |X)) = -E p e (X,Y ) log p ψ (Y |X). Assume that E satisfies: ∀e ∈ E,Y ̸⊥ ⊥ p e Z =⇒ ∃e ′ ∈ E s.t. L e ′ (p e (Y |X)) -L e ′ (p B (Y )) > 0. Then the Bayes optimal classifier trained on any balanced distribution p B (X, Y ) is minimax optimal across all environments in E: p B (Y |X) = arg min p ψ ∈F max e∈E L e (p ψ (Y |X)). The assumption implies that the environment space E is large and diverse enough such that a perfect classifier on one environment will always perform worse than random guessing on some other environment. Under such an assumption, no other Byes optimal classifier produced by an environment in E would have a better worst case OOD performance than the balanced distribution.

3. METHOD

We propose a two-phased method that first use a VAE to learn the underlying data distribution p e (X, Y, Z) with latent covariate Z for each e ∈ E train , and then use the learned distribution to calculate a balancing score to create a balanced distribution based on the training data.

3.1. LATENT COVARIATE LEARNING

We argue that the underlying joint distribution of p e (X, Y, Z) can be learned and identified by a VAE, given a sufficiently large set of train environments E train . To specify the correlation between Z and Y , we assume that the conditional distribution p e (Z|Y ) is conditional factorial with an exponential family distribution: Assumption 3.1. The correlation between Y and Z in environment e is characterized by: p e T,λ (Z|Y ) = n i=1 Q i (Z i ) W e i (Y ) exp k j=1 T ij (Z i )λ e ij (Y ) , where Z i is the i-th element of Z, Q = [Q i ] i : Z → R n is the base measure, W e = [W e i ] i : Y → R n is the normalizing constant, T = [T ij ] ij : Z → R nk is the sufficient statistics, and λ e = [λ e ij ] ij : Y → R nk are the Y dependent parameters. Here n is the dimension of the latent variable Z, and k is the dimension of each sufficient statistic. Note that k, Q, and T is determined by the type of chosen exponential family distribution thus independent of the environment. The simplified conditional factorial prior assumption is from the mean-field approximation, which can be expressed as a closed form of the true prior (Blei et al., 2017) . Note that the exponential family assumption is not very restrictive as it has universal approximation capabilities (Sriperumbudur et al., 2017) . We then consider the following conditional generative model in each environment e ∈ E train , with parameters θ = (f , T, λ): p e θ (X, Z|Y ) = p f (X|Z, Y )p e T,λ (Z|Y ). We use a VAE to estimate the above generative model with the following evidence lower bound (ELBO) in each environment e ∈ E train : E D e [log p e θ (X|Y )] ≥ L e θ,ϕ := E D e E q e ϕ (Z|X,Y ) [log p f (X|Z, Y )] -D KL (q e ϕ (Z|X, Y )||p e T,λ Z|Y ) . The KL-divergence term can be calculated analytically. To sample from the variational distribution q e ϕ (Z|X, Y ), we use reparameterization trick (Kingma & Welling, 2013) . We then maximize the above ELBO 1 |Etrain| e∈Etrain L e θ,ϕ over all training environments to obtain model parameters (θ, ϕ). To show that we can uniquely recover the latent variable Z up to some simple transformations, we want to show that the model parameter θ is identifiable up to some simple transformations. That is, for any {θ = (f , T, λ), θ ′ = (f ′ , T ′ , λ ′ )} ∈ Θ, p e θ (X|Y ) = p e θ ′ (X|Y ), ∀e ∈ E train =⇒ θ ∼ θ ′ , where Θ is the parameter space and ∼ represents an equivalent relation. Specifically, we consider the following equivalence relation from Motiian et al. (2017)  : Definition 3.2. If (f , T, λ) ∼ A (f ′ , T ′ , λ ′ ), then there exists an invertible matrix A ∈ R nk×nk and a vector c ∈ R nk , such that T(f -1 (x)) = AT ′ (f ′-1 (x)) + c, ∀x ∈ X . When the underlying model parameter θ * can be recovered by perfectly fitting the data distribution p e θ * (X|Y ) for all e ∈ E train , the joint distribution p e θ * (X, Z|Y ) is also recovered. This further implies the recovery of the prior p e θ * (Z|Y ) and the true latent variable Z * . The identifiability of our proposed latent covariate learning model can then be summarized as follows: Theorem 3.3. Suppose we observe data sampled from the generative model defined according to Equation (2), with parameters θ = (f , T, λ). In addition to Assumption 2.1 and Assumption 3.1, we assume the following conditions holds: (1) The set {x ∈ X |ϕ ϵ (x) = 0} has measure zero, where ϕ ϵ is the characteristic function of the density p ϵ . (2) The sufficient statistics T ij are differentiable almost everywhere, and (T ij ) 1≤j≤k are linearly independent on any subset of X of measure greater than zero. (3) There exist nk + 1 distinct pairs (y 0 , e 0 ), . . . , (y nk , e nk ) such that the nk × nk matrix L = (λ e1 (y 1 ) -λ e0 (y 0 ), . . . , λ e nk (y nk ) -λ e0 (y 0 )) , is invertible. Then we have the parameters θ = (f , T, λ) are ∼ A -identifiable. Note that in the last assumption in Theorem 3.3, since there exists nk + 1 distinct points (y i , e i ), the product space Y × E train has to be large enough. i.e. We need m|E train | > nk. The invertibility of L implies that λ ei (y i ) -λ e0 (y 0 ) need to be orthogonal to each other which further implies the diversity of environment space E.

3.2. BALANCED MINI-BATCH SAMPLING

We consider using a classic method that has been widely used in the average treatment effect (ATE) estimation -balancing score matching (Rosenbaum & Rubin, 1983 ) -to sample balanced minibatches that mimic a balanced distribution shown in Figure 1b . A balancing score is used to balance the systematical difference between the treated unites and the controlled units, and to reveal the true causal effect from the observed data, which is defined as below: Definition 3.4. A balancing score b(Z) is a function of covariate Z s.t. Z ⊥ ⊥ Y |b(Z). There is a wide range of functions of Z that can be used as a balancing score, where the propensity score p(Y = 1|Z) is the coarsest one and the covariate Z itself is the finest one (Rosenbaum & Rubin, 1983) . To extend this statement to non-binary treatments, we first define propensity score s(Z) for Y ∈ Y = {1, 2, ..., m} as a vector: Definition 3.5. The propensity score for Y ∈ {1, 2, ..., m} is s (Z) = [p(Y = y|Z)] m y=1 . We then have the following theorem that applies to the vector version of propensity score s(Z): Theorem 3.6. Let b(Z) be a function of Z. Then b(Z) is a balancing score, if and only if b(Z) is finer than s(Z). i.e. exists a function g such that s(Z) = g(b(Z)). We use b e (Z) to denote the balancing score for a specific environment e. The propensity score would then be s e (Z) = [p e (Y = y|Z)] m y=1 , which can be derived from the VAE's conditional prior p e T,λ (Z|Y ) as defined in Equation ( 2): p e (Y = y|Z) = p e T,λ (Z|Y = y)p e (Y = y) m i=1 p e T,λ (Z|Y = i)p e (Y = i) , where p e (Y = i) can be directly estimated from the training data D e . In practice, we adopt the propensity score computed from Equation ( 3  (Y |Z, E) = 1 a+1 ( a m-1 + m-a-1 m-1 p(Y |Z, E)). When a = m -1, pB (Y |Z, E) = 1 m = p B (Y ). With perfect match at every step (i.e., b e (z j ) = b e (z)) and a = m -1, we can obtain a completely balanced mini-batch sampled from the balanced distribution. However, an exact match of balancing score is unlikely in reality, so a larger a will introduce more noises. This can be mitigated by choosing a smaller a, which on the other hand will increase the dependency between Y and Z. So in practice, the choice of a reflects a trade-off between the balancing score matching quality and the degree of dependency between Y and Z. Baselines: We apply our proposed balanced mini-batch sampling method along with four representative widely-used domain generalization algorithms: empirical risk minimization (ERM) (Vapnik, 1998) , invariant risk minimization (IRM) (Arjovsky et al., 2020) , GroupDRO (Sagawa et al., 2019) and deep CORAL (Sun & Saenko, 2016) , and compare the performance of using our balanced minibatch sampling strategy with using the usual random mini-batch sampling strategy. We compare our method with 20 baselines in total (Xu et al., 2020; Li et al., 2018a; Ganin et al., 2016; Li et al., 2018c; b; Krueger et al., 2021; Blanchard et al., 2021; Zhang et al., 2021; Nam et al., 2021; Huang et al., 2020; Shi et al., 2022; Parascandolo et al., 2021; Shahtalebi et al., 2021; Rame et al., 2022; Kim et al., 2021) reported on DomainBed, including a recent causality based baseline CausIRL CORAL and CausIRL MMD (Chevalley et al., 2022) that also utilize the invariance of causal mechanisms. We also compare with a group-based method PI (Bao et al., 2021) that interpolates the distributions of the correct predictions and the wrong predictions on ColoredMNIST 10 . To control the effect of the base algorithms, we use the same set of hyperparameters for both the random sampling baselines and our methods. We primarily consider train domain validation for model selection, as it is the most practical validation method. A detailed description of datasets and baselines, and hyperparameter tuning and selection can be found in Appendix B. ColoredMNIST: We use the ColoredMNIST dataset as a proof of concept scenario, as we already know color is a dominant latent covariate that exhibits spurious correlation with the digit label. For ColoredMNIST 10 , we adopt the setting from (Bao et al., 2021) , which is a multiclass version of the original ColoredMNIST dataset (Arjovsky et al., 2020) . The label y is assigned according to the numeric digit of the MNIST image with a 25% random noise. Then we assign one of a set of 10 colors (each indicated by a separate color channel) to the image according to the label y, with probability e that we assign the corresponding color and probability 1 -e we randomly choose another color. Here e ∈ {0.1, 0.2} for two train environments and e = 0.9 for the test environment. For ColoredMNIST, we adopt the original setting from (Arjovsky et al., 2020) , which only has two classes (digit smaller/larger than 5) and two colors, with three environments e ∈ {0.1, 0.2, 0.9}. Balanced mini-batch example. An example of a balanced mini-batch created by our method from digit 4, 5 and 7 in ColoredMNIST 10 is illustrated in Figure 3 . In the random mini-batch, labels are spuriously correlated with color. e.g. most 6 are blue, most 1 are red and most 2 are yellow. In the balanced mini-batch, we force each label to have uniform color distribution by matching each example with an example with a different label but the same color. Here, the color information is implicitly learned by latent covariate learning. ColoredMNIST main results. Table 1 shows the out-of-domain accuracy of our method combined with various base algorithms on ColoredMNIST 10 and ColoredMNIST dataset. Our balanced mini- batch sampling can increase the accuracy of all base algorithms by a large margin, with CORAL improving the most (57% and 47.3%). Note that the highest possible accuracy without relying on the color feature is 75%. In Figure 4 , we study important factors in our proposed method by ablating on the ColoredMNIST 10 dataset with ERM. The effectiveness of balancing. We construct oracle balanced mini-batches with b(Z) = Color, and then control the degree of balancing by varying the fraction of balanced examples in a mini-batch: for each randomly sampled example, with probability β, we match it with 9 examples with the same color but different labels to balance the mini-batch; otherwise, we match it with 9 examples with the same color and label to maintain the original distribution. Figure 4a shows that increasing the balancing fraction would increase the OOD performance. The effect of the number of matched examples a. Figure 4b shows that when a increases, the OOD performance first increases, then becomes stable with a slightly decreasing trend. This result is consistent with our analysis in Section 3.2, that a large a will increase balancing in theory, but due to imperfection of the learning of latent covariate Z, large a will eventually introduce more low-quality matches, which may hurt the performance. It can also be observed that we do not need a very large a to reach the maximum performance. The effect of different test environments. In Figure 4c , we fix the train environments as [0.1, 0.2] and test on different test environments. We report the results chosen by train domain validation, as the results with test domain validation are almost the same as the training domain validation results. The accuracy of the model trained with random mini-batches drops linearly when the test environment changes from 0.1 to 0.9, indicating that the model learns to use the color feature as the main predictive evidence. On the other hand, the accuracy of the model trained with balanced mini-batches produce by our method almost stays the same across all test domains, indicating that the model learns to use domain-invariant features.

DomainBed:

We investigate the effectiveness of our method under different situations.

DomainBed main results.

In Table 2 , we consider combining our method with four representative base algorithms: ERM, IRM, GroupDRO, and CORAL. IRM represents a wide range of invariant representation learning baselines. GroupDRO represents group-based methods that minimize the worst group errors. CORAL represents the distribution matching algorithms that match the feature distribution across train domains. In general, our method can improve the average performance of all the base algorithms by one to two points (1.6% for ERM, IRM and GroupDRO), while CORAL improves the most (2.2%). The reason why CORAL works the best with our method, and achieves the state-of-the-art OOD accuracy not only on average but also on ColoredMNIST, PACS, Office-Home and DomainNet dataset, is likely because our method aims to balance the data distribution and close the distribution gap between domains, which is in line with the objective of distribution matching algorithms. Our proposed method improves the most on ColoreMNIST, OfficeHome, and DomainNet, while our method is not very effective on RotatedMNIST and VLCS. Reason for significant improvements. The large improvement on ColoredMNIST (8.6% for ERM, 7.2% for IRM, 1.8% for GroupDRO and 15.1% for CORAL) is likely because the dominant latent covariate, color, is relatively easy to learn with a low dimensional VAE. The good performance on OfficeHome and DomainNet (1.7% for ERM, 6.6% for IRM, 7.5% for GroupDRO and 2.4% for CORAL) is likely because of the large number of classes. OfficeHome has 65 classes, and Domain-Net has 345 classes, while all the other datasets have less or equal to 10 classes. According to the conclusion of Theorem 3.3, a larger number of labels or environments will enable the identification of a higher dimensional latent covariate, which is more likely to capture the complex underlying data distribution. Reason for insignificant improvements. The lower performance on RotatedMNIST is because the digits in each domain are all rotated by the same degree. Since classes are balanced, images in each domain are already balanced for rotation, the dominant latent covariate. As the performance with random mini-batches is already very high, the noise introduced by the matching procedure may hurt the performance. VLCS on the one hand has a pretty complex data distribution as the images from each domain are very different realistic photos collected in different ways. However, VLCS only has 5 classes and 4 domains, which only enables the identification of a very low dimensional latent covariate, which is insufficient to capture the complexity of each domain. In practice, we suggest using our method when there is a large number of classes or domains, and preferably combined with distribution matching algorithms for domain generalization.

5. RELATED WORK

A growing body of work has investigated the out-of-domain (OOD) generalization problem with causal modeling. One prominent idea is to learn invariant features. When multiple training domains are available, this can be approximated by enforcing some invariance conditions across training domains by adding a regularization term to the usual empirical risk minimization (Arjovsky et al., 2020; Krueger et al., 2021; Bellot & van der Schaar, 2020; Wald et al., 2021; Chevalley et al., 2022) . There are also some group-based works (Sagawa et al., 2019; Bao et al., 2021; Liu et al., 2021b; Sanh et al., 2021; Piratla et al., 2021; Zhou et al., 2021) that improve worst group performance and can be applied to domain generalization problem. However, recent work claims that many of these approaches still fail to achieve the intended invariance property (Kamath et al., 2021; Rosenfeld et al., 2020; Guo et al., 2021) , and thorough empirical study questions the true effectiveness of these domain generalization methods (Gulrajani & Lopez-Paz, 2020) . Instead of using datasets from multiple domains, Makar et al. (2022) To sample from the balanced distribution, we use a classic method for average treatment effect (ATE) estimation (Holland, 1986 ) -balancing score matching (Rosenbaum & Rubin, 1983) . Causal effect estimation studies the effect a treatment would have had on a unit that in reality received another treatment. A causal graph (Pearl, 2009) similar to Figure 1a is usually considered in a causal effect estimation problem, where Z is called the covariate (e.g. a patient profile), which is observed before treatment Y ∈ {0, 1} (e.g. taking placebo or drug) is applied. We denote the effect of receiving a specific treatment Y = y as X y (e.g. blood pressure). Note that the causal graph implies the Strong Ignorability assumption (Rubin, 1978) . i.e. Z includes all variables related to both X and Y . In the case of a binary treatment, the ATE is defined as E [X 1 -X 0 ]. For a randomized controlled trial, ATE can be directly estimated by E[X|Y = 1] -E[X|Y = 0], as in this case Z ⊥ ⊥ Y and there would not be systematic differences between units exposed to one treatment and units exposed to another. However, in most observed datasets, Z is correlated with Y . Thus E[X 1 ] and E[X 0 ] are not directly comparable. We can then use balancing score b(Z) (Dawid, 1979) to de-correlate Z and Y , and ATE can then be estimated by matching units with same balancing score but different treatments: E[X 1 -X 0 ] = E b(Z) [E[X|Y = 1, b(Z)] -E[X|Y = 0, b(Z)]]. Recently, Schwab et al. (2018) extends this method to individual treatment effect (ITE) estimation (Holland, 1986 ) by constructinng virtually randomized mini-batches with balancing score.

6. CONCLUSION

Our novel causality-based domain generalization method for classification task samples balanced mini-batches to reduce the presentation of spurious correlations in the dataset. We propose a spurious-free balanced distribution and show that the Bayes optimal classifier trained on such distribution is minimax optimal over all environments. We show that our assumed data generation model with an invariant causal mechanism can be identified up to sample transformations. We demonstrate theoretically that the balanced mini-batch is approximately sampled from a spurious-free balanced distribution with the same causal mechanism under ideal scenarios. Our experiments empirically show the effectiveness of our method in both semi-synthetic settings and real-world settings.

A PROOFS

In this section, we give full proofs of the main theorems in the paper. A.1 BALANCED DISTRIBUTION A.1.1 PROOF FOR THEOREM 2.4 Here we give a proof of the minimax optimality of the Bayes optimal classifier trained on a balanced distribution. Proof. The Bayes optimal classifier trained on a balanced distribution p B (X, Y ) has p ψ (Y |X) = p B (Y |X). Then consider the expected cross entropy loss of such classifier on an unseen test distribution p e : L e (p B (Y |X)) = -E p e (X,Y ) log p B (Y |X) (4) = -E p e (X,Y ) log p B (Y ) + E p e (X,Y ) log p B (Y ) p B (Y |X) = L e (p B (Y )) + E p e (X,Y,Z) log p B (Y ) p B (Y |X) = L e (p B (Y )) + E p e (Y,Z) E p B (X|Y,Z) log p B (Y ) p B (Y |X) = L e (p B (Y )) + E p e (Y,Z) E p B (X|Y,Z) log p B (Y |Z) p B (Y |X, Z) (5) = L e (p B (Y )) + E p e (Y,Z) E p B (X|Y,Z) log p B (X|Z) p B (X|Y, Z) = L e (p B (Y )) -E p e (Y,Z) KL[p B (X|Y, Z)||p B (X|Z)]. • Equation ( 4) is the definition of cross entropy loss. • Equation ( 5) is obtained by Y ⊥ ⊥ B Z and Y ⊥ ⊥ B Z|X. Thus we have the cross entropy loss of p B (X, Y ) in any environment e is smaller than that of p B (Y ) = 1 m (random guess): L e (p B (Y |X)) -L e (p B (Y )) ≤ -E p e (Y,Z) KL[p B (X|Y, Z)||p B (X|Z)] ≤ 0, which means: max e ′ ∈E L e ′ (p B (Y |X)) -L e ′ (p B (Y )) ≤ 0. That is, the performance of p B (X, Y ) is at least as good as a random guess in any environment. Since we assume the environment diversity, that is for any p e with Y ̸⊥ ⊥ e Z, there exists an environment e ′ such that p e (Y |X) performs worse than a random guess. So we have:  max e ′ ∈E L e ′ (p B (Y |X)) -L e ′ (p B (Y )) ≤ 0 < max e ′ ∈E L e ′ (p e (Y |X)) -L e ′ (p B (Y )) . E p B (Y ) [p B (X|Z, Y )] = p B (Y |X, Z) = p B (Y |X). Thus we have the following minimax optimality: p B (Y |X) = arg min p ψ ∈F max e∈E L e (p ψ (Y |X)). A.2 LATENT COVARIATE LEARNING A.2.1 PROOF FOR THEOREM 3.3 We now prove Theorem 3.3 setting up the identifiability of the necessary parameters that capture the spuriously correlated covariate features in the VAE. The proof is based on the proof of Theorem 1 in (Motiian et al., 2017) , with the following modifications: 1. We use both E and Y as auxiliary variables. 2. We include Y in the causal mechanism of generating X by X = f (Y, Z) + ϵ = f Y (Z) + ϵ. Proof. Step I. In this step, we transform the equality of the marginal distributions over observed data into the equality of a noise-free distribution. Suppose we have two sets of parameters θ = (f , T, λ) and θ ′ = (f ′ , T ′ , λ ′ ) such that p θ (X|Y, E = e) = p θ ′ (X|Y, E = e), ∀e ∈ E train , then: Z p T,λ (Z|Y, E = e)p f (X|Z, Y )dZ = Z P T ′ ,λ ′ (Z|Y, E = e)p ′ f (X|Z, Y )dZ ⇒ Z p T,λ (Z|Y, E = e)pϵ(X -fY (Z))dZ = Z p T ′ ,λ ′ (Z|Y, E = e)pϵ(X -f ′ Y (Z))dZ ⇒ X p T,λ (f -1 ( X)|Y, E = e)volJ f -1 ( X)pϵ(X -X)d X = X p T ′ ,λ ′ (f ′-1 ( X)|Y, E = e)volJ f ′-1 ( X) pϵ(X -X)d X (6) ⇒ R d pT,λ,f,Y,e ( X)pϵ(X -X)d X = R d pT ′ ,λ ′ ,f ′ ,Y,e ( Xpϵ(X -X)d X) (7) ⇒ (p T,λ,f ,Y,e * pϵ)(X) = (p T ′ ,λ ′ ,f ′ ,Y,e * PE )(X) (8) ⇒ F [p T,λ,f ,Y,e ](ω)ϕϵ(ω) = F [p T ′ ,λ ′ ,f ′ ,Y,e ](ω)ϕϵ(ω) (9) ⇒ F [p T,λ,f ,Y,e ](ω) = F [p T ′ ,λ ′ ,f ′ ,Y,e ](ω) (10) ⇒ pT,λ,f,Y,e (X) = pT ′ ,λ ′ ,f ′ ,Y,e (X). • In Equation ( 6), we denote the volume of a matrix A as volA := √ det A T A. J denotes the Jacobian. We made the change of variable X = f Y (Z) on the left hand side and X = fY (Z) on the right hand side. Since f is injective, we have f -1 ( X) = (Y, Z). Here we abuse f -1 ( X) to specifically denote the recovery of Z, i.e. f -1 ( X) = Z. • In Equation ( 7), we introduce pT,λ,f,Y,e (X) = p T,λ (f -1 Y (X)|Y, E = e)volJ f -1 Y (X)1 X (X), on the left hand side, and similarly on the right hand side. • In Equation ( 8), we use * for the convolution operator. • In Equation ( 9), we use F [•] to designate the Fourier transform. The characteristic function of ϵ is then ϕ ϵ = F [p ϵ ]. • In Equation ( 10), we dropped ϕ ϵ (ω) from both sides as it is non-zero almost everywhere (by assumption (1) of the Theorem). Step II. In this step, we remove all terms that are either a function of X or Y or e. By taking logarithm on both sides of Equation ( 11) and replacing P T,λ by its expression from Equation (3) we get: log volJ f -1 (X) + n i=1 (log Q i (f -1 i (X)) -log W e i (Y ) + k j=1 T i,j (f -1 i (X))λ e i,j (Y )) = log volJ f ′-1 (X) + n i=1 (log Q ′ i (f ′-1 i (X)) -log W ′e i (Y ) + k j=1 T ′ i,j (f ′-1 i (X))λ ′e i,j (Y )). Let (e 0 , y 0 ), (e 1 , y 1 ), ..., (e nk , y nk ) be the points provided by assumption (3) of the Theorem. We evaluate the above equations at these points to obtain k + 1 equations, and subtract the first equation from the remaining k equations to obtain: ⟨T(f -1 (X)), λ e l (y l ) -λ e0 (y 0 )⟩ + n i=1 log W e0 i (y 0 ) W e l i (y l ) =⟨T ′ (f -1 (X)), λ ′e l (y l ) -λ ′e0 (y 0 )⟩ + n i=1 log W ′e0 i (y 0 ) W ′e l i (y l ) . ( ) Let L be the matrix defined in assumption (3) and L ′ similarly defined for λ ′ (L ′ is not necessarily Then Equation ( 12) can be rewritten in the matrix form: L T T(f -1 (X)) = L ′T T ′ (f ′-1 (X)) + b. ( ) We multiply both sides of Equation ( 13) by L -T to get: T(f -1 (X)) = AT ′ (f ′-1 (X)) + c. ( ) Where A = L -T L ′ and c = L -T b. Step III. To complete the proof, we need to show that A is invertible. By definition of T and according to Assumption (2), its Jacobian exists and is an nk × n matrix of rank n. This implies that the Jacobian of T ′ • f ′-1 exists and is of rank n and so is A. We distinguish two cases: 1. If k = 1, then A is invertible as A ∈ R n×n . 2. If k > 1, define x = f -1 (x) and T i (x i ) = (T i,1 (x i ), ..., T i,k (x i )). Suppose for any choice of x1 i , x2 i , ..., xk i , the family ( dTi(x 1 i ) dx 1 i , ..., dTi(x k i ) dx k i ) is never linearly independent. This means that T i (R) is included in a subspace of R k of the dimension of most k -1. Let h be a non-zero vector that is orthogonal to T i (R). Then for all x ∈ R, we have ⟨ dTi(x) dx , h⟩ = 0. By integrating we find that ⟨T i (x), h⟩ = const. Since this is true for all x ∈ R and a h ̸ = 0, we conclude that the distribution is not strongly exponential. So by contradiction, we conclude that there exist k points x1 i , x2 i , ...x k i such that ( dTi(x 1 i ) dx 1 i , ..., dTi(x k i ) dx k i ) are linearly independent. Collect these points into k vectors (x 1 , ..., xk ) and concatenate the k Jacobians J T (x l ) evaluated at each of those vectors horizontally into the matrix Q = (J T (x 1 ), ..., J T (x k )) and similarly define Q ′ as the concatenation of the Jacobians of T ′ (f ′-1 • f (x)) evaluated at those points. Then the matrix Q is invertible. By differentiating Equation ( 14) for each x l , we get: Q = AQ ′ . The invertibility of Q implies the invertibility of A and Q ′ . This completes the proof.

A.3 BALANCED MINI-BATCH SAMPLING

A.3.1 PROOF FOR THEOREM 3.6 Our proof of all possible balancing scores is an extension of the proof of Theorem 2 from (Rosenbaum & Rubin, 1983) , by generalizing the binary treatment to multiple treatments. Proof. First, suppose the balancing score b(Z) is finer than the propensity score s(Z). By the definition of a balancing score (Theorem 3.4) and Bayes' rule, we have: p(Y |Z, b(Z)) = p(Y |b(Z)) On the other hand, since b(Z) is a function of Z, we have: p(Y |Z, b(Z)) = p(Y |Z) Equation ( 15) and Equation (16 ) give us p(Y |b(Z)) = p(Y |Z). So to show b(Z) is a balancing score, it is sufficient to show p(Y |b(Z)) = p(Y |Z). Let the y-th entry of s(Z) be s y (Z) = p(Y = y|Z), then: E[s y (Z)|b(Z)] = Z p(Y = y|Z = z)p(Z = z|b(Z))dz = p(Y = y|b(Z)) But since b(Z) is finer than s(Z), b(Z) is also finer than s y (Z), then E[s y (Z)|b(Z)] = s y (Z) Then by Equation ( 17) and Equation ( 18) we have P (Y = y|Z) = P (Y = y|b(Z)) as required. So b(Z) is a balancing score. For the converse, suppose b(Z) is a balancing score, but that b(Z) is not finer than s(Z). Then there exists z 1 and z 2 such that s(z 1 ) ̸ = s(z 2 ), but b(z 1 ) = b(z 2 ). By the definition of s(•), there exists y such that P (Y = y|z 1 ) ̸ = P (Y = y|z 2 ). This means, Y and Z are not conditionally independent given b(Z), thus b(Z) is not a balancing score. Therefore, to be a balancing score, b(Z) must be finer than s(Z). Note that s(Z) is also a balancing score, since s(Z) is also a function of itself. A.3.2 PROOF FOR THEOREM 3.7 We provide a proof for Theorem 3.7, demonstrating the feasibility of balanced mini-batch sampling. Proof. In Algorithm 1, by uniformly sampling a different labels such that y ̸ = y e , we mean sample Y alt = {y 1 , y 2 , ..., y a } by the following procedure: y 1 ∼ U {1, 2, ..., m} \ {y e } y 2 ∼ U {1, 2, ..., m} \ {y e , y 1 } . . . y a ∼ U {1, 2, ..., m} \ {y e , y 1 , y 2 ...y a-1 }, where U denotes the uniform distribution. Suppose D balanced ∼ pB (X, Y ), and data distribution D e ∼ p(X, Y |E = e), ∀e ∈ E train . Suppose we have an exact match every time we match a balancing score, then for all e ∈ E train , we have  pB (Y |b e (Z), E = e) = 1 a + 1 p(Y |b e (Z), E = e) + 1 a + 1 (1 -p(Y |b e (Z), E = e) 1 m -1 + + 1 a + 1 (1 -p(Y |b e (Z), E = e)(1 - 1 m -1 ) 1 m -2 + ... + 1 a + 1 (1 -p(Y |b e (Z), E = e)(1 - 1 m -1 )(1 - 1 m -2 )... (1 - 1 m -a + 1 ) 1 m -a = 1 a + 1 ( a m -1 + m -a -1 m -1 p(Y |b e (Z), E = e)). pB (Y |Z, E) = 1 a + 1 ( a m -1 + m -a -1 m -1 p(Y |Z, E)). When a = m -1, we have pB (Y |Z, E) = 1 m = U {1, 2, ..., m}, which means pB (X, Y, Z) = p B (X, Y, Z). i.e. D balanced can be regarded as sampled from the balanced distribution p B as defined in Definition 2.2.

B EXPERIMENT DETAILS

In this section, we give more details of our experiments. We perform our experiments on the Do-mainBed codebasefoot_3 (Gulrajani & Lopez-Paz, 2020) .

B.1 DATASETS

ColoredMNIST is a variant of the MNIST handwritten digit classification dataset. Each domain in [0.1, 0.3, 0.9] is constructed by digits spuriously correlated with their color. This dataset contains 70, 000 examples of dimensions (2, 28, 28) and 2 classes, where the class indicates if the digit is less than 5, with a 25% noise. RotatedMNIST is another variant of MNIST where each domain contains digits rotated by α degrees, where α ∈ {0, 15, 30, 45, 60, 75}. This dataset contains 70, 000 examples of dimensions (1, 28, 28) and 10 classes, where the class indicates the digit. PACS comprises four domains: art, cartoons, photos, and sketches. This dataset contains 9, 991 examples of dimensions (3, 224, 224) 

B.2 BASELINES

We choose ERM, IRM, GroupDRO and CORAL as base algorithms to apply our method because they are representative methods for domain generalization, and they serve as strong baselines when compared to a wide range of domain generalization methods. Empirical risk minimization (ERM) is a default training scheme for most machine learning problems, merging all training data into one dataset and minimizing the training errors across all training domains. Invariant risk minimization (IRM) represents a wide range of invariant representation learning baselines. IRM learns a data representation such that the optimal linear classifier on top of it is invariant across training domains. Group distributionally robust optimization (GroupDRO) represents group-based methods that minimize the worst group errors. GroupDRO performs ERM while increasing the weight of the environments with larger errors. Deep CORAL represents the distribution matching algorithms. CORAL matches the mean and covariance of feature distributions across training domains. According to (Gulrajani & Lopez-Paz, 2020) , CORAL is the best performing domain generalization algorithm averaged across 7 datasets, compared to other 13 baselines.

B.3 HYPERPARAMETER SELECTION

Base algorithms: For the architecture of image classifiers, following the DomainBed setting, we train a convolutional neural network from scratch for ColoredMNIST and RotatedMNIST datasets, and use a pre-trained ResNet50 (He et al., 2016) for all other datasets. Each experiment is repeated with 3 different random seeds. We choose the hyperparameters of base algorithms based on the default hyperparameter search with random mini-batch sampling. More specifically, we extract the hyperparameters from the official experimental logs provided in the DomainBed GitHub repository. 5 To retrieve hyperparameters, we ran the script collect results detailed.py, modified from the provided collect results.py script, to collect the hyperparameters that are used to produce the DomainBed results table with train domain validation. Balanced mini-batch construction: We use a multi-layer perceptron (MLP) based VAE (Kingma & Welling, 2013) to learn the latent covariate Z. For ColoredMNIST, ColoredMNIST 10 and Rotat-edMNIST, we use a 2-layer MLP with 512 neurons in each layer. For all other datasets, we use a 3-layer MLP with 1024 neurons in each layer. We choose the conditional prior p t (Z|Y, E = e) to be a Gaussian distribution with diagonal covariance matrix. We also choose the noise distribution p ϵ to be a Gaussian distribution with zero mean and identity variance matrix. We choose the largest possible latent dimension n according to Theorem 3.3 up to 64. We choose KL divergence as our distance metric d on DomainBed. The hyperparameters we use are shown in Table 3 . We control k by choosing different distributions to model the latent covariate: for k = 2, we choose Normal distribution, and for k = 1, we choose Normal distribution with a fixed variance equal to the identity matrix. When choosing the latent dimension n, we follow the identifiability requirement m|E train | > nk in Section 3.1, and we chose the maximum allowed n up to λ = 64 for large images (224 × 224) and up to λ = 16 for small images (28 × 28). i.e. n = min{⌊m|E train |/k⌋, λ}. For the distance metric d, we choose the KL divergence on all datasets except on ColoredMNIST 10 , we choose the L∞ distance. Different choice of distance metric usually does not affect the final results too much, as shown in Table 4 . We tune the number of matching examples a for each base algorithm with a train domain validation, and the best a for each base algorithm is shown in the order of ERM/IRM/GroupDRO/CORAL in the last column of Table 3 . Typically, the best a for a dataset across different base algorithms is similar. ERM 95.9 ± 0.1 98.9 ± 0.0 98.8 ± 0.0 98.9 ± 0.0 98.9 ± 0.0 96.4 ± 0.0 98.0 IRM 95.5 ± 0.1 98.8 ± 0.2 98.7 ± 0.1 98.6 ± 0.1 98.7 ± 0.0 95.9 ± 0.2 97.7 GroupDRO 95.6 ± 0.1 98.9 ± 0.1 98.9 ± 0.1 99.0 ± 0.0 98.9 ± 0.0 96.5 ± 0.2 98.0 CORAL 95.8 ± 0.3 98.8 ± 0.0 98.9 ± 0.0 99.0 ± 0.0 98.9 ± 0.1 96.4 ± 0.2 98.0 Ours+ERM 94.8 ± 0.3 98.4 ± 0.1 98.7 ± 0.0 98.8 ± 0.0 98.8 ± 0.0 96.4 ± 0.1 97.7 Ours+IRM 93.0 ± 0.5 98.2 ± 0.1 98.6 ± 0.1 98.3 ± 0.2 98.6 ± 0.1 94.3 ± 0.2 96.8 Ours+GroupDRO 94.8 ± 0.2 98.5 ± 0.1 98.9 ± 0.0 98.8 ± 0.0 98.9 ± 0.1 95.9 ± 0.3 97.6 Ours+CORAL 94.5 ± 0.4 98.7 ± 0.0 98.8 ± 0.1 99.0 ± 0.0 98.9 ± 0.0 96.2 ± 0.2 97.7 

C DISCUSSIONS AND LIMITATIONS

The experiments show that our balanced mini-batch sampling method outperforms the random minisampling baseline when applied to multiple domain generalization methods, on both semi-synthetic datasets and real-world datasets. While our method can be easily incorporated into other domain generalization methods with good performance, there are some potential drawbacks of our method. First, the computation complexity of our method grows quadratically with the dataset size, as for each training example, our method requires searching across the dataset to find the closest match in balancing score, which could become a computation bottleneck on large datasets. However, this could be solved by matching examples offline before training, or with more efficient searching methods. The second caveat is that we do not provide an optimized model selection method to complement our method. While it is possible to balance the held-out validation set with our method and choose the best model based on the accuracy of the balanced validation set, the quality of such a balanced validation set is questionable given the small size of a typical validation set. For now, we recommend the training-domain validation scheme in practice.

D IN-DEPTH COMPARISON WITH RELATED WORK D.1 COMPARISON OF ASSUMPTIONS

Certain assumptions are needed for our paper, as in other works on domain generalization. Our assumptions are not stronger than other domain generalization works that give similar generalization guarantees. Arguably, ours are weaker than most of them. We provide the identifiability of the balanced distribution given a finite set of train environments and prove that the Bayesian optimal classifier trained on the balanced distribution would be minimax optimal across all environments. Our main assumptions are the factorial exponential distribution of the latent covariate given the label, the invertible causal function f , and the additive noise. Similar assumptions have been made in Sun et al. (2021) . Works without constraints on environments usually can only provide a generalization guarantee when optimizing overall environments (Mahajan et al., 2021) or do not provide any such guarantees (Chen et al., 2021; Li et al., 2022) . To provide a generalization guarantee with a single or a small number of train environments, Yuan et al. ( 2021 In practice, the model built with our assumptions works well on real-world datasets that do not exactly fit our assumptions, which empirically demonstrates that our method is robust against violations of our assumptions.

D.2 COMPARISON OF CAUSAL MODEL

In general, the assumption of the underlying Structural Causal Model (SCM) is determined by the nature of the task. Sometimes, such SCM can be designed by a human expert who knows the data generation process of the task. In our paper, we propose to adopt a coarse-grained SCM for general image classification tasks with only three variables: image X, label Y , and latent variable Z. Our high-level philosophy is that the image itself is merely a record of what has been done, and the label can usually be regarded as a driving force of the recorded event. When one intervenes on image X, the label Y of the image does not necessarily change. However, if the intervention is on the class label Y , the image X changes almost for sure for a well-defined image classification task. For example, in the medical domain, a disease (Y ) would cause some lesions, further driving the different appearance of MRI images (X). Another example is when Y is the object class of the item appearing in the image X, which is usually the case for the most widely used image classification benchmarks like ImageNet. We have also discussed this in Section 2.1 of our paper. However, there could be exceptions. For example, if we are asked to classify whether we feel happy or sad after seeing a picture, picture X would become the cause of the sentiment label Y . Such a scenario is less likely to happen in real-world image classification tasks. To resolve the issue of different SCM for different tasks, Christiansen et al. (2021) consider all SCMs that can be transformed into a specific linear form with plausible interventions. Wald et al. (2021) assume X can be disentangled into features causing Y and features caused by Y , and derive their theoretical results with a linear SCM. We assume a more general nonlinear SCM with Y → X, which is suitable for most of the image classification tasks we consider. On the other hand, Yuan et al. (2021) directly assumes an SCM with X → Y . Empirically, they obtained worse results on PACS (84.4 v.s. 86.7) and OfficeHome (64.2 v.s. 69.6) datasets, which confirms that our SCM is more suitable. A principled way of identifying the causal relationship (if there is any) between X and Y is causal discovery. However, current causal discovery techniques cannot handle the complex highdimensional image data we consider in the paper (Vowels et al., 2021) . A slightly related work is Hoover (1990) , which proposes that decomposing a joint distribution following the causal graph is more stable for interventions than a random decomposition. Our paper uses the invariant of P (X|Y, Z), where Z represents domain-dependent features like camera positions and picture style. It is hard to find such invariance in other ways of decomposition. On the other hand, quite a few works assume no direct causal relationship between X and Y (Chen et al., 2021; Mahajan et al., 2021; Liu et al., 2021a; Sun et al., 2021; Ahuja et al., 2021; Li et al., 2022) . Instead, they assume there is a causal feature Z causal directly causing X, together with another non-causal feature Z non-causal . Y is caused by Z causal , which implies that Z causal may contain more information than Y . Such a causal model can be viewed as a noisy version of ours, as we consider Y the same as the causal feature Z causal , and Z the same as the non-causal feature Z non-causal . Different paper model the spurious correlation between Z causal and Z non-causal in a different way in the SCM, while we just ensure Z causal and Z non-causal are correlated, without specifying how they are correlated.



We publicly release our code at https://github.com/WANGXinyiLinda/ causal-balancing-for-domain-generalization. See Appendix A for proofs of all theorems. https://github.com/facebookresearch/DomainBed https://github.com/facebookresearch/DomainBed https://drive.google.com/file/d/16VFQWTble6-nB5AdXBtQpQFwjEC7CChM/



Figure 1: The causal graphical model assumed for data generation process in environment e ∈ E. Shaded nodes mean being observed and white nodes mean not being observed. Black arrows mean causal relations invariant across different environments. The Red dashed line means correlation varies across different environments.

Figure 2: Annotated example causal graphs of two realizations of the joint distribution p(X, Y, E).

) as our balancing score (b(Z) = s e (Z)) and propose to construct balanced mini-batches by matching 1 ≤ a ≤ m -1 different examples with different labels but the same/closest balancing score, b e (Z) ∈ B, with each train example. The detailed sampling algorithm is shown in Algorithm 1. Algorithm 1: Balanced Mini-batch sampling. Input: |E train | training datasets D e = {(x e i , y e i )} N e i=1 for all e ∈ E train , a balancing score b e (z i ) inferred from each training data point (x e i , y e i ), and a distance metrics d : B × B → R; Output: A balanced batch of data D balanced consisting of B × |E train | × (a + 1) examples; D balanced ← Empty; for e ∈ E train do Randomly sample B data points D e random from D e ; Add D e random to D balanced ; for (x e , y e ) ∈ D e random do Y alt = {y i ∼ U {1, 2, ..., m} \ {y e , y 1 , .., y i-1 }|i ∈ [1, a]}; Compute balancing score b e (z e ) from (x e , y e ); for y i ∈ Y alt do j = arg min j∈[1,N e ] d(b e (z j ), b e (z e )) such that y e j = y i and (x e j , y e j ) ∈ D e ; Add (x e j , y e j ) to D balanced . We denote the data distribution obtained from Algorithm 1 by pB (X, Y, Z, E), then we have: Theorem 3.7. If d(b e (z j ), b e (z e )) = 0 in Algorithm 1, the balanced mini-batch can be regarded as sampling from a semi-balanced distribution with pB

(a) A random mini-batch.(b) A balanced mini-batch (obtained by our method).

Figure 3: A random mini-batch and a balanced mini-batch from the ColoredMNIST 10 dataset. Note that there is 25% label noise so mismatches of label y and image are expected.

Figure 4: The out-of-domain accuracy versus (a) degree of balancing, (b) number of matched examples a, and (c) test environment, on ColoredMNIST 10 dataset with ERM base algorithm.

Now we want to prove that ∀e ∈ E, Y ⊥ ⊥ e Z, Y ⊥ ⊥ e Z|X, p e (Y ) = 1 m =⇒ p e (Y |X) = p B (Y |X). For any Z ∈ Z, we have: p e (Y |X) = p e (Y |X, Z) = p e (Y ) p e (X|Y, Z) E p e (Y |Z) [p e (X|Z, Y )] = p B (Y ) p B (X|Y, Z)

l ) and b = [b l ] nk l=1 .

By the definition of balancing score, p(Y |Z, E = e) = p(Y |b e (Z), E = e) and pB (Y |Z, E = e) = pB (Y |b e (Z), E = e), then we have

and 7 classes, where the class indicates the object in the image. VLCS comprises four photographic domains: Caltech101, LabelMe, SUN09, and VOC2007. This dataset contains 10, 729 examples of dimensions (3, 224, 224) and 5 classes, where the class indicates the main object in the photo. OfficeHome includes four domains: art, clipart, product, and real. This dataset contains 15, 588 examples of dimension (3, 224, 224) and 65 classes, where the class indicates the object in the image. TerraIncognita contains photographs of wild animals taken by camera traps at four different locations: L100, L38, L43, and L46. This dataset contains 24, 788 examples of dimensions (3, 224, 224) and 10 classes, where the class indicates the animal in the image. DomainNet has six domains: clipart, infographics, painting, quickdraw, real, and sketch. This dataset contains 586, 575 examples of size (3, 224, 224) and 345 classes.

Figure5shows three sets of reconstructed images with the same latent covariate Z and different label Y using our VAE model. We can see that Z keeps the color feature and some style features, while the digit shape is changed to the closest digits belongs to class Y .

); Wald et al. (2021); Ahuja et al. (2021) use a more restrictive linear causal model, Arjovsky et al. (2020) only provide full solution for linear classifiers, Christiansen et al. (2021) assume additive confounders, Yuan et al. (2021); Makar et al. (2022); Puli et al. (2022) need to utilize the observation of the variable spurious correlated with the label Y .

Out-of-domain accuracy on ColoredMNIST 10 and ColoredMNIST with two train environments [0.1, 0.2] and one test environment [0.9].

Out-of-domain accuracy on DomainBed benchmark. Numbers are averaged over all test environments with standard deviation over 3 runs. The training domain validation scheme is used. Full results on each test environment can be found in Appendix B.4.

and Puli et al. (2022) propose to utilize an additional auxiliary variable different from the label to solve the OOD problem, using a single train domain. Their methods are two-phased: (1) reweight the train data with respect to the auxiliary variable; (2) add invariance regularizations to the training objective. The limitation of such methods is that they can only handle distribution shifts induced by the chosen auxiliary variable.Little & Badawy (2019) also propose a bootstrapping method to resample train data by reweighting to mimic a randomized controlled trial. There is also single-phased methods likeWang et al. (2021)   which proposes new training objectives to reduce spurious correlations. Some other OOD works aim to improve OOD accuracy without any additional information.Liu  et al. (2021a)  andLu et al. (2022) propose to use VAE to learn latent variables in the assumed causal graph, with appropriate assumptions of the train data distribution in a single train domain. The identifiability of such latent variables is usually based onKhemakhem et al. (2020), which assumes that the latent variable has a factorial exponential family distribution given an auxiliary variable.

Choice of hyperparameters for constructing balanced mini-batches, including training the VAE model for latent covariate learning (n, lr, batch size) and the balancing score matching (a, d).

Out-of-domain accuracy on ColoredMNIST 10 when using different distance metrics.





ACKNOWLEDGMENTS

This work was supported by the National Science Foundation award #2048122. The views expressed are those of the author and do not reflect the official policy or position of the US government. We thank Google and the Robert N. Noyce Trust for their generous gift to the University of California. This work was also supported in part by the National Science Foundation Graduate Research Fellowship under Grant No. 1650114. This work was also partially supported by the National Institutes of Health (NIH) under Contract R01HL159805, by the NSF-Convergence Accelerator Track-D award #2134901, by a grant from Apple Inc., a grant from KDDI Research Inc, and generous gifts from Salesforce Inc., Microsoft Research, and Amazon Research. This work was also supported by NSERC Discovery Grant RGPIN-2022-03215, DGECR-2022-00357. 

