DISTRIBUTIONALLY ROBUST POST-HOC CLASSIFIERS UNDER PRIOR SHIFTS

Abstract

The generalization ability of machine learning models degrades significantly when the test distribution shifts away from the training distribution. We investigate the problem of training models that are robust to shifts caused by changes in the distribution of class-priors or group-priors. The presence of skewed training priors can often lead to the models overfitting to spurious features. Unlike existing methods, which optimize for either the worst or the average performance over classes or groups, our work is motivated by the need for finer control over the robustness properties of the model. We present an extremely lightweight post-hoc approach that performs scaling adjustments to predictions from a pre-trained model, with the goal of minimizing a distributionally robust loss around a chosen target distribution. These adjustments are computed by solving a constrained optimization problem on a validation set and applied to the model during test time. Our constrained optimization objective is inspired from a natural notion of robustness to controlled distribution shifts. Our method comes with provable guarantees and empirically makes a strong case for distributional robust post-hoc classifiers. An empirical implementation is available at

1. INTRODUCTION

Distribution shift, a problem characterized by the shift of test distribution away from the training distribution, deteriorates the generalizability of machine learning models and is a major challenge for successfully deploying these models in the wild. We are specifically interested in distribution shifts resulting from changes in marginal class priors or group priors from training to test. This is often caused by a skewed distribution of classes or groups in the training data, and vanilla empirical risk minimization (ERM) can lead to models overfitting to spurious features. These spurious features seem to be predictive on the training data but do not generalize well to the test set. For example, the background can act as a spurious feature for predicting the object of interest in images, e.g., camels in a desert background, water-birds in water background (Sagawa et al., 2020) . Distributionally robust optimization (DRO) (Ben-Tal et al., 2013; Duchi et al., 2016; Duchi & Namkoong, 2018; Sagawa et al., 2020) is a popular framework to address this problem which formulates a robust optimization problem over class-or group-specific losses. The common metrics of interest in the DRO methods are either the average accuracy or the worst accuracy over classes/groups (Menon et al., 2021; Jitkrittum et al., 2022; Rosenfeld et al., 2022; Piratla et al., 2022; Sagawa et al., 2020; Zhai et al., 2021; Xu et al., 2020; Kirichenko et al., 2022) . However, these metrics only cover the two ends of the full spectrum of distribution shifts in the priors. We are instead motivated by the need to measure the robustness of the model at various points on the spectrum of distribution shifts. To this end, we consider applications where we are provided a target prior distribution (that could either come from a practitioner or default to the uniform distribution), and would like to train a model that is robust to varying distribution shifts around this prior. Instead of taking the conventional approach of optimizing for either the average accuracy or the worst-case accuracy, we seek to maximize the minimum accuracy within a δ-radius ball around the specified target distribution. This strategy allows us to encourage generalization on a spectrum of controlled distribution shifts governed by the parameter δ. When δ = 0, our objective is simply the average accuracy for the specified target priors, and when δ → ∞, it reduces to the model's worst-case accuracy, thus providing a natural way to interpolate between the two extreme goals of average and worst-case optimization. To train a classifier that performs well on the prescribed distributionally robust objective, we propose a fast and extremely lightweight post-hoc method that learns scaling adjustments to predictions from a pre-trained model. These adjustments are computed by solving a constrained optimization problem on a validation set, and then applied to the model during evaluation time. A key advantage of our method is that it is able to reuse the same pretrained model for different robustness requirements by simply scaling the model predictions. This is contrast to several existing DRO methods that train all model parameters using the robust optimization loss (Sagawa et al., 2020; Piratla et al., 2022) , which requires group annotations for the training data and requires careful regularization to make it work with overparameterized models (Sagawa et al., 2020) . On the other hand, our approach only needs group annotations for a smaller held-out set and works by only scaling the model predictions of a pre-trained model at test time. Our method also comes with provable convergence guarantees. We apply our method on standard benchmarks for class imbalance and group DRO, and show that it compares favorably to the existing methods when evaluated on a range of distribution shifts away from the target prior distribution.

2. BACKGROUND

We are primarily interested in two specific prior shifts for distributional robustness of classifiers. In this section, we briefly introduce the problem setting of the two prior shifts and set the notation. Class-Level Prior Shifts. We are interested in a multi-class classification problem with instance space X and output space [m] = {1, . . . , m}. Let D denote the underlying data distribution over X × [m], the random variables of instance X and label Y satisfy that (X, Y ) ∼ D . We define the conditional-class probability as η y (x) = P(Y = y|X = x) and the class priors π y = P(Y = y), note that π y = E [η y (x)] . We use u = [ 1 m , . . . , 1 m ] to denote the uniform prior over m classes. Our goal is then to learn a multi-class classifier h : X → [m] that maps an instance x ∈ X to one of m classes. We will do so by first learning a scoring function f : X → ∆ m that estimates the conditional-class probability for a given instance, and construct the classifier by predicting the class with the highest score: h(x) = arg max j∈[m] f j (x). We measure the performance of a scoring function using a loss function : [m] × ∆ m → R + and measure the per-class loss using i (f ) := E [ (y, f (x)) | y = i]. Let {(x i , y i )} n i=1 be a set of training data samples. The empirical estimate of training set prior is πi := 1 n j∈[n] 1(y j = i) where 1(•) is the indicator function. In class prior shift, the class prior probabilities at test time shift away from π. A special case of such class-level prior shifts includes class-imbalanced learning (Lin et al., 2017; Cui et al., 2019; Cao et al., 2019; Ren et al., 2020; Menon et al., 2021) where π is a long-tailed distribution while the class prior at test time is usually chosen to be the uniform distribution. Regular ERM tends to focus more on the majority classes at the expense of ignoring the loss of the minority classes. Recent work (Menon et al., 2021) uses temperature-scaled logit adjustment with training class priors to adapt the model for average class accuracy. Our method also applies post-hoc adjustments to model probabilities, but our goal differs from (Menon et al., 2021) as we care for varying distribution shifts around the uniform prior and the scaling adjustments are learned using a held-out set to optimize for a constrained robust loss. Group-Level Prior Shifts. The notion of groups arises when each data point (x, y) is associated with some attribute a ∈ A that is spuriously correlated with the label. This is used to form m = |A|× |Y | groups as the cross-product of |A| attributes and |Y | classes. The data distribution D is taken to be a mixture of m groups with mixture prior probabilities π, and each group-conditional distribution given by D j , j ∈ [m]. In this scenario, we have n training samples {(x i , y i )} n i=1 drawn i.i.d. from D, with empirical group prior probabilities π. For skewed group-prior probabilities π, regular ERM is vulnerable to spurious correlations between the attributes and labels, and the accuracy degrades

availability

https://github.com/weijiaheng/Drops.

