AGRO: ADVERSARIAL DISCOVERY OF ERROR-PRONE GROUPS FOR ROBUST OPTIMIZATION

Abstract

Models trained via empirical risk minimization (ERM) are known to rely on spurious correlations between labels and task-independent input features, resulting in poor generalization to distributional shifts. Group distributionally robust optimization (G-DRO) can alleviate this problem by minimizing the worst-case loss over a set of pre-defined groups over training data. G-DRO successfully improves performance of the worst-group, where the correlation does not hold. However, G-DRO assumes that the spurious correlations and associated worst groups are known in advance, making it challenging to apply it to new tasks with potentially multiple unknown spurious correlations. We propose AGRO-Adversarial Group discovery for Distributionally Robust Optimization-an end-to-end approach that jointly identifies error-prone groups and improves accuracy on them. AGRO equips G-DRO with an adversarial slicing model to find a group assignment for training examples which maximizes worst-case loss over the discovered groups. On the WILDS benchmark, AGRO results in 8% higher model performance on average on known worst-groups, compared to prior group discovery approaches used with G-DRO. AGRO also improves out-of-distribution performance on SST2, QQP, and MS-COCO-datasets where potential spurious correlations are as yet uncharacterized. Human evaluation of ARGO groups shows that they contain well-defined, yet previously unstudied spurious correlations that lead to model errors.

1. INTRODUCTION

Neural models trained using the empirical risk minimization principle (ERM) are highly accurate on average; yet they consistently fail on rare or atypical examples that are unlike the training data. Such models may end up relying on spurious correlations (between labels and task-independent features), which may reduce empirical loss on the training data but do not hold outside the training distribution (Koh et al., 2021; Hashimoto et al., 2018) . Figure 1 shows examples of such correlations in the MultiNLI and CelebA datasets. Building models that gracefully handle degradation under distributional shifts is important for robust optimization, domain generalization, and fairness (Lahoti et al., 2020; Madry et al., 2017) . When the correlations are known and training data can be partitioned into dominant and rare groups, group distributionally robust optimization (G- DRO, Sagawa et al., 2019) can efficiently minimize the worst (highest) expected loss over groups and improve performance on the rare group. A key limitation of G-DRO is the need for a pre-defined partitioning of training data based on a known spurious correlation; but such correlations may be unknown, protected or expensive to obtain. In this paper, we present AGRO-Adversarial Group discovery for Distributional Robust Optimization-an end-to-end unsupervised optimization technique that jointly learns to find error-prone training groups and minimize expected loss on them. Prior work on group discovery limits the space of discoverable groups for tractability. For example, Wu et al. (2022) use prior knowledge about the task to find simple correlations-e.g. presence of negation in the text is correlated with the contradiction label (Figure 1 ). However, such task-specific approaches do not generalize to tasks with different and/or unknown (types of) spurious correlations. Approaches using generalizable features are semi-supervised (Sohoni et al., 2020; Liu et al., 2021) in that they assume access to group information on a held-out dataset. However, obtaining supervision for group assignments is costly and can lead to cascading pipeline errors. In contrast, AGRO is completely unsupervised and end-to-end while making no assumptions about the nature of the task and availability of additional supervision. To address these challenges in AGRO, we construct a new parameterized grouper model that produces a soft distribution over groups for every example in the training data and is jointly trained with the task model. We introduce two key contributions to train this model. Firstly, the grouper model does not make task-specific assumptions about its inputs. Instead, it relies on computationally extracted features from the ERM model including: (a) predictions and mistakes of the ERM model on training and validation instances, and (b) pretrained dataset-agnostic representations and representations fine-tuned on the task data. Secondly, AGRO jointly optimizes the task model and grouper model. We formulate a zero-sum game between the grouper model that assigns instances to groups and the robust model which seeks to minimize the worst expected loss over the set of inferred groups. Specifically, while G-DRO optimizes the robust model to minimize the worst group-lossfoot_0 , the grouper model adversarially seeks a probabilistic group assignment such that the worst group-loss is maximized. On four datasets in the WILDS benchmark (Koh et al., 2021) (MultiNLI, CivilComments, CelebA, and Waterbirds), AGRO simultaneously improves performance on multiple worst-groupsfoot_1 corresponding to previously characterized spurious correlations, compared to ERM and G-DRO with known group assignment. AGRO also improves worst-group performance over prior approaches that find spurious correlations and groups by 8% on average, establishing a new SOTA for such methods on two of the WILDS datasets. On natural language inference, sentiment analyses, paraphrase detection and common-object classification (COCO), AGRO improves robustness to uncharacterized distributional shifts compared to prior approaches, as demonstrated by gains in out-of-distribution datasets for these tasks. Ablations on different parts of the framework underscore the need for a generalizable feature space and end-to-end optimization. We develop a novel annotation task for humans to analyze the discovered AGRO groups-distinguishing group members from random examples and perturbing them to potentially change model predictions. We find that humans can identify existing and previously unknown features in AGRO groups that lead to systematic model errors and are potentially spurious, such as the correlation between antonyms and contradiction in MultiNLI, or the correlation between hats, sunglasses and short hair with non-blondes in CelebA. Our code and models are publicfoot_2 .

2. RELATED WORK

Group distributionally robust optimization G-DRO is a popular variant of distributionally robust optimization (DRO) (Ben-Tal et al., 2013; Duchi et al., 2016) that optimizes for robustness over various types of sub-population (group) shifts: label-shifts (Hu et al., 2018) , shift in data-sources or domains (Oren et al., 2019) , or test-time distributional shifts due to spurious correlations in training



Largest group-wise loss in a minibatch A group of examples where the spurious correlation between feature and label does not hold. https://github.com/bhargaviparanjape/robust-transformers



Figure 1: Groups discovered by AGRO on CelebA image classification dataset and MultiNLI sentence pair classification dataset.

