UNDERSTANDING THE FAILURE MODES OF OUT-OF-DISTRIBUTION GENERALIZATION

Abstract

Empirical studies suggest that machine learning models often rely on features, such as the background, that may be spuriously correlated with the label only during training time, resulting in poor accuracy during test-time. In this work, we identify the fundamental factors that give rise to this behavior, by explaining why models fail this way even in easy-to-learn tasks where one would expect these models to succeed. In particular, through a theoretical study of gradient-descenttrained linear classifiers on some easy-to-learn tasks, we uncover two complementary failure modes. These modes arise from how spurious correlations induce two kinds of skews in the data: one geometric in nature, and another, statistical in nature. Finally, we construct natural modifications of image classification datasets to understand when these failure modes can arise in practice. We also design experiments to isolate the two failure modes when training modern neural networks on these datasets. 1

1. INTRODUCTION

A machine learning model in the wild (e.g., a self-driving car) must be prepared to make sense of its surroundings in rare conditions that may not have been well-represented in its training set. This could range from conditions such as mild glitches in the camera to strange weather conditions. This out-of-distribution (OoD) generalization problem has been extensively studied within the framework of the domain generalization setting (Blanchard et al., 2011; Muandet et al., 2013) . Here, the classifier has access to training data sourced from multiple "domains" or distributions, but no data from test domains. By observing the various kinds of shifts exhibited by the training domains, we want the classifier can learn to be robust to such shifts. The simplest approach to domain generalization is based on the Empirical Risk Minimization (ERM) principle (Vapnik, 1998) : pool the data from all the training domains (ignoring the "domain label" on each point) and train a classifier by gradient descent to minimize the average loss on this pooled dataset. Alternatively, many recent studies (Ganin et al., 2016; Arjovsky et al., 2019; Sagawa et al., 2020a) have focused on designing more sophisticated algorithms that do utilize the domain label on the datapoints e.g., by enforcing certain representational invariances across domains. A basic premise behind pursuing such sophisticated techniques, as emphasized by Arjovsky et al. (2019) , is the empirical observation that ERM-based gradient-descent-training (or for convenience, just ERM) fails in a characteristic way. As a standard illustration, consider a cow-camel classification task (Beery et al., 2018) where the background happens to be spuriously correlated with the label in a particular manner only during training -say, most cows are found against a grassy background and most camels against a sandy one. Then, during test-time, if the correlation is completely flipped (i.e., all cows in deserts, and all camels in meadows), one would observe that the accuracy of ERM drops drastically. Evidently, ERM, in its unrestrained attempt at fitting the data, indiscriminately relies on all kinds of informative features, including unreliable spurious features like the background. However, an algorithm that carefully uses domain label information can hope to identify and rely purely on invariant features (or "core" features (Sagawa et al., 2020b)). While the above narrative is an oft-stated motivation behind developing sophisticated OoD generalization algorithms, there is little formal explanation as to why ERM fails in this characteristic way. Existing works (Sagawa et al., 2020b; Tsipras et al., 2019; Arjovsky et al., 2019; Shah et al., 2020) provide valuable answers to this question through concrete theoretical examples; however, their examples critically rely on certain factors to make the task difficult enough for ERM to rely on the spurious features. For instance, many of these examples have invariant features that are only partially predictive of the label (see Fig 1a ). Surprisingly though, ERM relies on spurious features even in much easier-to-learn tasks where these complicating factors are absent -such as in tasks with fully predictive invariant features e.g., Fig 1c (2020a) or for that matter, in any real-world situation where the object shape perfectly determines the label. This failure in easy-to-learn tasks, as we argue later, is not straightforward to explain (see Fig 1b for brief idea). This evidently implies that there must exist factors more general and fundamental than those known so far, that cause ERM to fail. Our goal in this work is to uncover these fundamental factors behind the failure of ERM. The hope is that this will provide a vital foundation for future work to reason about OoD generalization. Indeed, recent empirical work (Gulrajani & Lopez-Paz, 2020) has questioned whether existing alternatives necessarily outperform ERM on OoD tasks; however, due to a lack of theory, it is not clear how to hypothesize about when/why one algorithm would outperform another here. Through our theoretical study, future work can hope to be better positioned to precisely identify the key missing components in these algorithms, and bridge these gaps to better solve the OoD generalization problem. Our contributions. To identify the most fundamental factors causing OoD failure, our strategy is to (a) study tasks that are "easy" to succeed at, and (b) to demonstrate that ERM relies on spurious features despite how easy the tasks are. More concretely: 1. We formulate a set of constraints on how our tasks must be designed so that they are easy to succeed at (e.g., the invariant feature must be fully predictive of the label). Notably, this class of easy-to-learn tasks provides both a theoretical test-bed for reasoning about OoD generalization and also a simplified empirical test-bed. In particular, this class encompasses simplified MNIST and CIFAR10-based classification tasks where we establish empirical failure of ERM. 2. We identify two complementary mechanisms of failure of ERM that arise from how spurious correlations induce two kinds of skews in the data: one that is geometric and the other statistical. In particular, we theoretically isolate these failure modes by studying linear classifiers trained by



Figure 1: Unexplained OoD failure: Existing theory can explain why classifiers rely on the spurious feature when the invariant feature is in itself not informative enough (Fig 1a). But when invariant features are fully predictive of the label, these explanations fall apart. E.g., in the four-point-dataset of Fig 1b, one would expect the max-margin classifier to easily ignore spurious correlations (also see Sec 3). Yet, why do classifiers (including the max-margin) rely on the spurious feature, in so many real-world settings where the shapes are perfectly informative of the object label (e.g., Fig 1c)? We identify two fundamental factors behind this behavior. In doing so, we also identify and explain other kinds of vulnerabilities such as the one in Fig 1d (see Sec 4).

or the Waterbirds/CelebA examples in Sagawa et al.

funding

performed in part while Vaishnavh Nagarajan was interning at Blueshift, Alphabet.

