RECYCLING SCRAPS: IMPROVING PRIVATE LEARNING BY LEVERAGING INTERMEDIATE CHECKPOINTS

Abstract

All state-of-the-art (SOTA) differentially private machine learning (DP ML) methods are iterative in nature, and their privacy analyses allow publicly releasing the intermediate training checkpoints. However, DP ML benchmarks, and even practical deployments, typically use only the final training checkpoint to make predictions. In this work, for the first time, we comprehensively explore various methods that aggregate intermediate checkpoints to improve the utility of DP training. Empirically, we demonstrate that checkpoint aggregations provide significant gains in the prediction accuracy over the existing SOTA for CIFAR10 and StackOverflow datasets, and that these gains get magnified in settings with periodically varying training data distributions. For instance, we improve SOTA StackOverflow accuracies to 22.7% (+0.43% absolute) for ε = 8.2, and 23.84% (+0.43%) for ε = 18.9. Theoretically, we show that uniform tail averaging of checkpoints improves the empirical risk minimization bound compared to the last checkpoint of DP-SGD. Lastly, we initiate an exploration into estimating the uncertainty that DP noise adds in the predictions of DP ML models. We prove that, under standard assumptions on the loss function, the sample variance from last few checkpoints provides a good approximation of the variance of the final model of a DP run. Empirically, we show that the last few checkpoints can provide a reasonable lower bound for the variance of a converged DP model.

1. INTRODUCTION

Machine learning models can unintentionally memorize sensitive information about the data they were trained on, which has led to numerous attacks that extract private information about the training data (Ateniese et al., 2013; Fredrikson et al., 2014; 2015; Carlini et al., 2019; Shejwalkar et al., 2021; Carlini et al., 2021; 2022) . For instance, membership inference attacks (Shokri et al., 2017) can infer whether a target sample was used to train a given ML model, while property inference attacks (Melis et al., 2019; Mahloujifar et al., 2022) can infer certain sensitive properties of the training data. To address such privacy risks, literature has introduced various approaches to privacy-preserving ML (Nasr et al., 2018; Shejwalkar & Houmansadr, 2021; Tang et al., 2022) . In particular, iterative techniques like differentially private stochastic gradient decent (DP-SGD) (Song et al., 2013; Bassily et al., 2014; Abadi et al., 2016b; McMahan et al., 2017) and DP Follow The Regularized Leader (DP-FTRL) (Kairouz et al., 2021) have become the state-of-the-art for training DP neural networks. For establishing benchmarks, prior works in DP ML (Abadi et al., 2016b; McMahan et al., 2017; 2018; Thakkar et al., 2019; Erlingsson et al., 2019; Wang et al., 2019b; Zhu & Wang, 2019; Balle et al., 2020; Erlingsson et al., 2020; Papernot et al., 2020; Tramer & Boneh, 2020; Andrew et al., 2021; Kairouz et al., 2021; Amid et al., 2022; De et al., 2022; Feldman et al., 2022) use only the final model output by the DP algorithm. This is also how DP models are deployed in practice (Ramaswamy et al., 2020; McMahan et al., 2022) . However, the privacy analyses for the techniques used allow releasing/using all of the intermediate training checkpoints. In this work, we comprehensively study various methods that leverage intermediate checkpoints to 1) improve the utility of DP training, and 2) quantify the uncertainty in DP ML models that is due to the DP noise.

Accuracy improvement using checkpoints:

We propose two classes of aggregation methods based on aggregating the parameters of checkpoints, or their outputs. We provide both theoretical and em-pirical analyses for our aggregation methods. Theoretically, we show that excess empirical risk of the final checkpoint of DP-SGD is log(n) times more than that of the weighted average of the past k checkpoints. Here, n is the size of dataset. Empirically, we demonstrate significant top-1 accuracy gains due to our aggregations for image classification (CIFAR10) and a next word prediction (StackOverflow) tasks. Specifically, we show that our checkpoints aggregations achieve absolute (relative) prediction accuracy improvements of 3.79% (7.2%) at ε = 1 for CIFAR10 (DP-SGD), and 0.43% (1.9%) at ε = 8.2 for the StackOverflow (DP-FTRLM) SOTA baselines, respectively. We also show that our aggregations significantly reduce the variance in the performance of DP models over training. Finally, we show that these benefits further magnify in more practical settings with periodically varying training data distributions. For instance, we note absolute (relative) accuracy gains of 17.4% (28.6%) at ε = 8 for CIFAR10 over DP-SGD baseline in such a setting.

