BREAKING CORRELATION SHIFT VIA CONDITIONAL INVARIANT REGULARIZER

Abstract

Recently, generalization on out-of-distribution (OOD) data with correlation shift has attracted great attentions. The correlation shift is caused by the spurious attributes that correlate to the class label, as the correlation between them may vary in training and test data. For such a problem, we show that given the class label, the models that are conditionally independent of spurious attributes are OOD generalizable. Based on this, a metric Conditional Spurious Variation (CSV) which controls the OOD generalization error, is proposed to measure such conditional independence. To improve the OOD generalization, we regularize the training process with the proposed CSV. Under mild assumptions, our training objective can be formulated as a nonconvex-concave mini-max problem. An algorithm with a provable convergence rate is proposed to solve the problem. Extensive empirical results verify our algorithm's efficacy in improving OOD generalization.

1. INTRODUCTION

The success of standard learning algorithms rely heavily on the identically distributed assumption of training and test data. However, in real-world, such assumption is often violated due to the varying circumstances, selection bias, and other reasons (Meinshausen & Bühlmann, 2015) . Thus, learning a model that generalizes on out-of-distribution (OOD) data has attracted great attentions. The OOD data (Ye et al., 2022) can be categorized into data with diversity shift or correlation shift. Roughly speaking, there is a mismatch of the spectrum and a spurious correlation between training and test distributions under the two shifts, respectively. Compared with diversity shift, correlation shift is less explored (Ye et al., 2022) , while the misleading spurious correlation works for training data may deteriorate model's performance on test data (Beery et al., 2018) . The correlation shift says, for the spurious attributes in data, there exists variation of (spurious) correlation between class label and such spurious attributes from training to test data (Figure 1 ). Based on a theoretical characterization of it, we show that given the class label, the model which is conditionally independent of spurious attributes has stable performance across training and OOD test data. Then, a metric Conditional Spurious Variation (CSV, Definition 2) is proposed to measure such conditional independence. Notably, in contrast to the existing metrics related to OOD generalization (Hu et al., 2020; Mahajan et al., 2021) , our CSV can control the OOD generalization error. To improve OOD generalization, we regularize the training process with estimated CSV. With observable spurious attributes, we propose an estimator to CSV. However, such observable condition may be violated. In this case, we propose another estimator, which approximates a sharp upper bound of CSV. We regularize the training process with one of them, depending on whether the spurious attributes are observable. Our method improves the observable condition in (Sagawa et al., 2019) . Under mild smoothness assumptions, the regularized training objective can be formulated as a specific non-convex concave minimax problem. A stochastic gradient descent based algorithm with a provable convergence rate of order O(T -2/5 ) is proposed to solve it, where T is the number of iterations. 

2. RELATED WORKS AND PRELIMINARIES

2.1 RELATED WORKS OOD Generalization. The appearance of OOD data (Hendrycks & Dietterich, 2018) has been widely observed in machine learning community (Recht et al., 2019; Schneider et al., 2020; Salman et al., 2020; Tu et al., 2020; Lohn, 2020) . To tackle this, researchers have proposed various algorithms from different perspectives, e.g., distributional robust optimization (Sinha et al., 2018; Volpi et al., 2018; Sagawa et al., 2019; Yi et al., 2021b; Levy et al., 2020) or causal inference (Arjovsky et al., 2019; He et al., 2021; Liu et al., 2021b; Mahajan et al., 2021; Wang et al., 2022; Ye et al., 2021) . Ye et al. (2022) points out that the OOD data can be categorized into data with diversity shift (e.g., PACS (Li et al., 2018) ) and correlation shift (e.g., Waterbirds (Sagawa et al., 2019) ), and we focus on the latter in this paper, as we have clarified that it deteriorates the performance of the model on OOD test data (Geirhos et al., 2018; Beery et al., 2018; Xie et al., 2020; Wald et al., 2021) . Domain Generalization. To goal of domain generalization is extrapolating model to test data from unseen domains to capture OOD generalization. The problem we explored can be treated by domain generalization methods as data with different spurious attributes can be regarded as from different domains. The core idea in domain generalization is to learn a domain-invariant model. To this end, Arjovsky et al. ( 2019 2022) propose plenty of invariant metrics as training regularizer. However, unlike our CSV, none of these metrics controls the OOD generalization error. Moreover, none of these methods capture the invariance corresponds to the correlation shift we discussed (see Section 4.1). This motivates us to reconsider the effectiveness of these methods. Finally, in contrast to ours, these methods require observable domain labels, and it is usually impractical. The techniques in (Liu et al., 2021b; Devansh Arpit, 2019; Sohoni et al., 2020; Creager et al., 2021) are also applicable without domain information, but they are built on strong assumptions (mixture Gaussian data (Liu et al., 2021b) and linear model (Devansh Arpit, 2019)) or require a high-quality spurious attribute classifier (Sohoni et al., 2020; Creager et al., 2021) . Distributional Robustness. The distributional robustness (Ben-Tal et al., 2013) based methods minimize the worst-case loss over different groups of data (Sagawa et al., 2019; Liu et al., 2021a; Zhou et al., 2022) . The groups are decided via certain rules, e.g., data with same spurious attributes (Sagawa et al., 2019) or annotated via validation sets with observable spurious attributes (Liu et al., 2021a; Zhou et al., 2022 ). However, Sagawa et al. (2019) finds that directly minimizing the worstgroup loss results in unstable training processes. In contrast, our method has stable training process as it balances the objectives of accuracy and robustness over spurious attributes (see Section 5).

2.2. PROBLEM SETUP

We collect the notations in this paper. ∥ • ∥ is the ℓ 2 -norm of vectors. O(•) is the order of a number. The sample (X, Y ) ∈ X × Y, where X and Y are respectively input data and its label. The integer set



Figure 1: Examples of CelebA (Liu et al., 2015), Waterbirds (Sagawa et al., 2019), MultiNLI (Williams et al., 2018), and CivilComments (Borkan et al., 2019) involved in this paper. The class labels and spurious attributes are respectively colored with red and blue. Their correlation may vary from training set to test set. More details are shown in Section 6. Finally, extensive experiments are conducted to empirically verify the effectiveness of our methods on the OOD data with spurious correlation. Concretely, we conduct experiments on benchmark classification datasets CelebA (Liu et al., 2015), Waterbirds (Sagawa et al., 2019), MultiNLI (Williams et al., 2018), and CivilComments (Borkan et al., 2019). Empirical results show that our algorithm consistently improves the model's generalization on OOD data with correlation shifts.

); Hu et al. (2020); Li et al. (2018); Mahajan et al. (2021); Heinze-Deml & Meinshausen (2021); Krueger et al. (2021); Wald et al. (2021); Seo et al. (

