OUT-OF-DISTRIBUTION PREDICTION WITH INVARIANT RISK MINIMIZATION: THE LIMITATION AND AN EF-FECTIVE FIX

Abstract

This work considers the out-of-distribution (OOD) prediction problem where (1) the training data are from multiple domains and (2) the test domain is unseen in the training. DNNs fail in OOD prediction because they are prone to pick up spurious correlations. Recently, Invariant Risk Minimization (IRM) is proposed to address this issue. Its effectiveness has been demonstrated in the colored MNIST experiment. Nevertheless, we find that the performance of IRM can be dramatically degraded under strong Λ spuriousness -when the spurious correlation between the spurious features and the class label is strong due to the strong causal influence of their common cause, the domain label, on both of them (see Fig. 1 ). In this work, we try to answer the questions: why does IRM fail in the aforementioned setting? Why does IRM work for the original colored MNIST dataset? Then, we propose a simple and effective approach to fix the problem of IRM. We combine IRM with conditional distribution matching to avoid a specific type of spurious correlation under strong Λ spuriousness. Empirically, we design a series of semi synthetic datasets -the colored MNIST plus, which exposes the problems of IRM and demonstrates the efficacy of the proposed method.



1 INTRODUCTION E X c Y X s Strong empirical results have demonstrated the efficacy of deep neural networks (DNNs) in a variety of areas including computer vision, natural language processing and speech recognition. However, such positive results overwhelmingly rely on the assumption that the training, validation and test data consist of independent identical samples of the same underlying distribution. In contrast, in the setting of outof-distribution (OOD) prediction where (1) the training data is from multiple domains and (2) the test set has different distribution from the training, the performance of DNNs can be dramatically degraded. This is because DNNs are prone to pick up spurious correlations which do not hold beyond the training data distribution (Beery et al., 2018; Arjovsky et al., 2019) . For example, when most camel pictures in a training set have a desert in the background, DNNs will pick up the spurious correlation between a desert and the class label, leading to failures when camel pictures come with different backgrounds in a test set. Therefore, OOD prediction remains an extremely challenging problem for DNNs. The invariant causal relationships across different domains turns out to be the key to address the challenge of OOD prediction. The causal graph in Fig. 1 describes the relationships of the variables in the OOD prediction problem. Although spurious correlations learned in one domain are unreliable in another, invariant causal relationships enable DNNs that capture causal relationships to generalize to unseen domains. In practice, it is extremely difficult to know whether an input feature is causal or spurious. Thus, a recipe for training DNNs that capture causal relationships is learning causal feature representations that hold invariant causal relationships with the class label. 1



Figure 1: The causal graph in OOD prediction: P (Y |X c ) is invariant across domains. The spurious correlation P (Y |X s ) may vary. A directed (bidirected) edge is a causal relationship (correlation).

