WHEN MAJORITIES PREVENT LEARNING: ELIMINAT-ING BIAS TO IMPROVE WORST-GROUP AND OUT-OF-DISTRIBUTION GENERALIZATION

Abstract

Modern neural networks trained on large datasets achieve state-of-the-art (indistribution) generalization performance on various tasks. However, their good generalization performance has been shown to be contributed largely to overfitting spurious biases in large datasets. This is evident by the poor generalization performance of such models on minorities and out-of-distribution data. To alleviate this issue, subsampling the majority groups has been shown to be very effective. However, it is not clear how to find the subgroups (e.g. within a class) in large real-world datasets. Besides, naively subsampling the majority groups can entirely deplete some of their smaller sub-populations and drastically harm the in-distribution performance. Here, we show that tracking gradient trajectories of examples in initial epochs allows for finding large subpopulations of data points. We leverage this observation and propose an importance sampling method that is biased towards selecting smaller subpopulations, and eliminates bias in the large subpopulations. Our experiments confirm the effectiveness of our approach in eliminating spurious biases and learning higher-quality models with superior in-and out-of-distribution performance on various datasets.

1. INTRODUCTION

Large datasets have enabled modern neural networks to achieve unprecedented success on various tasks. Large datasets are, however, often heavily biased towards the data-rich head of the distribution (Le Bras et al., 2020; Sagawa et al., 2020; 2019) . That means, there are large groups of potentially redundant data points belonging to majority subpopulations, and smaller groups of examples representing minorities. Larger groups often contain spurious biases, i.e., unintended but strong correlations between examples (e.g. image background) and their label. In such settings, overparameterized models learn to memorize the spurious features instead of the core features for the majority, and overfit the minorities (Sagawa et al., 2020) . As a result, despite their superior performance on in-distribution data, overparameterized models trained on biased datasets often have a poor worst-group and out-of-distribution generalization performance. To improve the high worst-group error and of out-of-distribution generalization, techniques such as distributionally robust optimization (DRO), or up-weighting the minority groups are commonly used (Sagawa et al., 2019; 2020) . However, such methods have been shown to be highly ineffective for overparameterized models in the presence of spurious features (Sagawa et al., 2020) . When majority groups are sufficiently large and the spurious features are strong, overparameterized models choose to exploit the spurious features for the majorities and memorize the minorities, as it entails less memorization on the entire data. In this setting, upweighting minorities only exacerbates spurious correlations, and subsampling the majorities has been advocated for (Sagawa et al., 2020) . But, this requires the groups to be specified beforehand, which is not available for real-world datasets. Besides, random subsampling of the majority groups can entirely deplete some of their subpopulations and drastically harm the in-distribution performance (Toneva et al., 2018; Paul et al., 2021) . In this work, we propose an effective way to find large subpopulations of examples (see Fig. 1 ), and subsample them to ensure inclusion of representative examples from all the subpopulations. We rely on the following recent observations. In the initial training epochs, the network learns important features and the NTK undergoes rapid changes, which determine its final basin of convergence (Fort et al., 2020) . This results in learning a linear function during the initial epochs, followed by learning functions of increasing complexity (Nakkiran et al., 2019) . We show that large subpopulations are responsible for forming the initial linear model, by inserting large gradient forces in the first few epochs. The minorities, on the other hand, dictate the higher-complexity functions later in training. To find the large subpopulations, we track the gradient trajectories-the way the gradient changesduring initial training epochs. Then, we cluster similar gradient trajectories together, and employ an importance sampling method that samples data points from every cluster by a probability equal to the inverse of the size of the cluster it belongs to. This allows selecting a balanced subset from different clusters. By studying the effect of our method on the evolution of the model early during the training, we show that our method allows the model to better learn from all the subpopulations by balancing the gradient forces between different groups. This enables learning higher-quality features. Our empirical studies confirm the effectiveness of our method in improving the worst-group and out-of-distribution generalization, while enjoying a superior in-distribution performance even when the size of the selected sample is small. Notably, on CMNIST (Alain et al., 2015) and Waterbird (Sagawa et al., 2019) datasets which contain strong spurious biases, our method achieves a comparable or even better performance than the state-of-the-art methods, which rely on the underlying group information to uniformly subsample the majority group. In addition, on CIFAR10, CIFAR100 (Krizhevsky et al., 2009), and Caltech256 (Griffin et al., 2007) our method provides a superior indistribution performance to state-of-the-art data pruning methods, based on forgettability (Toneva et al., 2018) and El2N (Paul et al., 2021) scores, especially for small subsets. At the same time, it outperforms such methods on out-of-distribution data, CIFAR10C (Hendrycks & Dietterich, 2019).

2. RELATED WORK

Data pruning for worst-group generalization. To improve the generalization performance on minorities, preventing the model from learning spurious features is very helpful (Sagawa et al., 2019; 2020) . For overparameterized models, randomly subsampling the majorities has been shown to be the most effective (Sagawa et al., 2020) than distributionally robust optimization (DRO) (Sagawa et al., 2019) and up-weighting the minority groups (Sagawa et al., 2020) . However, this requires the group labels to be specified beforehand, which is not available for large real-world datasets. Besides, if the majority contains imbalanced subpopulations, random subsampling inherits similar biases. Finally, random subsampling of the majority groups can entirely deplete some of their smaller subpopulations and drastically harm the in-distribution performance, as we empirically show. A different line of work (Sohoni et al., 2020; Nam et al., 2020; Ahmed et al., 2020; Liu et al., 2021; Creager et al., 2021; Taghanaki et al., 2021; Zhang et al., 2022; Nam et al., 2021) studies how to improve worst-group generalization without having access to group labels. These methods require training a model first to minimize the average empirical risk before training the robust model, which doubles the training time and is thus also not practical for large real-world datasets.



Figure 1: An illustration of large/small subpopulations within the majority/minority groups in the same class of the Waterbirds dataset (Sagawa et al., 2019).

