DIVERSIFY AND DISAMBIGUATE: OUT-OF-DISTRIBUTION ROBUSTNESS VIA DISAGREEMENT

Abstract

Real-world machine learning problems often exhibit shifts between the source and target distributions, in which source data does not fully convey the desired behavior on target inputs. Different functions that achieve near-perfect source accuracy can make differing predictions on test inputs, and such ambiguity makes robustness to distribution shifts challenging. We propose DivDis, a simple two-stage framework for identifying and resolving ambiguity in data. DivDis first learns a diverse set of hypotheses that achieve low source loss but make differing predictions on target inputs. We then disambiguate by selecting one of the discovered functions using additional information, for example, a small number of target labels. Our experimental evaluation shows improved performance in subpopulation shift and domain generalization settings, demonstrating that DivDis can scalably adapt to distribution shifts in image and text classification benchmarks.

1. INTRODUCTION

Datasets are often underspecified: multiple plausible hypotheses each describe the data equally well (D'Amour et al., 2020) , and the data offers no further evidence to prefer one over another. Despite such ambiguity, machine learning models typically choose only one of the possible explanations of given data. Such choices can be suboptimal, causing these models to fail when the data distribution is shifted, as common in real-world applications. For example, examination of a chest X-ray dataset (Oakden-Rayner et al., 2020) has shown that many images of patients with pneumothorax include a thin drain used for treating the disease. A standard classifier trained on this dataset can erroneously identify such drains as a predictive feature of the disease, exhibiting degraded accuracy on the intended distribution of patients not yet being treated. To not suffer from such failures, it is desirable to have a model that can discover a diverse collection of alternate plausible hypotheses. The standard empirical risk minimization (Vapnik, 1992, ERM) paradigm performs poorly on underspecified data, because ERM tends to select the solution based on the most salient features without considering alternatives (Geirhos et al., 2020; Shah et al., 2020; Scimeca et al., 2021) . This simplicity bias occurs even when training an ensemble (Hansen & Salamon, 1990; Lakshminarayanan et al., 2017) because each model is still biased towards simple functions. While many recent methods (Ganin et al., 2016; Sagawa et al., 2020; Liu et al., 2021) improve robustness in distribution shift settings, we find that they fail on data with more severe underspecification. This is because, similarly to ERM, these methods only consider a single solution even in situations where multiple explanations exist. We propose Diversify and Disambiguate (DivDis), a two-stage framework for learning from underspecified data. Our key idea is to learn a collection of diverse functions that are consistent with the training data but make differing predictions on unlabeled test datapoints. DivDis operates as follows. We train a neural network consisting of a shared backbone feature extractor with multiple heads, each representing a different function. As in regular training, each head is trained to predict labels for training data, but the heads are additionally encouraged to represent different functions from each other. More specifically, the heads are trained to make disagreeing predictions on a separate unlabeled dataset from the test distribution, a setting close to transductive learning. At test time, we select one member of the diversified functions by querying labels for the datapoints most informative for disambiguation. We visually summarize this framework in Fig. 1 . DivDis is designed for

