CAUSAL BALANCING FOR DOMAIN GENERALIZATION

Abstract

While machine learning models rapidly advance the state-of-the-art on various real-world tasks, out-of-domain (OOD) generalization remains a challenging problem given the vulnerability of these models to spurious correlations. We propose a balanced mini-batch sampling strategy to transform a biased data distribution into a spurious-free balanced distribution, based on the invariance of the underlying causal mechanisms for the data generation process. We argue that the Bayes optimal classifiers trained on such balanced distribution are minimax optimal across a diverse enough environment space. We also provide an identifiability guarantee of the latent variable model of the proposed data generation process, when utilizing enough train environments. Experiments are conducted on DomainBed, demonstrating empirically that our method obtains the best performance across 20 baselines reported on the benchmark. 1

1. INTRODUCTION

Machine learning is achieving tremendous success in many fields with useful real-world applications (Silver et al., 2016; Devlin et al., 2019; Jumper et al., 2021) . While machine learning models can perform well on in-domain data sampled from seen environments, they often fail to generalize to out-of-domain (OOD) data sampled from unseen environments (Quiñonero-Candela et al., 2009; Szegedy et al., 2014) . One explanation is that machine learning models are prone to learning spurious correlations that change between environments. For example, in image classification, instead of relying on the object of interest, machine learning models easily rely on surface-level textures (Jo & Bengio, 2017; Geirhos et al., 2019) or background environments (Beery et al., 2018; Zhang et al., 2020) . This vulnerability to changes in environments can cause serious problems for machine learning systems deployed in the real world, calling into question their reliability over time. Various methods have been proposed to improve the OOD generalizability by considering the invariance of causal features or the underlying causal mechanism (Pearl, 2009) through which data is generated. Such methods often aim to find invariant data representations using new loss function designs that incorporate some invariance conditions across different domains into the training process (Arjovsky et al., 2020; Mahajan et al., 2021; Liu et al., 2021a; Lu et al., 2022; Wald et al., 2021) . Unfortunately, these approaches have to contend with trade-offs between weak linear models or approaches without theoretical guarantees (Arjovsky et al., 2020; Wald et al., 2021) , and empirical studies have shown their utility in the real world to be questionable (Gulrajani & Lopez-Paz, 2020) . In this paper, we consider the setting that multiple train domains/environments are available. We theoretically show that the Bayes optimal classifier trained on a balanced (spurious-free) distribution is minimax optimal across all environments. Then we propose a principled two-step method to sample balanced mini-batches from such balanced distribution: (1) learn the observed data distribution using a variational autoencoder (VAE) and identify the latent covariate; (2) match train examples with the closest latent covariate to create balanced mini-batches. By only modifying the mini-batch sampling strategy, our method is lightweight and highly flexible, enabling seamless incorporation with complex classification models or improvement upon other domain generalization methods. Our contributions are as follows: (1) We propose a general non-linear causality-based framework for the domain generalization problem of classification tasks; (2) We prove that a spurious-free balanced distribution can produce minimax optimal classifiers for OOD generalization; (3) We rigorously demonstrate that the source of spurious correlation, as a latent variable, can be identified given a large enough set of training environments in a nonlinear setting; (4) We propose a novel and principled balanced mini-batch sampling algorithm that, in an ideal scenario, can remove the spurious correlations in the observed data distribution; (5) Our empirical results show that our method obtains significant performance gain compared to 20 baselines on DomainBed (Arjovsky et al., 2020) .

2. PRELIMINARIES

Problem Setting. We consider a standard domain generalization setting with a potentially highdimensional variable X (e.g. an image), a label variable Y and a discrete environment (or domain) variable E in the sample spaces X , Y, E, respectively. Here we focus on the classification problems with Y = {1, 2, ..., m} and X ⊆ R d . We assume that the training data are collected from a finite subset of training environments E train ⊂ E. The training data D e = {(x e i , y e i )} N e i=1 is then sampled from the distribution p e (X, Y ) = p(X, Y |E = e) for all e ∈ E train . Our goal is to learn a classifier C ψ : X → Y that performs well in a new, unseen environment e test ̸ ∈ E train . We assume that there is a data generation process of the observed data distribution p e (X, Y ) represented by an underlying structural causal model (SCM) shown in Figure 1a . More specifically, we assume that X is caused by label Y , an unobserved latent variable Z (with sample space Z ∈ R n ) and an independent noise variable ϵ with the following formulation: X = f (Y, Z) + ϵ = f Y (Z) + ϵ. Here, we assume the causal mechanism is invariant across all environments e ∈ E and we further characterize f with the following assumption: Assumption 2.1. f : {1, 2, ..., m} × Z → X is injective. f -1 : X → {1, 2, ..., m} × Z is the left inverse of f . Note that this assumption forces the generation process of X to consider both Z and Y instead of only one of them. Suppose ϵ has a known probability density function p ϵ > 0. Then we have p f (X|Z, Y ) = p ϵ (X -f Y (Z)). While the causal mechanism is invariant across environments, we assume that the correlation between label Y and latent Z is environment-variant and Z should exclude Y information. i.e., Y cannot be recovered as a function of Z. If Y is a function of Z, the generation process of X can completely ignore Y and f would not be injective. We consider the following family of distributions: F = { p e (X, Y, Z) = p f (X | Z, Y )p e (Z|Y )p e (Y )|p e (Z|Y ), p e (Y ) > 0 } e . Then the environment space we consider would be all the index of F: E = { e | p e ∈ F }. Note that any mixture of distributions from F would also be a member of F. i.e. Any combination of the environments from E would also be an environment in E. To better understand our setting, consider the following example: an image X of an object in class Y has an appearance driven by the fundamental shared properties of Y as well as other meaningful latent features Z that do not determine "Y -ness", but can be spuriously correlated with Y . In



We publicly release our code at https://github.com/WANGXinyiLinda/ causal-balancing-for-domain-generalization.



Figure 1: The causal graphical model assumed for data generation process in environment e ∈ E. Shaded nodes mean being observed and white nodes mean not being observed. Black arrows mean causal relations invariant across different environments. The Red dashed line means correlation varies across different environments.

