INVARIANT CAUSAL REPRESENTATION LEARNING

Abstract

Due to spurious correlations, machine learning systems often fail to generalize to environments whose distributions differ from the ones used at training time. Prior work addressing this, either explicitly or implicitly, attempted to find a data representation that has an invariant causal relationship with the outcome. This is done by leveraging a diverse set of training environments to reduce the effect of spurious features, on top of which an invariant classifier is then built. However, these methods have generalization guarantees only when both data representation and classifiers come from a linear model class. As an alternative, we propose Invariant Causal Representation Learning (ICRL), a learning paradigm that enables out-ofdistribution generalization in the nonlinear setting (i.e., nonlinear representations and nonlinear classifiers). It builds upon a practical and general assumption: data representations factorize when conditioning on the outcome and the environment. Based on this, we show identifiability up to a permutation and pointwise transformation. We also prove that all direct causes of the outcome can be fully discovered, which further enables us to obtain generalization guarantees in the nonlinear setting. Extensive experiments on both synthetic and real-world datasets show that our approach significantly outperforms a variety of baseline methods.

1. INTRODUCTION

In recent years, despite various impressive success stories, there is still a significant lack of robustness in machine learning algorithms. Specifically, machine learning systems often fail to generalize outside of a specific training distribution, because they usually learn easier-to-fit spurious correlations which are prone to change from training to testing environments. We illustrate this point by considering the widely used example of classifying images of camels and cows (Beery et al., 2018) . The training dataset has a selection bias, i.e., most pictures of cows are taken in green pastures, while most pictures of camels happen to be in deserts. After training a convnet on this dataset, it is found that the model fell into the spurious correlation, i.e., it related green pastures with cows and deserts with camels, and therefore classified green pastures as cows and deserts as camels. The result is that the model failed to classify images of cows when they are taken on sandy beaches. To address the aforementioned problem, a natural idea is to identify which features of the training data present domain-varying spurious correlations with labels and which features describe true correlations of interest that are stable across domains. In the example above, the former are the features describing the context (e.g., pastures and deserts), whilst the latter are the features describing animals (e.g., animal shape). Arjovsky et al. ( 2019) suggest that one can identify the stable features and build invariant predictors on them by exploiting the varying degrees of spurious correlation naturally present in training data collected from multiple environments. The authors proposed the invariant risk minimization (IRM) approach to find data representations for which the optimal classifier is invariant across all environments. Since this formulation is a challenging bi-leveled optimization problem, the authors proved the generalization of IRM across all environments by constraining both data representations and classifiers to be linear (Theorem 9 in Arjovsky et al. ( 2019)). Ahuja et al. (2020) studied the problem from the perspective of game theory, with an approach that we call IRMG for short. They showed that the set of Nash equilibria for a proposed game are equivalent to the set of invariant predictors for any finite number of environments, even with nonlinear data representations and nonlinear classifiers. However, these theoretical results in the nonlinear setting only guarantee that one can learn invariant predictors from training environments, but do not guarantee that the learned invariant predictors can generalize well across all environments including 1

