BREAKING CORRELATION SHIFT VIA CONDITIONAL INVARIANT REGULARIZER

Abstract

Recently, generalization on out-of-distribution (OOD) data with correlation shift has attracted great attentions. The correlation shift is caused by the spurious attributes that correlate to the class label, as the correlation between them may vary in training and test data. For such a problem, we show that given the class label, the models that are conditionally independent of spurious attributes are OOD generalizable. Based on this, a metric Conditional Spurious Variation (CSV) which controls the OOD generalization error, is proposed to measure such conditional independence. To improve the OOD generalization, we regularize the training process with the proposed CSV. Under mild assumptions, our training objective can be formulated as a nonconvex-concave mini-max problem. An algorithm with a provable convergence rate is proposed to solve the problem. Extensive empirical results verify our algorithm's efficacy in improving OOD generalization.

1. INTRODUCTION

The success of standard learning algorithms rely heavily on the identically distributed assumption of training and test data. However, in real-world, such assumption is often violated due to the varying circumstances, selection bias, and other reasons (Meinshausen & Bühlmann, 2015) . Thus, learning a model that generalizes on out-of-distribution (OOD) data has attracted great attentions. The OOD data (Ye et al., 2022) can be categorized into data with diversity shift or correlation shift. Roughly speaking, there is a mismatch of the spectrum and a spurious correlation between training and test distributions under the two shifts, respectively. Compared with diversity shift, correlation shift is less explored (Ye et al., 2022) , while the misleading spurious correlation works for training data may deteriorate model's performance on test data (Beery et al., 2018) . The correlation shift says, for the spurious attributes in data, there exists variation of (spurious) correlation between class label and such spurious attributes from training to test data (Figure 1 ). Based on a theoretical characterization of it, we show that given the class label, the model which is conditionally independent of spurious attributes has stable performance across training and OOD test data. Then, a metric Conditional Spurious Variation (CSV, Definition 2) is proposed to measure such conditional independence. Notably, in contrast to the existing metrics related to OOD generalization (Hu et al., 2020; Mahajan et al., 2021) , our CSV can control the OOD generalization error. To improve OOD generalization, we regularize the training process with estimated CSV. With observable spurious attributes, we propose an estimator to CSV. However, such observable condition may be violated. In this case, we propose another estimator, which approximates a sharp upper bound of CSV. We regularize the training process with one of them, depending on whether the spurious attributes are observable. Our method improves the observable condition in (Sagawa et al., 2019) . Under mild smoothness assumptions, the regularized training objective can be formulated as a specific non-convex concave minimax problem. A stochastic gradient descent based algorithm with a provable convergence rate of order O(T -2/5 ) is proposed to solve it, where T is the number of iterations.

