CAUSAL BALANCING FOR DOMAIN GENERALIZATION

Abstract

While machine learning models rapidly advance the state-of-the-art on various real-world tasks, out-of-domain (OOD) generalization remains a challenging problem given the vulnerability of these models to spurious correlations. We propose a balanced mini-batch sampling strategy to transform a biased data distribution into a spurious-free balanced distribution, based on the invariance of the underlying causal mechanisms for the data generation process. We argue that the Bayes optimal classifiers trained on such balanced distribution are minimax optimal across a diverse enough environment space. We also provide an identifiability guarantee of the latent variable model of the proposed data generation process, when utilizing enough train environments. Experiments are conducted on DomainBed, demonstrating empirically that our method obtains the best performance across 20 baselines reported on the benchmark. 1

1. INTRODUCTION

Machine learning is achieving tremendous success in many fields with useful real-world applications (Silver et al., 2016; Devlin et al., 2019; Jumper et al., 2021) . While machine learning models can perform well on in-domain data sampled from seen environments, they often fail to generalize to out-of-domain (OOD) data sampled from unseen environments (Quiñonero-Candela et al., 2009; Szegedy et al., 2014) . One explanation is that machine learning models are prone to learning spurious correlations that change between environments. For example, in image classification, instead of relying on the object of interest, machine learning models easily rely on surface-level textures (Jo & Bengio, 2017; Geirhos et al., 2019) or background environments (Beery et al., 2018; Zhang et al., 2020) . This vulnerability to changes in environments can cause serious problems for machine learning systems deployed in the real world, calling into question their reliability over time. Various methods have been proposed to improve the OOD generalizability by considering the invariance of causal features or the underlying causal mechanism (Pearl, 2009) through which data is generated. Such methods often aim to find invariant data representations using new loss function designs that incorporate some invariance conditions across different domains into the training process (Arjovsky et al., 2020; Mahajan et al., 2021; Liu et al., 2021a; Lu et al., 2022; Wald et al., 2021) . Unfortunately, these approaches have to contend with trade-offs between weak linear models or approaches without theoretical guarantees (Arjovsky et al., 2020; Wald et al., 2021) , and empirical studies have shown their utility in the real world to be questionable (Gulrajani & Lopez-Paz, 2020) . In this paper, we consider the setting that multiple train domains/environments are available. We theoretically show that the Bayes optimal classifier trained on a balanced (spurious-free) distribution is minimax optimal across all environments. Then we propose a principled two-step method to sample balanced mini-batches from such balanced distribution: (1) learn the observed data distribution using a variational autoencoder (VAE) and identify the latent covariate; (2) match train examples



We publicly release our code at https://github.com/WANGXinyiLinda/ causal-balancing-for-domain-generalization.

