OUT-OF-DISTRIBUTION PREDICTION WITH INVARIANT RISK MINIMIZATION: THE LIMITATION AND AN EF-FECTIVE FIX

Abstract

This work considers the out-of-distribution (OOD) prediction problem where (1) the training data are from multiple domains and (2) the test domain is unseen in the training. DNNs fail in OOD prediction because they are prone to pick up spurious correlations. Recently, Invariant Risk Minimization (IRM) is proposed to address this issue. Its effectiveness has been demonstrated in the colored MNIST experiment. Nevertheless, we find that the performance of IRM can be dramatically degraded under strong Λ spuriousness -when the spurious correlation between the spurious features and the class label is strong due to the strong causal influence of their common cause, the domain label, on both of them (see Fig. 1 ). In this work, we try to answer the questions: why does IRM fail in the aforementioned setting? Why does IRM work for the original colored MNIST dataset? Then, we propose a simple and effective approach to fix the problem of IRM. We combine IRM with conditional distribution matching to avoid a specific type of spurious correlation under strong Λ spuriousness. Empirically, we design a series of semi synthetic datasets -the colored MNIST plus, which exposes the problems of IRM and demonstrates the efficacy of the proposed method.



1 INTRODUCTION E X c Y X s Figure 1: The causal graph in OOD prediction: P (Y |X c ) is invariant across domains. The spurious correlation P (Y |X s ) may vary. A directed (bidirected) edge is a causal relationship (correlation). Strong empirical results have demonstrated the efficacy of deep neural networks (DNNs) in a variety of areas including computer vision, natural language processing and speech recognition. However, such positive results overwhelmingly rely on the assumption that the training, validation and test data consist of independent identical samples of the same underlying distribution. In contrast, in the setting of outof-distribution (OOD) prediction where (1) the training data is from multiple domains and (2) the test set has different distribution from the training, the performance of DNNs can be dramatically degraded. This is because DNNs are prone to pick up spurious correlations which do not hold beyond the training data distribution (Beery et al., 2018; Arjovsky et al., 2019) . For example, when most camel pictures in a training set have a desert in the background, DNNs will pick up the spurious correlation between a desert and the class label, leading to failures when camel pictures come with different backgrounds in a test set. Therefore, OOD prediction remains an extremely challenging problem for DNNs. The invariant causal relationships across different domains turns out to be the key to address the challenge of OOD prediction. The causal graph in Fig. 1 describes the relationships of the variables in the OOD prediction problem. Although spurious correlations learned in one domain are unreliable in another, invariant causal relationships enable DNNs that capture causal relationships to generalize to unseen domains. In practice, it is extremely difficult to know whether an input feature is causal or spurious. Thus, a recipe for training DNNs that capture causal relationships is learning causal feature representations that hold invariant causal relationships with the class label. Thus, the main challenge of OOD prediction becomes learning causal feature representations given training data from multiple domains. Invariant Risk Minimization (IRM) (Arjovsky et al., 2019) is proposed to learn causal feature representations for OOD prediction. IRM formulates the invariance of causal relationships as a constraint. It requires that causal feature representations must result in the same optimal classifier across all domains. IRM is written as the conditional independence, Y ⊥ ⊥ E|F (X) in (Chang et al., 2020; Zeng et al., 2019) which can be derived from the causal graph in Fig. 1 when F (X) is a mapping of causal features (X c in Fig. 1 ) with no information loss. A detailed discussion on IRM and Y ⊥ ⊥ E|F (X) is in Appendix A. Despite IRM's success in the colored MNIST dataset (Arjovsky et al., 2019) , in this work, we find an important issue of IRM that has not been discussed. Specifically, we consider the situation where strong spurious correlations among the spurious features, the class label and the domain label only hold for training data. We name this strong Λ spuriousness using the shape of the structure among X s , E and Y in Fig. 1 . Under strong Λ spuriousness, IRM regularized empirical risk can have low values with spurious feature representations that accurately predict the domain label. This is because, in this setting, picking up spurious features can achieve high accuracy in predicting both domain and class in the training set, but not in the test. However, the colored MNIST dataset cannot expose this issue because the strong similarity between the two training domains makes it difficult to pick up the weak spurious correlation between the domain label and the spurious features. To illustrate this problem, we design a new dataset -the colored MNIST plus. As shown in Fig. 2 , in this dataset, under strong Λ spuriousness, the performance of IRM models is significantly degraded. Moreover, to resolve this issue of IRM, we propose an effective solution, which combines IRM with conditional distribution matching (CDM). The CDM constraint means that the representation distribution of instances from the same class should be invariant across domains. Theoretically, we show that (1) causal feature representations can satisfy CDM and IRM at the same time and (2) CDM can prevent DNNs from learning spurious feature representations that accurately predict the domain label. Empirically, on our newly introduced dataset, the proposed method achieves significant performance improvement over IRM under strong Λ spuriousness.

2. PRELIMINARIES AND IRM

Notations. We use lowercase (e.g., x), uppercase (e.g., X) and calligraphic uppercase (e.g., X ) letters for values, random variables and spaces. We let X ∈ X , Y ∈ Y and E ∈ E denote raw input, the class label and the domain label where X , Y and E are the spaces of input, class labels and domains. A DNN model consists of a feature learning function F and a classifier G. A feature learning function F : X → R d maps raw input X to its d-dimensional representations F (X). A classifier G : R d → Y maps a feature representation to a class label. We denote their parameters by θ F and θ G , respectively. Let θ = Concat(θ F , θ G ) denote the concatenation of them. A domain e of n e instances is denoted by D e = {x e i , y ) is a recently proposed method to impose the causal inductive bias: the causal relationships between causal features and the label should be invariant across different domains. It not only aims to address the challenging OOD prediction problem, but also is a pioneer work that guides causal machine learning research towards the development of inductive bias imposing causal constraints. The effectiveness of IRM and its variants have been demonstrated across various areas including computer vision (Ahuja et al., 2020) , natural language processing (Chang et al., 2020 ), CTR prediction (Zeng et al., 2019 ), reinforcement learning (Zhang et al., 2020a) and financial forecasting (Krueger et al., 2020) . Arjovsky et al. (2019) propose the original formulation of IRM as a two-stage optimization problem: 



arg min θ F ,θ G e∈Etr E (x,y)∼De [R e (G(F (x)), y)] s.t. θ G ∈ arg min θ G R e (G(F (x); θ G ), y),(1)

e i } ne i=1 . Let E tr and E ts denote the set of training and test domains, in OOD prediction, we have (1) |E tr | > 1 and (2) E tr ∩ E ts = ∅. Problem Statement. Given data from multiple training domains {D e } e∈Etr . We aim to predict the label y e i of each instance with features x e i from a test domain {x e i , y e i }

