DOING FAST ADAPTATION FAST: CONDITIONALLY INDEPENDENT DEEP ENSEMBLES FOR DISTRIBUTION SHIFTS

Abstract

Classifiers in a diverse ensemble capture distinct predictive signals, which is valuable for datasets containing multiple strongly predictive signals. Performing fast adaptation at test time allows us to generalize to distributions where certain signals are no longer predictive, or to avoid relying on sensitive or protected attributes. However, ensemble learning is often expensive, even more so when we need to enforce diversity constraints between the high-dimensional representations of the classifiers. Instead, we propose an efficient and fast method for learning ensemble diversity. We minimize conditional mutual information of the output distributions between classifiers, a quantity which can be cheaply and exactly computed from empirical data. The resulting ensemble contains individually strong predictors that are only dependent because they predict the label. We demonstrate the efficacy of our method on shortcut learning tasks. Performing fast adaptation on our ensemble selects shortcut-invariant models that generalize well to test distributions where the shortcuts are uncorrelated with the label.

1. INTRODUCTION

Some of the strongest scientific theories are supported by multiple sources of evidence, a principle described by 19th century philosopher William Whewell as "consilience". Evolution is one such example, having been firmly corroborated by fields ranging from paleontology to genetics. In many real-world applications of machine learning, datasets can similarly contain multiple predictive signals that explain the label well. In these settings, a standard model typically learns from a combination of predictive features (Ross et al., 2018; Kirichenko et al., 2022) . Such a model will fail to generalize to distribution shifts that break the correlation between certain signals and the label (Hovy & Søgaard, 2015; Hashimoto et al., 2018; Puli et al., 2022) . This shortcoming can be addressed by learning a diverse set or ensemble of classifiers. Such methods typically exploit some notion of independence to learn multiple classifiers that rely on different predictive signals. We can then perform fast adaptation, using a small amount of out-of-distribution (OOD) validation data to select the model that generalizes best. Learning diversity is also beneficial in and of itself: these classifiers are empirically shown to be more human-interpretable than if we were to fit a single model (Ross et al., 2018) , possibly because they learn disentangled representations that correspond to natural factors of variation (Shu et al., 2019) . The key challenge is quantifying the right notion of diversity. Existing work has exploited concepts like input gradient or parameter orthogonality as a proxy for statistical independence (Teney et al., 2021; Xu et al., 2021) . To tackle OOD generalization, which fundamentally requires additional assumptions or data beyond the observed training data (Bareinboim et al., 2022; Schölkopf et al., 2021) , previous work have also assumed access to unlabelled test data and measured disagreement on those examples (Lee et al., 2022; Pagliardini et al., 2022) . However, these objectives or assumptions are often prohibitive or unrealistic in real-world settings. For example, group-balanced test data is not always obtainable, e.g. when deploying a pneumonia model to multiple new hospitals whose patient profiles may change over time. Another costly example is enforcing input gradient orthogonality on high-dimensional covariates like images or text, where it can be challenging to avoid learning from orthogonal covariates of the same underlying feature, such as neighboring pixels. To avoid the pitfalls of operating in high-dimensional input or parameter space, a promising line of work instead adopts the information-theoretic perspective and tackles the problem as representation learning. These approaches apply the information bottleneck method and minimize mutual information between the representations learnt by each classifier. Such an objective forces the classifiers to rely on distinctly meaningful features for prediction. Most notably, Pace et al. ( 2020) and Rame & Cord (2021) minimize mutual information between the classifier representations conditioned on the label. Since any pair of predictors cannot both be accurate while remaining unconditionally independent, the extra conditioning prevents learning weak classifiers. The resulting ensemble contains accurate classifiers that nevertheless rely on distinct predictive signals. The only core assumption is that the underlying predictive signals are themselves conditionally independent. These approaches are conceptually appealing but practically challenging. Mutual information between high-dimensional representations is intractable and must be approximated, either via variational (e.g. Fischer, 2020) or contrastive (e.g. Oord et al., 2018) bounds. Furthermore, such approximations are computationally expensive, a problem that is compounded in the ensemble setting where we wish to train multiple classifiers speedily. We seek to learn ensemble diversity fast and effectively. Our key insight is that it suffices to enforce conditional independence on the output distributions of the classifiers. Our first contribution is proposing conditional mutual information (CMI) between output distributions as the regularizing objective. Assuming conditionally independent predictive signals, enforcing CMI between output distributions also guarantees that the ensemble where separate predictive signals are learnt by separate classifiers is a minimizing solution. Since the output distribution is categorical, CMI can be cheaply and exactly computed from empirical data. In addition, our method avoids using additional sources of data that cannot be found in many real-world domains, such as unlabelled test data or "group" labels for each predictive signal in the dataset. We only permit a small amount of validation data from the test distribution for (1) hyperparameter tuning and (2) selecting the final predictor from our ensemble. We dub our approach as Conditionally Independent Deep Ensembles (CoDE). Our second contribution is evaluating CoDE on benchmark datasets for shortcut learning (Geirhos et al., 2020) . Shortcuts are signals that are (i) highly but spuriously correlated to the label in the training distribution, possibly due to biases in data collection or other systematic pre-processing errors (Torralba & Efros, 2011) , and (ii) preferentially learnt by a neural network, possibly due to simplicity biases (Shah et al., 2020) or architectural biases (e.g. convolutional neural networks (CNNs) relying on texture over shape (Baker et al., 2018) ). An empirical risk minimizing (ERM) model will rely on shortcuts and fail to generalize to test distributions where they are no longer correlated to the label. This is a natural application for our method as the core assumption of conditional independence applies to many such datasets -for example, in natural images, the foreground is typically the label and is thus conditionally independent from the background (shortcut). We show that CoDE effectively recovers an ensemble where the shortcut features and the true signal are learnt by separate classifiers.

2. PRELIMINARIES: SETUP AND NOTATION

In Section 3, we will fully motivate the assumptions behind our model of the data-generating process (DGP). However, we describe it here first to establish key terminology and concepts. Data-Generating Process Let z denote the set of latent factors that generate the set of observed features x ∈ R P . Let y ∈ {0, 1, . . . , K -1} denote the label. The data p e (x, y, z) is generated from a family of distributions indexed by e, the environment. We only consider: (i) a single training environment (e = tr), from which we have access to i.i.d. labelled training examples D tr = {x i , y i } N i=1 , and (ii) a test environment (e = te), from which we draw unlabelled test examples that our model should perform well on. We also allow access to a small set of labelled validation data D val = {x i , y i } N i=1 from the test environment, which is used only for hyperparameter tuning and ensembling (i.e. constructing the final model from the set of learnt classifiers). We make the following assumptions on the DGP: (i) all label information is encoded by z, i.e. p e (y|x, z) = p e (y|z) for all e (ii) p e (x|z) = p(x|z) is invariant across all e

