OUTLIER-ROBUST GROUP INFERENCE VIA GRADIENT SPACE CLUSTERING

Abstract

Traditional machine learning models focus on achieving good performance on the overall training distribution, but they often underperform on minority groups. Existing methods can improve the worst-group performance, but they can have several limitations: (i) they require group annotations, which are often expensive and sometimes infeasible to obtain, and/or (ii) they are sensitive to outliers. Most related works fail to solve these two issues simultaneously as they focus on conflicting perspectives of minority groups and outliers. We address the problem of learning group annotations in the presence of outliers by clustering the data in the space of gradients of the model parameters. We show that data in the gradient space has a simpler structure while preserving information about minority groups and outliers, making it suitable for standard clustering methods like DBSCAN. Extensive experiments demonstrate that our method significantly outperforms state-of-the-art both in terms of group identification and downstream worst-group performance.

1. INTRODUCTION

Empirical Risk Minimization (ERM), i.e., the minimization of average training loss over the set of model parameters, is the standard training procedure in machine learning. It yields models with strong in-distribution performancefoot_0 but does not guarantee satisfactory performance on minority groups that contribute relatively few data points to the training loss function (Sagawa et al., 2019; Koh et al., 2021) . This effect is particularly problematic when the minority groups correspond to socially-protected groups. For example, in the toxic text classification task, certain identities are overwhelmingly abused in online conversations that form data for training models detecting toxicity (Dixon et al., 2018) . Such data lacks sufficient non-toxic examples mentioning these identities, yielding problematic and unfair spurious correlations -as a result ERM learns to associate these identities with toxicity (Dixon et al., 2018; Garg et al., 2019; Yurochkin & Sun, 2020) . A related phenomenon is subpopulation shift (Koh et al., 2021) , i.e., when the test distribution differs from the train distribution in terms of group proportions. Under subpopulation shift, poor performance on the minority groups in the train data translates into poor overall test distribution performance, where these groups are more prevalent or more heavily weighted. Subpopulation shift occurs in many application domains (Tatman, 2017; Beery et al., 2018; Oakden-Rayner et al., 2020; Santurkar et al., 2020; Koh et al., 2021) . Prior work offers a variety of methods for training models robust to subpopulation shift and spurious correlations, including group distributionally robust optimization (gDRO) (Hu et al., 2018; Sagawa et al., 2019) , importance weighting (Shimodaira, 2000; Byrd & Lipton, 2019 ), subsampling (Sagawa et al., 2020; Idrissi et al., 2022; Maity et al., 2022) , and variations of tilted ERM (Li et al., 2020; 2021) . These methods are successful in achieving comparable performance across groups in the data, but they require group annotations. The annotations can be expensive to obtain, e.g., labeling spurious backgrounds in image recognition (Beery et al., 2018) or labeling identity mentions in the toxicity example. It also could be challenging to anticipate all potential spurious correlations in advance, e.g., it could be background, time of day, camera angle, or unanticipated identities subject to harassment. Recently, methods have emerged for learning group annotations (Sohoni et al., 2020; Liu et al., 2021; Creager et al., 2021) and variations of DRO that do not require groups (Hashimoto et al., 2018; Zhai et al., 2021) . One common theme is to treat data where an ERM model makes mistakes (i.e., high-loss points) as a minority group (Hashimoto et al., 2018; Liu et al., 2021) and increase the weighting of these points. Unfortunately, such methods are at risk of overfitting to outliers (e.g., mislabeled data, corrupted images), which are also high-loss points. Indeed, existing methods for outlier-robust training propose to ignore the high-loss points (Shen & Sanghavi, 2019) , the direct opposite of the approach in (Hashimoto et al., 2018; Liu et al., 2021) . In this paper, our goal is to learn group annotations in the presence of outliers. Rather than using loss values (which above were seen to create opposing tradeoffs), we propose to instead first represent data using gradients of a datum's loss w.r.t. the model parameters. Such gradients tell us how a specific data point wants the parameters of the model to change to fit it better. In this gradient space, we anticipate groups (conditioned on label) to correspond to gradients forming clusters. Outliers, on the other hand, majorly correspond to isolated gradients: they are likely to want model parameters to change differently from any of the groups and other outliers. See Figure 1 for an illustration. The gradient space structure allows us to separate out the outliers and learn the group annotations via traditional clustering techniques such as DBSCAN (Ester et al., 1996) . We use learned group annotations to train models with improved worst-group performance (measured w.r.t. the true group annotations). We summarize our contributions below: • We show that gradient space simplifies the data structure and makes it easier to learn group annotations via clustering. • We propose Gradient Space Partitioning (GRASP), a method for learning group annotations in the presence of outliers for training models robust to subpopulation shift. • We conduct extensive experiments on one synthetic dataset and three datasets from different modalities and demonstrate that our method achieves state-of-the-art both in terms of group identification quality and downstream worst-group performance.

2. PRELIMINARIES AND RELATED WORK

In this section, we review the problem of training models in the presence of minority groups. Denote 



I.e. low loss on test data drawn from the same distribution as the training dataset.



Figure 1: An illustration of learning group annotations in the presence of outliers. (a) A toy dataset in two dimensions. There are four groups g = 1, 2, 3, 4 and an outlier. g = 1 and g = 3 are the majority groups distributed as mixtures of three components each; g = 2 and g = 4 are unimodal minority groups. y-axis is the decision boundary of a logistic regression classifier. Figures (b, c, d) compare different data views for learning group annotations and detecting outliers via clustering of samples with y = 0. (b) loss values can confuse outliers and minority samples which both can have high loss; (c) in the original feature space it is difficult to distinguish one of the majority group modes and the minority group; (d) gradient space (bias gradient omitted for visualization) simplifies the data structure making it easier to identify the minority group and to detect outliers.

[N ] = {1, . . . , N }. Consider a dataset D = {z} n i=1 ⊂ Z consisting of n samples z ∈ Z, z = (x, y), where x ∈ X = R d is the input feature and y ∈ Y = {1, . . . , C} is the class label. The samples from each class y ∈ Y are categorized into K y groups. Denote K to be the total number of groups {G 1 , . . . , G K } ≜ P ⊂ Z, where K = y∈Y K y and G k G k ′ = ∅ for all pairs k ̸ = k ′ ∈ [K]. Denote the group membership of each point in the dataset as {g i } n i=1 , where g i ∈ [K] for all i ∈ [n]. For example, in toxicity classification, a group could correspond to a toxic comment mentioning

