DISTILLING MODEL FAILURES AS DIRECTIONS IN LATENT SPACE

Abstract

Existing methods for isolating hard subpopulations and spurious correlations in datasets often require human intervention. This can make these methods laborintensive and dataset-specific. To address these shortcomings, we present a scalable method for automatically distilling a model's failure modes. Specifically, we harness linear classifiers to identify consistent error patterns, and, in turn, induce a natural representation of these failure modes as directions within the feature space. We demonstrate that this framework allows us to discover and automatically caption challenging subpopulations within the training dataset. Moreover, by combining our framework with off-the-shelf diffusion models, we can generate images that are especially challenging for the analyzed model, and thus can be used to perform synthetic data augmentation that helps remedy the model's failure modes.

1. INTRODUCTION

The composition of the training dataset has key implications for machine learning models' behavior (Feldman, 2019; Carlini et al., 2019; Koh & Liang, 2017; Ghorbani & Zou, 2019; Ilyas et al., 2022) , especially as the training environments often deviate from deployment conditions (Rabanser et al., 2019; Koh et al., 2020; Hendrycks et al., 2020) . For example, a model might struggle on specific subpopulations in the data if that subpopulation was mislabeled (Northcutt et al., 2021; Stock & Cisse, 2018; Beyer et al., 2020; Vasudevan et al., 2022 ), underrepresented (Sagawa et al., 2020; Santurkar et al., 2021) , or corrupted (Hendrycks & Dietterich, 2019; Hendrycks et al., 2020) . More broadly, the training dataset might contain spurious correlations, encouraging the model to depend on prediction rules that do not generalize to deployment (Xiao et al., 2020; Geirhos et al., 2020; DeGrave et al., 2021) . Moreover, identifying meaningful subpopulations within data allows for dataset refinement (such as filtering or relabeling) (Yang et al., 2019; Stock & Cisse, 2018) , and training more fair (Kim et al., 2019; Du et al., 2021) or accurate (Jabbour et al., 2020; Srivastava et al., 2020) models. However, dominant approaches to such identification of biases and difficult subpopulations within datasets often require human intervention, which is typically labor intensive and thus not conducive to routine usage. For example, recent works (Tsipras et al., 2020; Vasudevan et al., 2022) need to resort to manual data exploration to identify label idiosyncrasies and failure modes in widely used datasets such as ImageNet. On the other hand, a different line of work (Sohoni et al., 2020; Nam et al., 2020; Kim et al., 2019; Liu et al., 2021; Hashimoto et al., 2018) does present automatic methods for identifying and intervening on hard examples, but these methods are not designed to capture simple, human-understandable patterns. For instance, Liu et al. ( 2021) directly upweights inputs that were misclassified early in training, but these examples do not necessarily represent a consistent failure mode. This motivates the question: How can we extract meaningful patterns of model errors on large datasets? One way to approach this question is through model interpretability methods. These include saliency maps (Adebayo et al., 2018; Simonyan et al., 2013 ), integrated gradients (Sundararajan et al., 2017 ), and LIME (Ribeiro et al., 2016b) , and perform feature attribution for particular inputs. Specifically, they aim to highlight which parts of an input were most important for making a model prediction, and can thus hint at brittleness of that prediction. However, while feature attribution can indeed help debug individual test examples, it does not provide a global understanding of the underlying biases of the dataset -at least without manually examining many such individual attributions.

