LIMITS OF ALGORITHMIC STABILITY FOR DISTRIBU-TIONAL GENERALIZATION Paper under double-blind review

Abstract

As machine learning models become widely considered in safety critical settings, it is important to understand when models may fail after deployment. One cause of model failure is distribution shift, where the training and test data distributions differ. In this paper we investigate the benefits of training models using methods which are algorithmically stable towards improving model robustness, motivated by recent theoretical developments which show a connection between the two. We use techniques from differentially private stochastic gradient descent (DP-SGD) to control the level of algorithmic stability during training. We compare the performance of algorithmically stable training procedures to stochastic gradient descent (SGD) across a variety of possible distribution shifts -specifically covariate, label, and subpopulation shifts. We find that models trained with algorithmically stable procedures result in models with consistently lower generalization gap across various types of shifts and shift severities as well as a higher absolute test performance in label shift. Finally, we demonstrate that there is there is a tradeoff between distributional robustness, stability, and performance.

1. INTRODUCTION

As machine learning (ML) is applied in several high-stakes decision making situations such as healthcare (Ghassemi et al., 2017; Rajkomar et al., 2018; Zhang et al., 2021a) and lending (Liu et al., 2018; Weber et al., 2020) , it is important to consider scenarios when models fail. Typically, models are trained with empirical risk minimization (ERM), which assumes that the training and test data are sampled i.i.d from the same underlying distribution (Vapnik, 1999) . Unfortunately, this assumption means that ERM is susceptible to performance degradation under distribution shift (Nagarajan et al., 2021) . Distribution shift occurs when the data distribution encountered during deployment is different, or changes over time while the model is used. In practice, even subtle shifts can significantly affect model performance (Rabanser et al., 2019) . Given that distribution shift is a significant source of model failure, there has been much work directed toward improving model robustness to distribution shifts (Taori et al., 2020; Cohen et al., 2019; Engstrom et al., 2019; Geirhos et al., 2018; Zhang et al., 2019; Zhang, 2019) . One concept recently introduced to improve model robustness is distributional generalization (Kulynych et al., 2022; Nakkiran & Bansal, 2020; Kulynych et al., 2020) . Distributional generalization (DG) extends classical generalization to encompass any evaluation function (instead of just the loss objective) and allows the train and test distributions to differ. Kulynych et al. (2022) prove that algorithms which satisfy total variation stability (TV stability) bound the gap between train and test metrics when distribution shift is present, i.e., algorithms which satisfy TV stability are also satisfy DG. This motivates the use of techniques from differentially private (DP) learning to satisfy DG, since DP implies TV stability (Kulynych et al., 2022) . We know from other works that DP learning often comes at a cost to accuracy (Bagdasaryan et al., 2019; Suriyakumar et al., 2021; Jayaraman & Evans, 2019) . Unfortunately these works don't thoroughly explore the empirical implications of their theorems across a wide variety of settings except for a positive result in Suriyakumar et al. (2021) . Because robustness to new settings is an important question for deployments of models, it is important to understand how the theory of distributional robustness will work practically when facing different types and severities of shifts. Furthermore, it is hard to understand from the current theory how practitioners should tune the level of stability as to achieve high performing models. In this paper we conduct an extensive empirical study on the impact of using algorithmically stable learning strategies for robustness when facing distribution shift. Stable learning (SL) refers to approaches that constrain the model optimization objective or learning algorithm to improve model stability. We focus on two questions regarding the use of SL for DG in practice: (i) Under what types of shift is SL more robust and accurate than ERM? (ii) Are SL trained models consistently robust across all hyperparameters, model architectures, and shift severities? We target four common examples of shift: covariate (Shimodaira, 2000) , label (Lipton et al., 2018; Storkey, 2009) , subpopulation (Duchi & Namkoong, 2021; Koh et al., 2021) , and natural shifts Taori et al. (2020) . We use state of the art models and large benchmark datasets focusing on realistic prediction tasks in object recognition, satellite imaging, biomedical imaging, and clinical notes (see Table 1 , with details in Section 4.2). The primary comparison we make is through the generalization gap, defined as the difference in model performance between training and testing (Zhang et al., 2021b) . Under extensive experimentation-incorporating 32 distinct types of distribution shift and 5 severity levels-we find: 1. SL improves both accuracy and robustness for label and natural shifts. 2. SL has a robustness-accuracy tradeoff for covariate and subpopulation shift. et al., 2016) . Early stopping and ℓ 2 regularization have already been studied for their potential to improve distributional robustness (Sagawa et al., 2019) . However, it's difficult to conduct finegrained analyses into improved robustness with these methods because their stability is not directly controllable. This motivates our use of DP-SGD to investigate the limits of stability for DG since we can control the level of stability by adjusting the noise multiplier σ in DP-SGD. While algorithmic stability has been explored theoretically in previous works (see Section 3), we explore it empirically in this paper across various synthetic and natural distribution shifts.

3. BACKGROUND AND NOTATION

We provide an overview of the connections between algorithmic stability, DP, and different forms of generalization in this section. It is well-established that algorithmic stability implies generalization in the traditional ERM setup (Bousquet & Elisseeff, 2002) . Additional work has proven that DP implies stability and thus, implies generalization (Bassily et al., 2016; Dwork et al., 2015) . In this section we define these concepts and draw connections between them. This is done to clarify the theoretical implication that DP leads to improved distributional robustness. 



3. The tradeoffs of SL are consistent across different shift severities, model architectures, and hyperparameter settings.2 RELATED WORKMany approaches have been developed in pursuit of robustness to distribution shift, including: domain adaption (Wang & Deng, 2018), out-of-distribution detection(Yang et al., 2021), adversarial training (Madry et al., 2018; Ilyas et al., 2019), as well as through algorithmic improvements(Sagawa et al., 2019). To solve the distribution shift problem, many recent techiques for distributionally robust optimization (DRO), such as risk averse learning(Curi et al., 2020), have been developed. However, many of these methods do not perform better than ERM(Pfohl et al.,  2022)  and involve complex implementations, making them difficult to use. Algorithmic stability has also been explored to improve distributional robustness. It is often easier to implement, with simpler methods such as: ℓ 2 regularization(Wibisono et al., 2009), early stopping(Hardt et al., 2016), and differentially private stochastic gradient descent (DP-SGD) (Abadi

We assume there is a training dataset D train = {(x i , y i )} n i=1 of labeled examples such that D train ∼ D and a testing dataset D test = {(x i , y i )} m i=1 of labeled examples such that D test ∼ D ′ . Given D train , we use a randomized learning algorithm M(D train ) to learn parameters θ ∈ Θ of a model relating the datapoints {x i } to their corresponding label {y i }. Now we will describe differential privacy and its links to stability. Definition 1 (Differential Privacy (Dwork et al., 2006)). Suppose we have two datasets D, D ′ which have a Hamming distance (the number of examples which the two databases differ by) of 1, then an algorithm M(D) is (ϵ, δ)-differentially private if:

