BIAS AMPLIFICATION IMPROVES WORST-GROUP ACCURACY WITHOUT GROUP INFORMATION

Abstract

Neural networks produced by standard training are known to suffer from poor accuracy on rare subgroups despite achieving high accuracy on average, due to the correlations between certain spurious features and labels. Previous approaches based on worst-group loss minimization (e.g. Group-DRO) are effective in improving worse-group accuracy but require expensive group annotations for all the training samples. In this paper, we focus on the more challenging and realistic setting where group annotations are only available on a small validation set or are not available at all. We propose BAM, a novel two-stage training algorithm: in the first stage, the model is trained using a bias amplification scheme via introducing a learnable auxiliary variable for each training sample together with the adoption of squared loss; in the second stage, we upweight the samples that the bias-amplified model misclassifies, and then continue training the same model on the reweighted dataset. Empirically, BAM leads to consistent improvement over its counterparts in worst-group accuracy, resulting in state-of-the-art performance in spurious correlation benchmarks in computer vision and natural language processing. Moreover, we find a simple stopping criterion that completely removes the need for group annotations, with little or no loss in worst-group accuracy.

1. INTRODUCTION

The presence of spurious correlations in the data distribution, also referred to as "shortcuts" (Geirhos et al., 2020) , is known to cause machine learning models to generate unintended decision rules that rely on spurious features. For example, image classifiers can largely use background instead of the intended combination of object features to make predictions (Beery et al., 2018) . Similar phenomenon is also prevalent in natural language processing (Gururangan et al., 2018) and reinforcement learning (Lehman et al., 2020) . In this paper, we focus on the group robustness formulation of such problems (Sagawa et al., 2019) , where we assume the existence of spurious attributes in the training data and define groups to be the combination of class labels and spurious attributes. The objective is to achieve high worst-group accuracy on test data, which would indicate that the model is not exploiting the spurious attributes. Under this setup, one type of method uses a distributionally robust optimization framework to directly minimize the worst-group training loss (Sagawa et al., 2019) . While these methods are effective in improving worst-group accuracy, they require knowing the group annotations for all training examples, which is expensive and oftentimes unrealistic. In order to resolve this issue, a line of recent work focused on designing methods that do not require group annotations for the training data, but need them for a small set of validation data (Liu et al., 2021; Nam et al., 2020; 2022; Zhang et al., 2022) . A common feature shared by these methods is that they all consist of training two models: the first model is trained using plain empirical risk minimization (ERM) and is intended to be "biased" toward certain groups; then, certain results from the first model are utilized to train a debiased second model to achieve better worst-group performance. For instance, a representative method is JTT (Liu et al., 2021) , which, after training the first model using ERM for a few epochs, trains the second model while upweighting the training examples incorrectly classified by the first model. The core question that motivates this paper is: Since the first model is intended to be biased, can we amplify its bias in order to improve the final group robustness? Intuitively, a bias-amplified first model can provide better information to guide the second model to be debiased, which can Evaluated on various benchmark datasets for spurious correlations, including Waterbirds (Wah et al., 2011; Sagawa et al., 2019 ), CelebA (Liu et al., 2015; Sagawa et al., 2019) , MultiNLI (Williams et al., 2018; Sagawa et al., 2019) , and CivilComments-WILDS (Borkan et al., 2019; Koh et al., 2021) , we find that BAM achieves state-of-the-art worst-group accuracy compared to existing methods that only use group annotations on a validation set for hyperparameter tuning. We also conduct a detailed ablation study and observe that every component in BAM (auxiliary variables, squared loss, continued training) is crucial in its improved performance. Furthermore, we explore the possibility of completely removing the need for group annotations. We find that low-class accuracy difference (which does not require any group annotations to evaluate) is strongly correlated with high worst-group accuracy. Using minimum class accuracy difference as the stopping criterion, BAM outperforms the previous state-of-the-art annotation-free method, GEORGE (Sohoni et al., 2020) , by a considerable margin and closes the performance gap between GEORGE and fully-supervised Group-DRO by an average of 88% on the image classification datasets.

2. RELATED WORKS

A variety of recent work discussed different realms of robustness, for instance, class imbalance (He & Garcia, 2009; Huang et al., 2016; Khan et al., 2017; Johnson & Khoshgoftaar, 2019; Thabtah et al., 2020) , and robustness in distribution shift, where the target data distribution is different from the source data distribution (Clark et al., 2019; Zhang et al., 2020; Marklund et al., 2020; Lee et al., 2022; Yao et al., 2022) . In this paper, we mainly focus on improving group robustness. Categorized by the amount of information we have for training and validation, we discuss three directions below: Improving Group Robustness with Training Group Annotations. Multiple works have used training group annotations to improve worst-group accuracy (Byrd & Lipton, 2019; Khani et al., 2019; Goel et al., 2020; Cao et al., 2020; Sagawa et al., 2020) . Other works include minimizing the worst-group training loss using distributionally robust optimization (Group-DRO) (Sagawa et al., 2019) , simple training data balancing (SUBG) (Idrissi et al., 2022) , and retraining the last layer of the



In Figure1, we use Grad-CAM visualization to illustrate that our bias-amplified model from Stage 1 focuses more on the image background while the final model after Stage focuses on the object target.



Figure 1: Using Grad-CAM (Selvaraju et al., 2017) to visualize the effect of bias amplification and rebalanced training stages, where the classifier heavily relies on the background information to make predictions after bias amplification but focuses on the useful feature (bird) itself after the rebalanced training stage.

