OUT-OF-DISTRIBUTION GENERALIZATION VIA RISK EXTRAPOLATION

Abstract

Distributional shift is one of the major obstacles when transferring machine learning prediction systems from the lab to the real world. To tackle this problem, we assume that variation across training domains is representative of the variation we might encounter at test time, but also that shifts at test time may be more extreme in magnitude. In particular, we show that reducing differences in risk across training domains can reduce a model's sensitivity to a wide range of extreme distributional shifts, including the challenging setting where the input contains both causal and anti-causal elements. We motivate this approach, Risk Extrapolation (REx), as a form of robust optimization over a perturbation set of extrapolated domains (MM-REx), and propose a penalty on the variance of training risks (V-REx) as a simpler variant. We prove that variants of REx can recover the causal mechanisms of the targets, while also providing some robustness to changes in the input distribution ("covariate shift"). By appropriately trading-off robustness to causally induced distributional shifts and covariate shift, REx is able to outperform alternative methods such as Invariant Risk Minimization in situations where these types of shift co-occur.

1. INTRODUCTION

While neural networks often exhibit super-human generalization on the training distribution, they can be extremely sensitive to distributional shift, presenting a major roadblock for their practical application (Su et al., 2019; Engstrom et al., 2017; Recht et al., 2019; Hendrycks & Dietterich, 2019) . This sensitivity is often caused by relying on "spurious" features unrelated to the core concept we are trying to learn (Geirhos et al., 2018) . For instance, Beery et al. (2018) give the example of an image recognition model failing to correctly classify cows on the beach, since it has learned to make predictions based on the features of the background (e.g. a grassy field) instead of just the animal. In this work, we consider out-of-distribution (OOD) generalization, also known as domain generalization, where a model must generalize appropriately to a new test domain for which it has neither labeled nor unlabeled training data. Following common practice (Ben-Tal et al., 2009) , we formulate this as optimizing the worst-case performance over a perturbation set of possible test domains, F: R OOD F (θ) = max e∈F R e (θ) Since generalizing to arbitrary test domains is impossible, the choice of perturbation set encodes our assumptions about which test domains might be encountered. Instead of making such assumptions a priori, we assume access to data from multiple training domains, which can inform our choice of perturbation set. A classic approach for this setting is group distributionally robust optimization (DRO) (Sagawa et al., 2019) , where F contains all mixtures of the training distributions. This is mathematically equivalent to considering convex combinations of the training risks. However, we aim for a more ambitious form of OOD generalization, over a larger perturbation set. Our method minimax Risk Extrapolation (MM-REx) is an extension of DRO where F instead contains affine combinations of training risks, see Figure 1 . Under specific circumstances, MM-REx can be thought of as DRO over a set of extrapolated domainsfoot_0 , allowing us to carry over machinery



We define "extrapolation" to mean "outside the convex hull", see Appendix B for more.1

