INVARIANT CAUSAL REPRESENTATION LEARNING

Abstract

Due to spurious correlations, machine learning systems often fail to generalize to environments whose distributions differ from the ones used at training time. Prior work addressing this, either explicitly or implicitly, attempted to find a data representation that has an invariant causal relationship with the outcome. This is done by leveraging a diverse set of training environments to reduce the effect of spurious features, on top of which an invariant classifier is then built. However, these methods have generalization guarantees only when both data representation and classifiers come from a linear model class. As an alternative, we propose Invariant Causal Representation Learning (ICRL), a learning paradigm that enables out-ofdistribution generalization in the nonlinear setting (i.e., nonlinear representations and nonlinear classifiers). It builds upon a practical and general assumption: data representations factorize when conditioning on the outcome and the environment. Based on this, we show identifiability up to a permutation and pointwise transformation. We also prove that all direct causes of the outcome can be fully discovered, which further enables us to obtain generalization guarantees in the nonlinear setting. Extensive experiments on both synthetic and real-world datasets show that our approach significantly outperforms a variety of baseline methods.

1. INTRODUCTION

In recent years, despite various impressive success stories, there is still a significant lack of robustness in machine learning algorithms. Specifically, machine learning systems often fail to generalize outside of a specific training distribution, because they usually learn easier-to-fit spurious correlations which are prone to change from training to testing environments. We illustrate this point by considering the widely used example of classifying images of camels and cows (Beery et al., 2018) . The training dataset has a selection bias, i.e., most pictures of cows are taken in green pastures, while most pictures of camels happen to be in deserts. After training a convnet on this dataset, it is found that the model fell into the spurious correlation, i.e., it related green pastures with cows and deserts with camels, and therefore classified green pastures as cows and deserts as camels. The result is that the model failed to classify images of cows when they are taken on sandy beaches. To address the aforementioned problem, a natural idea is to identify which features of the training data present domain-varying spurious correlations with labels and which features describe true correlations of interest that are stable across domains. In the example above, the former are the features describing the context (e.g., pastures and deserts), whilst the latter are the features describing animals (e.g., animal shape). Arjovsky et al. (2019) suggest that one can identify the stable features and build invariant predictors on them by exploiting the varying degrees of spurious correlation naturally present in training data collected from multiple environments. The authors proposed the invariant risk minimization (IRM) approach to find data representations for which the optimal classifier is invariant across all environments. Since this formulation is a challenging bi-leveled optimization problem, the authors proved the generalization of IRM across all environments by constraining both data representations and classifiers to be linear (Theorem 9 in Arjovsky et al. ( 2019)). 2020) studied the problem from the perspective of game theory, with an approach that we call IRMG for short. They showed that the set of Nash equilibria for a proposed game are equivalent to the set of invariant predictors for any finite number of environments, even with nonlinear data representations and nonlinear classifiers. However, these theoretical results in the nonlinear setting only guarantee that one can learn invariant predictors from training environments, but do not guarantee that the learned invariant predictors can generalize well across all environments including unseen testing environments. In fact, the authors directly borrowed the linear generalization result from Arjovsky et al. ( 2019) and presented it as Theorem 2 in Ahuja et al. (2020) .

Ahuja et al. (

In this work we propose an alternative learning paradigm, called Invariant Causal Representation Learning (ICRL), which enables out-of-distribution (OOD) generalization in the nonlinear setting (i.e., nonlinear representations and nonlinear classifier). We first introduce a practical and general assumption: the data representation factorizes (i.e., its components are independent of each other) when conditioning on the outcome (e.g., labels) and the environment (represented as an index). This assumption builds a bridge between supervised learning and unsupervised learning, leading to a guarantee that the data representation can be identified up to a permutation and pointwise transformation. We then theoretically show that all the direct causes of the outcome can be fully discovered. Based on this, the challenging bi-leveled optimization problem in IRM and IRMG can be reduced to two simpler independent optimization problems, that is, learning the data representation and learning the optimal classifier can be performed separately. This further enables us to attain generalization guarantees in the nonlinear setting. Contributions We propose Invariant Causal Representation Learning (ICRL), a novel learning paradigm that enables OOD generalization in the nonlinear setting. (i) We introduce a conditional factorization assumption on data representation for the OOD generalization (Assumption 1). (ii) Base on this assumption, we show that each component of the representation can be identified up to a permutation and pointwise transformation (Theorem 1, 2 & 3). (iii) We further prove that all the direct causes of the outcome can be fully discovered (Proposition 1). (iv) We show that our approach has generalization guarantees in the nonlinear setting (Proposition 2). (v) Empirical results demonstrate that our approach significantly outperforms IRM and IRMG in the nonlinear scenarios.

2.1. IDENTIFIABLE VARIATIONAL AUTOENCODERS

A general issue with variational autoencodersfoot_0 (VAEs) (Kingma & Welling, 2013; Rezende et al., 2014) is the lack of identifiability guarantees of the deep latent variable model. In other words, it is generally impossible to approximate the true joint distribution over observed and latent variables, including the true prior and posterior distributions over latent variables. Consider a simple latent variable model where O ∈ R d stands for an observed variable (random vector) and X ∈ R n for a latent variable. Khemakhem et al. (2020) showed that any model with unconditional latent distribution p θ (X) is unidentifiable. That is, we can always find transformations of X which change its value but do not change its distribution. Hence, the primary assumption that they make to obtain an identifiability result is to include a conditionally factorized prior distribution over the latent variables p θ (X|U ), where U ∈ R m is an additionally observed variable (Hyvarinen et al., 2019) . More specifically, let θ = (f , T , λ) ∈ Θ be the parameters of the conditional generative model: p θ (O, X|U ) = p f (O|X)p T ,λ (X|U ), where p f (O|X) = p (O -f (X)) in which is an independent noise variable with probability density function p ( ), and the prior probability density function is especifically given by p T ,λ (X|U ) = i Q i (X i )/Z i (U ) • exp k j=1 T i,j (X i )λ i,j (U ) , where Q i is the base measure, Z i (U ) the normalizing constant, T i = (T i,1 , . . . , T i,k ) the sufficient statistics, λ i (U ) = (λ i,1 (U ), . . . , λ i,k (U )) the corresponding parameters depending on U , and k the dimension of each sufficient statistic that is fixed in advance. It is worth noting that this assumption is not very restrictive as exponential families have universal approximation capabilities (Sriperumbudur et al., 2017) . As in VAEs, we maximize the corresponding evidence lower bound: L iVAE (θ, φ) :=E p D E q φ (X|O,U ) [log p θ (O, X|U ) -log q φ (X|O, U )] , where we denote by p D the empirical data distribution given by dataset D = O (i) , U (i) N i=1 . This approach is called identifiable VAE (iVAE). Most importantly, it can be proved that iVAE can identify latent variables X up to a permutation and pointwise transformation under the conditions stated in Theorem 2 of (Khemakhem et al., 2020) .



A brief description of variational autoencoders is given in Appendix A.

