DOING FAST ADAPTATION FAST: CONDITIONALLY INDEPENDENT DEEP ENSEMBLES FOR DISTRIBUTION SHIFTS

Abstract

Classifiers in a diverse ensemble capture distinct predictive signals, which is valuable for datasets containing multiple strongly predictive signals. Performing fast adaptation at test time allows us to generalize to distributions where certain signals are no longer predictive, or to avoid relying on sensitive or protected attributes. However, ensemble learning is often expensive, even more so when we need to enforce diversity constraints between the high-dimensional representations of the classifiers. Instead, we propose an efficient and fast method for learning ensemble diversity. We minimize conditional mutual information of the output distributions between classifiers, a quantity which can be cheaply and exactly computed from empirical data. The resulting ensemble contains individually strong predictors that are only dependent because they predict the label. We demonstrate the efficacy of our method on shortcut learning tasks. Performing fast adaptation on our ensemble selects shortcut-invariant models that generalize well to test distributions where the shortcuts are uncorrelated with the label.

1. INTRODUCTION

Some of the strongest scientific theories are supported by multiple sources of evidence, a principle described by 19th century philosopher William Whewell as "consilience". Evolution is one such example, having been firmly corroborated by fields ranging from paleontology to genetics. In many real-world applications of machine learning, datasets can similarly contain multiple predictive signals that explain the label well. In these settings, a standard model typically learns from a combination of predictive features (Ross et al., 2018; Kirichenko et al., 2022) . Such a model will fail to generalize to distribution shifts that break the correlation between certain signals and the label (Hovy & Søgaard, 2015; Hashimoto et al., 2018; Puli et al., 2022) . This shortcoming can be addressed by learning a diverse set or ensemble of classifiers. Such methods typically exploit some notion of independence to learn multiple classifiers that rely on different predictive signals. We can then perform fast adaptation, using a small amount of out-of-distribution (OOD) validation data to select the model that generalizes best. Learning diversity is also beneficial in and of itself: these classifiers are empirically shown to be more human-interpretable than if we were to fit a single model (Ross et al., 2018) , possibly because they learn disentangled representations that correspond to natural factors of variation (Shu et al., 2019) . The key challenge is quantifying the right notion of diversity. Existing work has exploited concepts like input gradient or parameter orthogonality as a proxy for statistical independence (Teney et al., 2021; Xu et al., 2021) . To tackle OOD generalization, which fundamentally requires additional assumptions or data beyond the observed training data (Bareinboim et al., 2022; Schölkopf et al., 2021) , previous work have also assumed access to unlabelled test data and measured disagreement on those examples (Lee et al., 2022; Pagliardini et al., 2022) . However, these objectives or assumptions are often prohibitive or unrealistic in real-world settings. For example, group-balanced test data is not always obtainable, e.g. when deploying a pneumonia model to multiple new hospitals whose patient profiles may change over time. Another costly example is enforcing input gradient orthogonality on high-dimensional covariates like images or text, where it can be challenging to avoid learning from orthogonal covariates of the same underlying feature, such as neighboring pixels.