Uncertainty quantification using checkpoints:

There are various sources of randomness in a ML training pipeline (Abdar et al., 2021) , e.g., choice of initial parameters, dataset, batching, etc. This randomness induces uncertainty in the predictions made using such ML models. In critical domains, e.g., medical diagnosis, self-driving cars and financial market analysis, failing to capture the uncertainty in such predictions can have undesirable repercussions. DP learning adds an additional source of randomness by injecting noise at every training round. Hence, it is paramount to quantify reliability of the DP models, e.g., by quantifying the uncertainty in their predictions. To this end, we take the first steps towards quantifying the uncertainty that DP noise adds to DP ML training. As prior work, Karwa & Vadhan (2017) develop finite sample confidence intervals but for the simpler Gaussian mean estimation problem. Various methods exist for uncertainty quantification in ML-based systems (Mitchell, 1980; Roy et al., 2018; Begoli et al., 2019; Hubschneider et al., 2019; McDermott & Wikle, 2019; Tagasovska & Lopez-Paz, 2019; Wang et al., 2019a; Nair et al., 2020; Ferrando et al., 2022) . However, these methods either use specialized (or simpler) model architectures to facilitate uncertainty quantification, or are not directly applicable to quantify the uncertainty in DP ML due to DP noise. For e.g., a common way of uncertainty quantification (Barrientos et al., 2019; Nissim et al., 2007; Brawner & Honaker, 2018; Evans et al., 2020) that we call the independent runs method, needs k independent (bootstrap) runs of the ML algorithm. However, repeating a DP ML algorithm multiple times can incur significant privacy and computation costs. To address the above issue, we propose to use the last k checkpoints of a single run of a DP ML algorithm as a proxy for the k final checkpoints from independent runs. This does not incur any additional privacy cost to the DP ML algorithm. Furthermore, it is useful in practice as it does not incur additional training compute, and can work with any algorithm having intermediate checkpoints. Theoretically, we consider using the sample variance of a statistic f (θ) at checkpoints θ t1 , . . . , θ t k as an estimator of the variance of f (θ t k ), i.e., the statistic at the final checkpoint, and give a bound on the bias of this estimator. As expected, our bound on the bias decreases as the "burn-in" time t 1 as well as the time between checkpoints both increase. Intuitively, our proof shows that (i) as the burnin time increases, the marginal distribution of each θ ti approaches the distribution of θ t k , and (ii) as the time between checkpoints increases, any pair θ ti , θ tj approaches pairwise independence. Both (i) and (ii) are proven via a mixing time bound, which shows that starting from any point distribution θ 0 , the Markov chain given by DP-SGD approaches its stationary distribution at a certain rate. Empirically, we show our method provides reasonable lower bounds on the uncertainty quantified using the more accurate (but privacy and computation intensive) method that uses independent runs. 

2. IMPROVING ACCURACY BY AGGREGATING DP TRAINING CHECKPOINTS

In this section, we describe our checkpoint aggregation methods, followed by the experimental setup we use for evaluation. Next, we detail our experimental results that demonstrate the significant gains in accuracy of DP ML models due to checkpoints aggregations.



Checkpoint aggregations:(Chen et al., 2017; Izmailov et al., 2018)  explore checkpoint aggregation methods to improve performance in (non-DP) ML settings, but observe negligible performance gains. To our knowledge, De et al. (2022) is the only work in the DP ML literature that uses intermediate checkpoints post training. They apply an exponential moving average (EMA) over the checkpoints of DP-SGD, and note non-trivial gains in performance. However, we propose various aggregation methods that outperform EMA on standard benchmarks.

