DIVERSIFY AND DISAMBIGUATE: OUT-OF-DISTRIBUTION ROBUSTNESS VIA DISAGREEMENT

Abstract

Real-world machine learning problems often exhibit shifts between the source and target distributions, in which source data does not fully convey the desired behavior on target inputs. Different functions that achieve near-perfect source accuracy can make differing predictions on test inputs, and such ambiguity makes robustness to distribution shifts challenging. We propose DivDis, a simple two-stage framework for identifying and resolving ambiguity in data. DivDis first learns a diverse set of hypotheses that achieve low source loss but make differing predictions on target inputs. We then disambiguate by selecting one of the discovered functions using additional information, for example, a small number of target labels. Our experimental evaluation shows improved performance in subpopulation shift and domain generalization settings, demonstrating that DivDis can scalably adapt to distribution shifts in image and text classification benchmarks.

1. INTRODUCTION

Datasets are often underspecified: multiple plausible hypotheses each describe the data equally well (D'Amour et al., 2020) , and the data offers no further evidence to prefer one over another. Despite such ambiguity, machine learning models typically choose only one of the possible explanations of given data. Such choices can be suboptimal, causing these models to fail when the data distribution is shifted, as common in real-world applications. For example, examination of a chest X-ray dataset (Oakden-Rayner et al., 2020) has shown that many images of patients with pneumothorax include a thin drain used for treating the disease. A standard classifier trained on this dataset can erroneously identify such drains as a predictive feature of the disease, exhibiting degraded accuracy on the intended distribution of patients not yet being treated. To not suffer from such failures, it is desirable to have a model that can discover a diverse collection of alternate plausible hypotheses. The standard empirical risk minimization (Vapnik, 1992, ERM) paradigm performs poorly on underspecified data, because ERM tends to select the solution based on the most salient features without considering alternatives (Geirhos et al., 2020; Shah et al., 2020; Scimeca et al., 2021) . This simplicity bias occurs even when training an ensemble (Hansen & Salamon, 1990; Lakshminarayanan et al., 2017) because each model is still biased towards simple functions. While many recent methods (Ganin et al., 2016; Sagawa et al., 2020; Liu et al., 2021) improve robustness in distribution shift settings, we find that they fail on data with more severe underspecification. This is because, similarly to ERM, these methods only consider a single solution even in situations where multiple explanations exist. We propose Diversify and Disambiguate (DivDis), a two-stage framework for learning from underspecified data. Our key idea is to learn a collection of diverse functions that are consistent with the training data but make differing predictions on unlabeled test datapoints. DivDis operates as follows. We train a neural network consisting of a shared backbone feature extractor with multiple heads, each representing a different function. As in regular training, each head is trained to predict labels for training data, but the heads are additionally encouraged to represent different functions from each other. More specifically, the heads are trained to make disagreeing predictions on a separate unlabeled dataset from the test distribution, a setting close to transductive learning. At test time, we select one member of the diversified functions by querying labels for the datapoints most informative for disambiguation. We visually summarize this framework in Fig. 1 . DivDis is designed for scenarios with underspecified data and distribution shift, and its heads will not yield a set of diverse functions in settings where only one function can achieve low training loss. We evaluate DivDis in several settings in which underspecification limits the performance of prior methods, such as standard subpopulation shift benchmarks (Sagawa et al., 2020) or the large-scale CXR and Camelyon17 datasets (Wang et al., 2017; Sagawa et al., 2022) . DivDis achieves an over 15% improvement in worst-group accuracy on the Waterbirds task when tuning hyperparameters without any spurious attribute annotations, and outperforms existing semi-supervised methods on the Camelyon17 task. We also consider challenging problem settings in which labels are completely correlated with spurious attributes, so a classifier based on the spurious feature can achieve zero loss. In these completely correlated settings, our experiments find that DivDis is substantially more sample-efficient: DivDis with 4 target domain labels outperforms two fine-tuning methods that use 128 labels.

2. LEARNING FROM UNDERSPECIFIED DATA

We consider a supervised learning setting in which we train a model f that takes input x ∈ X and predicts its corresponding label y ∈ Y. We train f with a labeled dataset D S = {(x 1 , y 1 ), (x 2 , y 2 ), . . .} drawn from data distribution p S (x, y). The model f is selected from hypothesis class f ∈ F by approximately minimizing the predictive risk E pS(x,y) [ℓ(f (x), y)] on the data distribution. The model f is evaluated via its predictive risk on held-out samples from p S (x, y). Standard procedures such as regularization and cross-validation encourage such generalization. However, even if a function f generalizes to unseen data sampled from the same distribution p S (x, y), performance often deteriorates in distribution shift conditions, when we evaluate on target data sampled from a different distribution p T (x, y). In many distribution shift scenarios (Koh et al., 2021) , the overall data distribution can be modeled as a mixture of domains, where each domain d ∈ D corresponds to a fixed data distribution p d (x, y). In this paper, we specifically consider a subpopulation shift setting, where the source and target distributions are different mixtures of the same underlying domains: p S = d∈D w S d p d and p T = d∈D w T d p d , where {w S d } d∈D ̸ = {w T d } d∈D . Conditions like subpopulation shift can be inherently underspecified because the generative process underlying the data distribution, i.e. the domains and coefficients, has so many possibilities. We formalize this intuition through a notion of near-optimal sets of hypotheses. We define the ε-optimal set for a data distribution as follows: Definition 1 (ε-optimal set). Let p(x, y) be a distribution over X × Y, and F a set of predictors f : X → Y. Let L p : F → R be the risk with respect to p(x, y). The ε-optimal set with respect to F at level ε ≥ 0 is defined as F ε = {f ∈ F|L p (f ) ≤ ε}. (1) Put differently, the ε-optimal set consists of all functions that generalize within the distribution p(x, y). The constant ε controls the degree of generalization, and we consider small ε here onwards. Note that a model's predictions on samples from p S (x, y)-whether D S or a held-out validation set-cannot be used to distinguish between different near-optimal functions with respect to p S (x, y). This is because by definition, the predictions of any two models f 1 , f 2 ∈ F ε are nearly identical on p S (x, y) for small ε. Based on source data alone, we have insufficient reason to prefer any member of F ε over another. Our state of belief should therefore cover F ε as comprehensively as possible, putting nonzero weight on many functions that embody different inductive biases. This reasoning



Figure1: Our two-stage framework for learning from underspecified data. In the DIVERSIFY stage, we train each head in a multi-headed neural network to accurately predict the labels of source data while also outputting differing predictions for unlabeled target data. In the DISAMBIGUATE stage, we choose one of the heads by observing labels for an informative subset of the target data.

