OUT-OF-DISTRIBUTION GENERALIZATION VIA RISK EXTRAPOLATION

Abstract

Distributional shift is one of the major obstacles when transferring machine learning prediction systems from the lab to the real world. To tackle this problem, we assume that variation across training domains is representative of the variation we might encounter at test time, but also that shifts at test time may be more extreme in magnitude. In particular, we show that reducing differences in risk across training domains can reduce a model's sensitivity to a wide range of extreme distributional shifts, including the challenging setting where the input contains both causal and anti-causal elements. We motivate this approach, Risk Extrapolation (REx), as a form of robust optimization over a perturbation set of extrapolated domains (MM-REx), and propose a penalty on the variance of training risks (V-REx) as a simpler variant. We prove that variants of REx can recover the causal mechanisms of the targets, while also providing some robustness to changes in the input distribution ("covariate shift"). By appropriately trading-off robustness to causally induced distributional shifts and covariate shift, REx is able to outperform alternative methods such as Invariant Risk Minimization in situations where these types of shift co-occur.

1. INTRODUCTION

While neural networks often exhibit super-human generalization on the training distribution, they can be extremely sensitive to distributional shift, presenting a major roadblock for their practical application (Su et al., 2019; Engstrom et al., 2017; Recht et al., 2019; Hendrycks & Dietterich, 2019) . This sensitivity is often caused by relying on "spurious" features unrelated to the core concept we are trying to learn (Geirhos et al., 2018) . For instance, Beery et al. (2018) give the example of an image recognition model failing to correctly classify cows on the beach, since it has learned to make predictions based on the features of the background (e.g. a grassy field) instead of just the animal. In this work, we consider out-of-distribution (OOD) generalization, also known as domain generalization, where a model must generalize appropriately to a new test domain for which it has neither labeled nor unlabeled training data. Following common practice (Ben-Tal et al., 2009) , we formulate this as optimizing the worst-case performance over a perturbation set of possible test domains, F: R OOD F (θ) = max e∈F R e (θ) Since generalizing to arbitrary test domains is impossible, the choice of perturbation set encodes our assumptions about which test domains might be encountered. Instead of making such assumptions a priori, we assume access to data from multiple training domains, which can inform our choice of perturbation set. A classic approach for this setting is group distributionally robust optimization (DRO) (Sagawa et al., 2019) , where F contains all mixtures of the training distributions. This is mathematically equivalent to considering convex combinations of the training risks. However, we aim for a more ambitious form of OOD generalization, over a larger perturbation set. Our method minimax Risk Extrapolation (MM-REx) is an extension of DRO where F instead contains affine combinations of training risks, see Figure 1 . Under specific circumstances, MM-REx can be thought of as DRO over a set of extrapolated domainsfoot_0 , allowing us to carry over machinery In particular, focusing on supervised learning, we show that Risk Extrapolation can uncover invariant relationships between inputs X and targets Y . Intuitively, an invariant relationship is a statistical relationship which is maintained across all domains in F. Returning to the cow-on-the-beach example, the relationship between the animal and the label is expected to be invariant, while the relationship between the background and the label is not. A model which bases its predictions on such an invariant relationship is said to perform invariant prediction.foot_1 Many domain generalization methods assume P (Y |X) is an invariant relationship, limiting distributional shift to changes in P (X), which are known as covariate shift (David et al., 2010) . This assumption can easily be violated, however. For instance, when Y causes X, a more sensible assumption is that P (X|Y ) is fixed, with P (Y ) varying across domains (Schölkopf et al., 2012; Lipton et al., 2018) . In general, invariant prediction may involve an aspect of causal discovery. Depending on the perturbation set, however, other, more predictive, invariant relationships may also exist (Koyama & Yamaguchi, 2020) . # » P 1 (X, Y ) # » P 2 (X, Y ) e1 e2 e3 R R RI convex hull of training distributions # » P 1 (X, Y ) # » P 2 (X, Y ) e1 e2 e3 R REx R extrapolation region The first method for invariant prediction to be compatible with modern deep learning problems and techniques is Invariant Risk Minimization (IRM) (Arjovsky et al., 2019) , making it a natural point of comparison. Our work focuses on explaining how REx addresses OOD generalization, and highlighting differences (especially advantages) compared with IRM and other domain generalization methods, see Table 1 . Broadly speaking, REx optimizes for robustness to the forms of distributional shift that have been observed to have the largest impact on performance in training domains. This can prove a significant advantage over the more focused (but also limited) robustness that IRM targets. For instance, unlike IRM, REx can also encourage robustness to covariate shift (see Section 3). And indeed, our experiments show that REx significantly outperforms IRM in settings that involve covariate shift and require invariant prediction, including modified versions of CMNIST and simulated robotics tasks from the Deepmind control suite. On the other hand, because REx does not distinguish between underfitting and inherent noise, IRM has an advantage in settings where some domains are intrinsically harder than others. We perform several other sets of experiments in order to better understand and compare REx and IRM. Our contributions include: 1) MM-REx, a novel domain generalization problem formulation suitable for invariant prediction. 2) Demonstrating that REx solves invariant prediction tasks where IRM fails due to covariate shift. 3) Proving that equality of risks can be a sufficient criteria for discovering causal structure.



We define "extrapolation" to mean "outside the convex hull", see Appendix B for more. Note this is different from learning an invariant representation(Ganin et al., 2016); see Section 2.2.



Figure 1: Left: Robust optimization optimizes worst-case performance over the convex hull of training distributions. Right: By extrapolating risks, REx encourages robustness to larger shifts. Here e 1 , e 2 , and e 3 represent training distributions, and #» P 1 (X, Y ), # » P 2 (X, Y ) representsome particular directions of variation in the affine space of quasiprobability distributions over (X, Y ).

