AVOIDING SPURIOUS CORRELATIONS VIA LOGIT COR-RECTION

Abstract

Empirical studies suggest that machine learning models trained with empirical risk minimization (ERM) often rely on attributes that may be spuriously correlated with the class labels. Such models typically lead to poor performance during inference for data lacking such correlations. In this work, we explicitly consider a situation where potential spurious correlations are present in the majority of training data. In contrast with existing approaches, which use the ERM model outputs to detect the samples without spurious correlations and either heuristically upweight or upsample those samples, we propose the logit correction (LC) loss, a simple yet effective improvement on the softmax cross-entropy loss, to correct the sample logit. We demonstrate that minimizing the LC loss is equivalent to maximizing the group-balanced accuracy, so the proposed LC could mitigate the negative impacts of spurious correlations. Our extensive experimental results further reveal that the proposed LC loss outperforms state-of-the-art solutions on multiple popular benchmarks by a large margin, an average 5.5% absolute improvement, without access to spurious attribute labels. LC is also competitive with oracle methods that make use of the attribute labels.

1. INTRODUCTION

In practical applications such as self-driving cars, a robust machine learning model must be designed to comprehend its surroundings in rare conditions that may not have been well-represented in its training set. However, deep neural networks can be negatively affected by spurious correlations between observed features and class labels that hold for well-represented groups but not for rare groups. For example, when classifying stop signs versus other traffic signs in autonomous driving, 99% of the stop signs in the United States are red. A model trained with standard empirical risk minimization (ERM) may learn models with low average training error that rely on the spurious background attribute instead of the desired "STOP" text on the sign, resulting in high average accuracy but low worst-group accuracy (e.g., making errors on yellow color or faded stop signs). This demonstrates a fundamental issue: models trained on such datasets could be systematically biased due to spurious correlations presented in their training data (Ben-Tal et al., 2013; Rosenfeld et al., 2018; Beery et al., 2018; Zhang et al., 2019) . Such biases must be mitigated in many fields, including algorithmic fairness (Du et al., 2021) , machine learning in healthcare (Oakden-Rayner et al., 2020; Liu et al., 2020b; 2022a) , and public policy Rodolfa et al. (2021) . Formally, spurious correlations occur when the target label is mistakenly associated with one or more confounding factors presented in the training data. The group of samples in which the spurious correlations occur is often called the majority group since spurious correlations are expected to occur in most samples, while the minority groups contain samples whose features are not spuriously correlated. The performance degradations of ERM on a dataset with spurious correlation (Nagarajan et al., 2021; Nguyen et al., 2021) are caused by two main reasons: 1) the geometric skew and 2) the statistical skew. For a robust classifier, the classification margin on the minority group should be much larger than that of the majority group (Nagarajan et al., 2021) . However, a classifier trained with ERM maximizes margins and therefore leads to equal training margins for the majority and minority groups. This results in geometric skew. The statistical skew is caused by slow convergence of gradient descent, which may cause the network to first learn the "easy-to-learn" spurious attributes instead of the true label information and rely on it until being trained for long enough (Nagarajan et al., 2021; Liu et al., 2020a; 2022b) . To determine whether samples are from the majority or minority groups, we need to know the group information during training, which is impractical. Therefore, many existing approaches consider the absence of group information and first detect the minority group (Nguyen et al., 2021; Liu et al., 2021b; Nam et al., 2020) and then upweight/upsample the samples in the minority group during training (Li & Vasconcelos, 2019; Nam et al., 2020; Lee et al., 2021; Liu et al., 2021a) . While intuitive, upweighting only addresses the statistical skew (Nguyen et al., 2021) , and it is often hard to define the weighted loss with an optimal upweighting scale in practice. Following Menon et al. (2013) ; Collell et al. ( 2016) on learning from imbalanced data, we argue that the goal of training a debiased model is to achieve a high average accuracy over all groups (Group-Balanced Accuracy, GBA, defined in Sec. 3), implying that the training loss should be Fisher consistent with GBA (Menon et al., 2013; Collell et al., 2016) . In other words, the minimizer of the loss function should be the maximizer of GBA. In this paper, we revisit the logit adjustment method (Menon et al., 2021) for long-tailed datasets, and propose a new loss called logit correction (LC) to reduce the impact of spurious correlations. We show that the proposed LC loss is able to mitigate both the statistical and the geometrical skews that cause performance degradation. More importantly, under mild conditions, its solution is Fisher consistent for maximizing GBA. In order to calculate the corrected logit, we study the spurious correlation and propose to use the outputs of the ERM model to estimate the group priors. To further reduce the geometrical skew, based on MixUp (Zhang et al., 2018) , we propose a simple yet effective method called Group MixUp to synthesize samples from the existing ones and thus increase the number of samples in the minority groups. The main contributions of our work include: • We propose logit correction loss to mitigate spurious correlations during training. The loss ensures the Fisher consistency with GBA and alleviates statistical and geometric skews. • We propose the Group MixUp method to increase the diversity of the minority group and further reduce the geometrical skew. • The proposed method significantly improves GBA and the worst-group accuracy when the group information is unknown. With only 0.5% of the samples from the minority group, the proposed method improves the accuracy by 6.03% and 4.61% on the Colored MNIST dataset and Corrupted CIFAR-10 dataset, respectively, over the state-of-the-art.

2. RELATED WORK

Spurious correlations are ubiquitous in real-world datasets. A typical mitigating solution requires to first detect the minority groups and then design a learning algorithm to improve the group-balanced accuracy and/or the worst-group accuracy. We review existing approaches based on these two steps. Detecting Spurious Correlations. Early researches often rely on the predefined spurious correlations (Kim et al., 2019; Sagawa et al., 2019; Li & Vasconcelos, 2019) . While effective, annotating the spurious attribute for each training sample is very expensive and sometimes impractical. Solutions that do not require spurious attribute annotation have recently attracted a lot of attention. Many of the existing works (Sohoni et al., 2020; Nam et al., 2020; Liu et al., 2021a; Zhang et al., 2022) et al., 2021; Nguyen et al., 2021) show that the geometric skew and the statistical skew are the two main reasons hurting the performance of the conventional ERM model. Reweighting (resampling), which assigns higher weights (sampling rates) to minority samples, is commonly used to remove the statistical skew (Li & Vasconcelos, 2019;  



assume that the ERM model tend to focus on spurious attribute (but may still learn the core features Kirichenko et al. (2022); Wei et al. (2023)), thus "hard" examples (whose predicted labels conflict with the ground-truth label) are likely to be in the minority group. Sohoni et al. (2020); Seo et al. (2022), on the other hand, propose to estimate the unknown group information by clustering. Our work follows the path of using the ERM model.

