IMPROVED GROUP ROBUSTNESS VIA CLASSIFIER RE-TRAINING ON INDEPENDENT SPLITS

Abstract

Deep neural networks learned by minimizing the average risk can achieve strong average performance, but their performance for a subgroup may degrade, if the subgroup is underrepresented in the overall data population. Group distributionally robust optimization (Sagawa et al., 2020a, GDRO) is a standard baseline for learning models with strong worst-group performance. However, GDRO requires group labels for every example during training and can be prone to overfitting, often requiring careful model capacity control via regularization or early stopping. When only a limited amount of group labels is available, Just Train Twice (Liu et al., 2021, JTT) is a popular approach which infers a pseudo-group-label for every unlabeled example. The process of inferring pseudo labels can be highly sensitive during model selection. To alleviate overfitting for GDRO and the pseudo labeling process for JTT, we propose a new method via classifier retraining on independent splits (of the training data). We find that using a novel sample splitting procedure achieves robust worst-group performance in the fine-tuning step. When evaluated on benchmark image and text classification tasks, our approach consistently reduces the requirement of group labels and hyperparameter search during training. Experimental results confirm that our approach performs favorably compared with existing methods (including GDRO and JTT) when either group labels are available during training or are only available during validation.

1. INTRODUCTION

For many tasks, deep neural networks (DNNs) are often developed where the test distribution is identical to and independent of the train distribution, which can be referred to as IID generalization. The performance of a DNN is also known to worsen when the testing distribution differs from the training distribution. This problem is often referred to as out-of-distribution (OOD) generalization. OOD generalization is crucial in safety-critical applications such as self-driving cars (Filos et al., 2020) or medical imaging (Oakden-Rayner et al., 2020) . Hence, addressing the problem of OOD generalization is foundational for real-world deployment of deep learning models. A notable setting where OOD generalization problems appear is the group-shift setting, where different groups of the data may have a distribution shift. In this setting, there are predefined attributes that divide the input space into different groups of interest. Here, the goal is to find a model that performs well across several predefined groups (Sagawa et al., 2020a) . Similar to other problems in OOD generalization, DNNs learned by empirical risk minimization (ERM) are observed to suffer from poor worst-group performance despite good average-group performance. The difficulty with learning group robust DNNs can be attributed to the phenomenon of shortcut learning (Geirhos et al., 2020) or spurious correlation (Sagawa et al., 2020a; Arjovsky et al., 2019) . Shortcut learning poses that ERM favors those models that discriminate based on simpler and/or spurious features of the data. However, one wishes for the learning algorithm to produce a model that uses features (i.e. correlations) that performs well not only on the train distribution, but also on all potential distributions that a task may generate, like that of a worst-group distribution. In recent years, the group-shift setting has received considerable attention, where Sagawa et al. (2020a) first investigates distributional robust optimization (DRO) (Duchi et al., 2021; Ben-Tal et al., 2013) in this setting and introduces Group DRO (GDRO) that attempts to directly optimize for the worst-group error. Since then, GDRO has been the standard method for producing group-robust As group annotations can be expensive to obtain, many works consider ways to reduce the amount of groups labels needed (Liu et al., 2021; Creager et al., 2021; Nam et al., 2022) . These methods usually follow the framework of first inferring pseudo-group-labels using a certain referenced model (pseudo-labeling) and then applying a group-robust algorithm like GDRO on the pseudo-labelled data. While results for these methods have been promising, they usually introduce several more sensitive hyperparameters that can be expensive to tune. Our work aims to obtain group robust models using as few group labels as possible while alleviating the need to carefully control model capacity. Our contribution. In this work, we propose a simple approach, called CROIS. By foregoing the error-prone and costly pseudo-labeling phase to instead concentrate on efficiently utilizing group labels by applying them only to the final classifier layer, CROIS achieves good robust performance without relying on a multitude of hyperparameters and large scale tuning, which has been a growing concern in the community (Gulrajani and Lopez-Paz, 2021). In short, CROIS takes advantage of the features learned by ERM (Kang et al., 2019; Menon et al., 2021) while overcoming the deficiency of its memorization behavior (Sagawa et al., 2020b ) by utilizing the training data as two independent splits: one group-unlabeled split to train the feature extractor and one group-labeled split to retrain only the classifier with a robust algorithm like GDRO. We demonstrate through ablation studies that the use of independent splits is crucial for robust classifier retraining. Furthermore, CROIS's restriction of GDRO to only a low-capacity linear layer reduces GDRO's sensitivity towards model capacity as well as the amount of data needed for GDRO to generalize well (e.g. Figure 1 and Figure 3 ). For various settings where group labels are only partially available during training (Section 4.1), our experimental results on standard datasets including Waterbird, CelebA, MultiNLI, and CivilComments show improved performance over existing methods including JTT (Liu et al., 2021) and SSA (Nam et al., 2022) , despite minimal parameter tuning and no reliance on pseudo labeling. In another setting where more group labels are available (Section 4.2), even when using only a fraction of the available group labels, our competitive robust performance against GDRO demonstrates our method's label efficiency. Finally, our results provide further evidences of ERM trained DNNs containing good features on both image classification (Menon et al., 2021) and natural language classification tasks.

1.1. RELATED WORKS

There are three main settings for the group-shift problem: (1) full availability of group labels, (2) limited availability of group labels, and (3) no availability of group label. Other related areas include domain generalization and long-tailed classification.



Worst-group learning curve on the Waterbird dataset between GDRO (left) and CROIS (right) from the setting in Section 4.2. In the left figure, the validation accuracy of GDRO becomes unstable as the number of epochs increase beyond 100 while the training accuracy remains close to 100%. Our approach CROIS instead uses 30% of the training data with group labels for fine-tuning the classification layer of a DNN that is obtained via ERM using the rest of the training data without group labels (see also Algorithm 1). This allows CROIS to reuse ERM features while improving generalization compared with GDRO.

