ADAPTIVE RISK MINIMIZATION: A META-LEARNING APPROACH FOR TACKLING GROUP SHIFT Anonymous

Abstract

A fundamental assumption of most machine learning algorithms is that the training and test data are drawn from the same underlying distribution. However, this assumption is violated in almost all practical applications: machine learning systems are regularly tested under distribution shift, due to temporal correlations, particular end users, or other factors. In this work, we consider the setting where the training data are structured into groups and test time shifts correspond to changes in the group distribution. Prior work has approached this problem by attempting to be robust to all possible test time distributions, which may degrade average performance. In contrast, we propose to use ideas from meta-learning to learn models that are adaptable, such that they can adapt to shift at test time using a batch of unlabeled test points. We acquire such models by learning to adapt to training batches sampled according to different distributions, which simulate structural shifts that may occur at test time. Our primary contribution is to introduce the framework of adaptive risk minimization (ARM), a formalization of this setting that lends itself to meta-learning. We develop meta-learning methods for solving the ARM problem, and compared to a variety of prior methods, these methods provide substantial gains on image classification problems in the presence of shift.

1. INTRODUCTION

The standard assumption in empirical risk minimization (ERM) is that the data distribution at test time will match the distribution at training time. When this assumption does not hold, the performance of standard ERM methods typically deteriorates rapidly, and this setting is commonly referred to as distribution or dataset shift (Quiñonero Candela et al., 2009; Lazer et al., 2014) . For instance, we can imagine a handwriting classification system that, after training on a large database of past images, is deployed to specific end users. Some new users have peculiarities in their handwriting style, leading to shift in the input distribution. This test scenario must be carefully considered when building machine learning systems for real world applications. Algorithms for handling distribution shift have been studied under a number of frameworks (Quiñonero Candela et al., 2009) . Many of these frameworks aim for zero shot generalization to shift, which requires more restrictive but realistic assumptions. For example, one popular assumption is that the training data are provided in groups and that distributions at test time will represent either new group distributions or new groups altogether. This assumption is used by, e.g., group distributionally robust optimization (DRO) (Hu et al., 2018; Sagawa et al., 2020) , robust federated learning (Mohri et al., 2019; Li et al., 2020) , and domain generalization (Blanchard et al., 2011; Gulrajani & Lopez-Paz, 2020) . Constructing training groups or tasks in practice is generally accomplished by using meta-data, which exists for most commonly used datasets. This assumption allows for more tractable optimization and still permits a wide range of realistic distribution shifts. However, achieving strong zero shot generalization in this setting is still a hard problem. For example, DRO methods, which focus on achieving maximal worst case performance, can often be overly pessimistic and learn models that do not perform well on the actual test distributions (Hu et al., 2018) . In this work, we take a different approach to combating group distribution shift by learning models that are able to deal with shift by adapting to the test time distribution. To do so, we assume that we can access a batch of unlabeled data points at test time -as opposed to individual isolated inputs -which can be used to implicitly infer the test distribution. This assumption is reasonable in many standard supervised learning setups. For example, we do not access single handwritten characters from an end user, but rather collections of characters such as sentences or paragraphs. When combined with the group assumption above, we arrive at a problem setting that is similar to the standard meta-learning setting (Vinyals et al., 2016) . This allows us to extend well established tools and techniques from meta-learning to address distribution shift problems. Meta-learning typically assumes that training data are grouped into tasks and new tasks are encountered at meta-test time, however these new tasks still include labeled examples for adaptation. As illustrated in Figure 1 , we instead aim to train a model that uses unlabeled data to adapt to the test distribution, thereby not requiring the model to generalize zero shot to all test distributions as in prior approaches. The main contribution of this paper is to introduce the framework of adaptive risk minimization (ARM), in which models have the opportunity to adapt to the data distribution at test time based on unlabeled data points. This contribution provides a principled approach for designing meta-learning methods to tackle distribution shift. We introduce an algorithm and instantiate a set of methods for solving ARM that, given a set of candidate distribution shifts, meta-learns a model that is adaptable to these shifts. One such method is based on meta-training a model such that simply updating batch normalization statistics (Ioffe & Szegedy, 2015) provides effective adaptation at test time, and we demonstrate that this simple approach can produce surprisingly strong results. Our experiments demonstrate that the proposed methods, by leveraging the meta-training phase, are able to outperform prior methods for handling distribution shift in image classification settings exhibiting group shift, including benchmarks for federated learning (Caldas et al., 2019) and testing image classifier robustness (Hendrycks & Dietterich, 2019) .

2. RELATED WORK

A number of prior works have studied distributional shift in various forms (Quiñonero Candela et al., 2009) . In this section, we review prior work in robust optimization, meta-learning, and adaptation. Robust optimization. DRO methods optimize machine learning systems to be robust to adversarial data distributions, thus optimizing for worst case performance against distribution shift (Globerson & Roweis, 2006; Ben-Tal et al., 2013; Liu & Ziebart, 2014; Esfahani & Kuhn, 2015; Miyato et al., 2015; Duchi et al., 2016; Blanchet et al., 2016) . Recent work has shown that these algorithms can be utilized with deep neural networks, with additional care taken for regularization and model capacity (Sagawa et al., 2020) . Unlike DRO methods, ARM methods do not require the model to generalize zero shot to all test time distribution shifts, but instead trains it to adapt to these shifts. Also of particular interest are methods for robustness or adaptation to different users (Horiguchi et al., 2018; Chen et al., 2018; Jiang et al., 2019; Fallah et al., 2020; Lin et al., 2020) , a setting commonly referred to as robust or fair federated learning (McMahan et al., 2017; Mohri et al., 2019; Li et al., 2020) . Unlike these works, we consider the federated learning problem setting in which we do not assume access to any labels from any test users, as we partition users into disjoint train and test sets. We argue that this is a realistic setting for many practical machine learning systems -oftentimes, the only available information from the end user is an unlabeled batch of data.



Figure 1: A schematic of the ARM problem setting and approach, described in detail in Section 3. Left: During training, we assume access to labeled data along with group information z, which allows us to construct training distributions that exhibit group distribution shift. For example, a training distribution may place uniform mass on only a single user's examples. We use these training distributions to learn a model that is adaptable to distribution shift via a form of meta-learning. We detail the specific adaptation procedures (orange box) that we consider in Section 3 and Figure 2. Right: We perform unsupervised adaptation to different test distributions, without requiring zero shot generalization to shift as in prior methods. If the test shifts we observe are similar to those simulated by the training distributions, e.g., we deploy the model to new end users at test time, then we expect that we can effectively adapt to these test distributions for better performance.

