DISTILLING MODEL FAILURES AS DIRECTIONS IN LATENT SPACE

Abstract

Existing methods for isolating hard subpopulations and spurious correlations in datasets often require human intervention. This can make these methods laborintensive and dataset-specific. To address these shortcomings, we present a scalable method for automatically distilling a model's failure modes. Specifically, we harness linear classifiers to identify consistent error patterns, and, in turn, induce a natural representation of these failure modes as directions within the feature space. We demonstrate that this framework allows us to discover and automatically caption challenging subpopulations within the training dataset. Moreover, by combining our framework with off-the-shelf diffusion models, we can generate images that are especially challenging for the analyzed model, and thus can be used to perform synthetic data augmentation that helps remedy the model's failure modes.

1. INTRODUCTION

The composition of the training dataset has key implications for machine learning models' behavior (Feldman, 2019; Carlini et al., 2019; Koh & Liang, 2017; Ghorbani & Zou, 2019; Ilyas et al., 2022) , especially as the training environments often deviate from deployment conditions (Rabanser et al., 2019; Koh et al., 2020; Hendrycks et al., 2020) . For example, a model might struggle on specific subpopulations in the data if that subpopulation was mislabeled (Northcutt et al., 2021; Stock & Cisse, 2018; Beyer et al., 2020; Vasudevan et al., 2022) , underrepresented (Sagawa et al., 2020; Santurkar et al., 2021) , or corrupted (Hendrycks & Dietterich, 2019; Hendrycks et al., 2020) . More broadly, the training dataset might contain spurious correlations, encouraging the model to depend on prediction rules that do not generalize to deployment (Xiao et al., 2020; Geirhos et al., 2020; DeGrave et al., 2021) . Moreover, identifying meaningful subpopulations within data allows for dataset refinement (such as filtering or relabeling) (Yang et al., 2019; Stock & Cisse, 2018) , and training more fair (Kim et al., 2019; Du et al., 2021) or accurate (Jabbour et al., 2020; Srivastava et al., 2020) models. However, dominant approaches to such identification of biases and difficult subpopulations within datasets often require human intervention, which is typically labor intensive and thus not conducive to routine usage. For example, recent works (Tsipras et al., 2020; Vasudevan et al., 2022) need to resort to manual data exploration to identify label idiosyncrasies and failure modes in widely used datasets such as ImageNet. On the other hand, a different line of work (Sohoni et al., 2020; Nam et al., 2020; Kim et al., 2019; Liu et al., 2021; Hashimoto et al., 2018) does present automatic methods for identifying and intervening on hard examples, but these methods are not designed to capture simple, human-understandable patterns. For instance, Liu et al. ( 2021) directly upweights inputs that were misclassified early in training, but these examples do not necessarily represent a consistent failure mode. This motivates the question:

How can we extract meaningful patterns of model errors on large datasets?

One way to approach this question is through model interpretability methods. These include saliency maps (Adebayo et al., 2018; Simonyan et al., 2013 ), integrated gradients (Sundararajan et al., 2017 ), and LIME (Ribeiro et al., 2016b) , and perform feature attribution for particular inputs. Specifically, they aim to highlight which parts of an input were most important for making a model prediction, and can thus hint at brittleness of that prediction. However, while feature attribution can indeed help debug individual test examples, it does not provide a global understanding of the underlying biases of the dataset -at least without manually examining many such individual attributions.

OUR CONTRIBUTIONS

In this work, we build a scalable mechanism for globally understanding large datasets from the perspective of the model's prediction rules. Specifically, our goal is not only to identify interpretable failure modes within the data, but also to inform actionable interventions to remedy these problems. Our approach distills a given model's failure modes as directions in a certain feature space. In particular, we train linear classifiers on normalized feature embeddings within that space to identify consistent mistakes in the original model's predictions. The decision boundary of each such classifier then defines a "direction" of hard examples. By measuring each data-point's alignment with this identified direction, we can understand how relevant that example is for the failure mode we intend to capture. We leverage this framework to: • Detection: Automatically detect and quantify reliability failures, such as brittleness to distribution shifts or performance degradation on hard subpopulations (Section 2.1). • Interpretation: Understand and automatically assign meaningful captions to the error patterns identified by our method (Section 2.2). • Intervention: Intervene during training in order to improve model reliability along the identified axes of failure (Section 2.3). In particular, by leveraging our framework in conjunction with off-the-shelf diffusion models, we can perform synthetic data augmentation tailored to improve the analyzed model's mistakes. Using our framework, we can automatically identify and intervene on hard subpopulations in image datasets such as CIFAR-10, ImageNet, and ChestX-ray14. Importantly, we do not require direct human intervention or pre-annotated subgroups. The resulting framework is thus a scalable approach to identifying important subpopulations in large datasets with respect to their downstream tasks.

2. CAPTURING FAILURE MODES AS DIRECTIONS WITHIN A LATENT SPACE

The presence of certain undesirable patterns, such as spurious correlations or underrepresented subpopulations, in a training dataset can prevent a learned model from properly generalizing during deployment. As a running example, consider the task of distinguishing "old" versus "young" faces, wherein the training dataset age is spuriously correlated with gender (such that the faces of older men and younger women are overrepresented). Such correlations occur in the CelebA dataset (Liu et al., 2015) (though here we construct a dataset that strengthens them)foot_0 . Thus, a model trained on such a dataset might rely too heavily on gender, and will struggle to predict the age of younger men or older women. How can we detect model failures on these subpopulations? The guiding principle of our framework is to model such failure modes as directions within a certain latent space (Figure 1 ). In the above example, we would like to identify an axis such that the (easier) examples of "old men" and the (harder) examples of "old women" lie in opposite directions. We then can capture the role of an individual data point in the dataset by evaluating how closely its normalized embedding aligns with that extracted direction (axis). But how can we learn these directions? Our method. The key idea of our approach is to find a hyperplane that best separates the correct examples from incorrect ones. In the presence of global failure modes such as spurious correlations, the original model will likely make consistent mistakes, and these mistakes will share features. Using a held out validation set, we can therefore train a linear support vector machine (SVM) for each class to predict the original model's mistakes based on these shared features. The SVM then establishes a decision boundary between the correct and incorrect examples, and the direction of the failure mode will be orthogonal to this decision boundary (i.e., the normal vector to the hyperplane). Intuitively, the more aligned an example is with the identified failure direction, the harder the example was for the original neural network. Details of our method can be found in Appendix A. The choice of latent space: Leveraging shared vision/language embeddings. Naturally, the choice of embedding space for the SVM greatly impacts the types of failure modes it picks up. Which embedding space should we choose? One option is to use the latent space of the original neural network. However, especially if the model fits the training data perfectly, it has likely learned latent



We can also detect this failure mode in the original CelebA dataset (See Appendix B.1)

