LEARNING A NON-REDUNDANT COLLECTION OF CLASSIFIERS

Abstract

Supervised learning models constructed under the i.i.d. assumption have often been shown to exploit spurious or brittle predictive signals instead of more robust ones present in the training data. Inspired by Quality-Diversity algorithms, in this work we train a collection of classifiers to learn distinct solutions to a classification problem, with the goal of learning to exploit a variety of predictive signals present in the training data. We propose an information-theoretic measure of model diversity based on minimizing an estimate of conditional total correlation of final layer representations across models given the label. We consider datasets with synthetically injected spurious correlations and evaluate our framework's ability to rapidly adapt to a change in distribution that destroys the spurious correlation. We compare our method to a variety of baselines under this evaluation protocol, showing that it is competitive with other approaches while being more successful at isolating distinct signals. We also show that our model is competitive with Invariant Risk Minimization (IRM) under this evaluation protocol without requiring access to the environment information required by IRM to discriminate between spurious and robust signals.

1. INTRODUCTION

The Empirical Risk Minimization (ERM) principle (Vapnik, 2013) , which underpins many machine learning models, is built on the assumption that training and testing samples are drawn i.i.d. from some hypothetical distribution. It has been demonstrated that certain violations of this assumption (in conjunction with potential misalignments between the formulation of the learning objectives and the underlying task of interest) lead to models that exploit spurious or brittle correlations in the training data. Examples include learning to exploit image backgrounds instead of objects in the foreground due to data biases (such as using grassy backgrounds to predict the presence of cows (Beery et al., 2018) ), using textural as opposed to shape information to classify objects (Geirhos et al., 2018) , and using signals not robust to small adversarial perturbations (Ilyas et al., 2019) . Implicit in work that attempts to address these phenomena is the assumption that more robust predictive signals are indeed present in the training data, even if for various reasons current models do not have the tendency to leverage them. In this work, drawing inspiration from Quality-Diversity algorithms (Pugh et al., 2016) -which seek to construct a collection of high-performing, diverse solutions to a task -we aim to learn a collection of models, each incentivized to find a distinct, high-performing solution to a given supervised learning problem from a fixed training set. Informally, our motivation is that a sufficiently large collection of such distinct models would exploit robust signals present in the training data in addition to the brittle signals that current models tend to exploit. Thus, given the representations computed by such a collection, it may be possible to rapidly adapt to test-time shifts in distribution that destroy the predictive power of brittle features. Addressing this problem hinges on defining and enforcing an appropriate measure of model diversity. To this end, we make the following contributions: Figure 1 : Graphical representation of our framework. We train multiple models to individually minimize supervised loss while simultaneously minimizing conditional total correlation of final layer representations given the label. incentivizes models to each learn a non-redundant (in an information-theoretic sense) way to solve the given task. • We estimate this measure using a proxy variational estimator computed using samples from the conditional joint and marginal distributions of final layer representations across models. We train a collection of models to be accordingly diverse, alternating between training a variational critic to maximize the variational estimator and minimizing a weighted sum of the classification losses across models and the variational estimator. • We empirically validate this framework by training on datasets with synthetically injected spurious correlations. We demonstrate that our framework is able to learn a collection of representations which, under a linear fine-tuning protocol, are competitive with baselines in being able to rapidly adapt to a shift in distribution which destroys the spurious correlations. We show that our framework performs favourably relative to baselines in being able to isolate distinct predictive signals in different models in the collection, as a result being able to make accurate test-time predictions without fine-tuning. We also compare our approach to the Invariant Risk Minimization (Arjovsky et al., 2019) framework which leverages information from multiple training environments to identify signals robust to variations across environments. We show that our approach is able to exploit the robust signals without requiring the metadata needed by IRM to discriminate between spurious and robust signals.

2. TOTAL CORRELATION

Total correlation (Watanabe, 1960) is a multivariate extension of mutual information. The total correlation for n random variables X 1 , ..., X n is given by (where KL denotes the Kullback-Leibler divergence): T C(X 1 , ..., X n ) := KL[p(X 1 , ..., X n ) n i=1 p(X i )] For n = 2 this is equivalent to mutual information. We can interpret this as a measure of the amount of redundancy across the X i 's. A total correlation of zero corresponds to the X i 's being mutually independent. Given a classification problem predicting label Y from X, we consider the conditional total correlation of a collection of vector-valued representations h 1 (X), ..., h n (X) given Y for differentiable functions h i , defined as:



We propose and motivate a novel measure of model diversity based on conditional total correlation (across models) of final layer representations given the label. Informally, this

