KNOWLEDGE DISTILLATION AS SEMIPARAMETRIC INFERENCE

Abstract

A popular approach to model compression is to train an inexpensive student model to mimic the class probabilities of a highly accurate but cumbersome teacher model. Surprisingly, this two-step knowledge distillation process often leads to higher accuracy than training the student directly on labeled data. To explain and enhance this phenomenon, we cast knowledge distillation as a semiparametric inference problem with the optimal student model as the target, the unknown Bayes class probabilities as nuisance, and the teacher probabilities as a plug-in nuisance estimate. By adapting modern semiparametric tools, we derive new guarantees for the prediction error of standard distillation and develop two enhancements-cross-fitting and loss correction-to mitigate the impact of teacher overfitting and underfitting on student performance. We validate our findings empirically on both tabular and image data and observe consistent improvements from our knowledge distillation enhancements.

1. INTRODUCTION

Knowledge distillation (KD) (Craven & Shavlik, 1996; Breiman & Shang, 1996; Bucila et al., 2006; Li et al., 2014; Ba & Caruana, 2014; Hinton et al., 2015) is a widely used model compression technique that enables the deployment of highly accurate predictive models on devices such as phones, watches, and virtual assistants (Stock et al., 2020) . KD operates by training a compressed student model to mimic the predicted class probabilities of an expensive, high-quality teacher model. Remarkably and across a wide variety of domains (Hinton et al., 2015; Sanh et al., 2019; Jiao et al., 2019; Liu et al., 2018; Tan et al., 2018; Fakoor et al., 2020) , this two-step process often leads to higher accuracy than training the student directly on the raw labeled dataset. While the practice of KD is now well developed, a general theoretical understanding of its successes and failures is still lacking. As we detail below, a number of authors have argued that the success of KD lies in the more precise "soft labels" provided by the teacher's predicted class probabilities. Recently, Menon et al. (2020) observed that these teacher probabilities can serve as a proxy for the Bayes probabilities (i.e., the true class probabilities) and that the closer the teacher and Bayes probabilities, the better the student's performance should be. Building on this observation, we cast KD as a plug-in approach to semiparametric inference (Kosorok, 2007) : that is, we view KD as fitting a student model f in the presence of nuisance (the Bayes probabilities p 0 ) with the teacher's probabilities p as a plug-in estimate of p 0 . This insight allows us to adapt modern tools from semiparametric inference to analyze the error of a distilled student in Sec. 3. Our analysis also reveals two distinct failure modes of KD: one due to teacher overfitting and data reuse and the other due to teacher underfitting from model misspecification or insufficient training. In Sec. 4, we introduce and analyze two complementary KD enhancements that correct for these failures: cross-fitting-a popular technique from semiparametric inference (see, e.g., Chernozhukov et al., 2018) -mitigates teacher overfitting through data partitioning while loss correction mitigates teacher underfitting by reducing the bias of the plug-in estimate p. The latter enhancement was inspired by the orthogonal machine learning (Chernozhukov et al., 2018; Foster & Syrgkanis, 2019) approach to semiparametric inference which suggests a particular adjustment for the teacher's log probabilities. We argue in Sec. 4 that this orthogonal correction minimizes the teacher bias but often at the cost of unacceptably large variance. Our proposed correction avoids this variance explosion by balancing the bias and variance terms in our generalization bounds. In Sec. 5, we complement our theoretical analysis with a pair of experiments demonstrating the value of our enhancements on six real classification problems. On five real tabular datasets, cross-fitting and loss correction improve student performance by up to 4% AUC over vanilla KD. Furthermore, on CIFAR-10 (Krizhevsky & Hinton, 2009) , a benchmark image classification dataset, our enhancements improve vanilla KD accuracy by up to 1.5% when the teacher model overfits. Related work. Since we cannot review the vast literature on KD in its entirety, we point the interested reader to Gou et al. (2020) for a recent overview of the field. We devote this section to reviewing theoretical advances in the understanding of KD and summarize complementary empirical studies and applications of in the extended literature review in App. A. A number of papers have argued that the availability of soft class probabilities from the teacher rather than hard labels enables us to improve training of the student model. This was hypothesized in Hinton et al. (2015) with empirical justification. Phuong & Lampert (2019) consider the case in which the teacher is a fixed linear classifier and the student is either a linear model or a deep linear network. They show that the student can learn the teacher perfectly if the number of training examples exceeds the ambient dimension. Vapnik & Izmailov (2015) discuss the setting of learning with privileged information where one has additional information at training time which is not available at test time. Lopez-Paz et al. (2015) draw a connection between this and KD, arguing that KD is effective because the teacher learns a better representation allowing the student to learn at a faster rate. They hypothesize that a teacher's class probabilities enable student improvement by indicating how difficult each point is to classify. Tang et al. (2020) argue using empirical evidence that label smoothing and reweighting of training examples using the teacher's predictions are key to the success of KD. Mobahi et al. (2020) analyzed the case of self-distillation in which the student and teacher function classes are identical. Focusing on kernel ridge regression models, they proved that self-distillation can act as increased regularization strength. Bu et al. (2020) considers more generic model compression in a rate-distortion framework, where the rate is the size of the student model and distortion is the difference in excess risk between the teacher and the student. Menon et al. (2020) consider the case of losses such that the population risk is linear in the Bayes class probabilities. They consider distilled empirical risk and Bayes distilled empirical risk which are the risk computed using the teacher class probabilities and Bayes class probabilities respectively rather than the observed label. They show that the variance of the Bayes distilled empirical risk is lower than the empirical risk. Then using analysis from Maurer & Pontil (2009) ; Bennett (1962) , they derive the excess risk of the distilled empirical risk as a function of the 2 distance between the teacher's class probabilities and the Bayes class probabilities. We significantly depart from Menon et al. (2020) in multiple ways: i) our Thm. 1 allows for the common practice of data re-use, ii) our results cover the standard KD losses SEL and ACE which are non-linear in p 0 , iii) we use localized Rademacher analysis to achieve tight fast rates for standard KD losses, and iv) we use techniques from semiparametric inference to improve upon vanilla KD.

2. KNOWLEDGE DISTILLATION BACKGROUND

We consider a multiclass classification problem with k classes and n training datapoints z i = (x i ,y i ) sampled independently from some distribution P. Each feature vector x belongs to a set X , each label vector y ∈ {e 1 ,...,e k } ⊂ {0,1} k is a one-hot encoding of the class label, and the conditional probability of observing each label is the Bayes class probability function p 0 (x) = E[Y | X = x]. Our aim is to identify a scoring rule f : X → R k that minimizes a prediction loss on average under the distribution P. Knowledge distillation. Knowledge distillation (KD) is a two-step training process where one first uses a labeled dataset to train a teacher model and then trains a student model to predict the teacher's predicted class probabilities. Typically the teacher model is larger and more cumbersome, while the student is smaller and more efficient. Knowledge distillation was first motivated by model compression (Bucila et al., 2006) , to find compact yet high-performing models to be deployed (such as on mobile devices). In training the student to match the teacher's prediction probability, there are several types of loss functions that are commonly used. Let p(x) ∈ R k be the teacher's vector of predicted class probabilities, f (x) ∈ R k be the student model's output, and [k] {1,2,...,k}. The most popular distillation loss functionsfoot_0 (z;f (x),p(x)) include the squared error logit (SEL) loss (Ba & Caruana, 2014) se (z;f (x),p(x)) j∈[k] 1 2 (f j (x)-log(p j (x))) 2 (SEL) and the annealed cross-entropy (ACE) loss (Hinton et al., 2015)  β (z;f (x),p(x)) = -j∈[k] pj (x) β l∈[k] pl (x) β log exp(βfj (x)) l∈[k] exp(βf l (x)) (ACE) for an inverse temperature β > 0. These loss functions measure the divergence between the probabilities predicted by the teacher and the student. A student model trained with knowledge distillation often performs better than the same model trained from scratch (Bucila et al., 2006; Hinton et al., 2015) . In Secs. 3 and 4, we will adapt modern tools from semiparametric inference to understand and enhance this phenomenon.

3. DISTILLATION AS SEMIPARAMETRIC INFERENCE

In semiparametric inference (Kosorok, 2007) , one aims to estimate a target parameter or function f 0 , but that estimation depends on an auxiliary nuisance function p 0 that is unknown and not of primary interest. We cast the knowledge distillation process as a semiparametric inference problem, by treating the unknown Bayes class probabilities p 0 as nuisance and the teacher's predicted probabilities as a plug-in estimate of that nuisance. This perspective allows us bound the generalization of the student in terms of the mean squared error (MSE) between the teacher and the Bayes probabilities. In the next section (Sec. 4) we use techniques from semiparametric inference to enhance the performance of the student. The interested reader could consult Tsiatis (2007) for more details on semiparametric inference. Our analysis starts from taking the following perspective on distillation. For a given pointwise loss function (z;f (x),p 0 (x)), we view the goal of the student as minimizing an oracle population loss over a function class F, L D (f,p 0 ) = E[ (Z;f (X),p 0 (X))] with f 0 argmin f ∈F L D (f,p 0 ). The main hurdle is that this is objective depends on the unknown Bayes probabilities p 0 . We view the teacher's model p as an approximate version of p 0 and bound the distillation error of the student as a function of the teacher's estimation error. Typical semiparametric inference considers cases where f 0 is a finite dimensional parameter; however recent work of Foster & Syrgkanis (2019) extends this framework to infinite dimensional models f 0 and to develop statistical learning theory with a nuisance component framework. The distillation problem fits exactly into this setup. Bounds on vanilla KD As a first step we derive a vanilla bound on the error of the distilled student model without any further modifications of the distillation process, i.e., we assume that the student is trained on the same data as the teacher and is trained by running empirical risk minimization (ERM) on the plug-in loss, plugging in the teacher's model instead of p 0 , i.e., f = argmin f ∈F L n (f,p) for L n (f,p) E n [ (Z;f (X),p(X))] (Vanilla KD) where E n [X] = 1 n n i=1 X i denotes the empirical expectation of a random variable. Technical definitions Before presenting our main theorem we introduce some technical notation. For a vector valued function f that takes as input a random variable X, we use the shorthand notation f p,q f (X) p L q = E f (X) q p 1/q . Let ∇ φ and ∇ π denote the partial derivatives of (z;φ,π), with respect to its second and third input correspondingly and ∇ φπ the Jacobian of cross partial derivatives, i.e., [∇ φπ (z;φ,π)] i,j = ∂ 2 ∂φj ∂πi (z;φ,π). Finally, let q f,p (x) = E[∇ φπ (Z;f (X),p(X)) | X = x] and γ f,p (x) = E U ∼Unif([0,1]) [q f,U p+(1-U )p0 (x)]. Critical radius Finally, we need to define the notion of the critical radius (see, e.g., Wainwright (2019, 14. for universal constants c 0 ,c 1 and δ n an upper bound on the critical radius of the function class G {z → r( (z;f (x),p(x))-(z;f 0 (x),p(x))) : f ∈ F, p ∈ P, r ∈ [0,1]}. Let µ(z) = sup φ ∇ φ (z;φ,p(x) ) 2 , and assume that the loss (z;φ,π) is σ-strongly convex in φ for each z and that each g ∈ G is uniformly bounded in [-H,H]. Then the Vanilla KD f satisfies f -f 0 2 2,2 = 1 σ 2 O(δ 2 n,ζ C 2 H 2 µ 2 4 + γ f0, p( p-p 0 ) 2 2,2 ) with probability at least 1-ζ. Thm. 1, proved in App. C, shows that vanilla distillation yields an accurate student whenever the teacher generalizes well (i.e., p-p 0 2,2 is small) and the student and teacher model classes F and P are not too complex. The 2 / 4 ratio requirement can be removed at the expense of replacing µ 4 by µ ∞ = sup z |µ(z)| in the final bound. Moreover, we highlight that the strong convexity requirement for is satisfied by all standard distillation objectives including SEL and ACE, as it is strong convexity with respect to the output of f and not the parameters of f . Even this requirement could be removed, but this would yield slow rate bounds of the form: f -f 0 2 2,2 = O(δ n,ζ + γ f0, p( p-p 0 ) 2 2,2 ). Failure modes of vanilla KD Thm. 1 also hints at two distinct ways in which vanilla distillation could fail. First, since the student only learns from the teacher and does not have access to the original labels, we would expect the student to be erroneous when the teacher probabilities are inaccurate due to model misspecification, an overly restrictive teacher function class, or insufficient training. Prop. 2, proved in App. D, confirms that, in the worst case, student error suffers from inaccuracy due to this teacher underfitting even when both the student and teacher belong to low complexity model classes. Proposition 2 (Impact of teacher underfitting on vanilla KD). There exists a classification problem in which the following properties all hold simultaneously with high probability for f 0 = log(p 0 ): • The teacher learns p(x) = 1 n(1+λ) n i=1 y i for all x ∈ X via ridge regression with λ = Θ(1/n 1/4 ). • Vanilla KD with SEL loss and constant f satisfies f -f 0 2 2,2 ≥ γ f0,p0 (p-p 0 ) 2 2,2 = Ω( 1 √ n ) , matching the dependence of the Thm. 1 upper bound up to a constant factor. • Enhanced KD with SEL loss, γ(t) = diag( 1 p(t) ), and constant f satisfies f -f 0 2 2,2 = O( 1 n ). Second, the critical radius in Thm. 1 depends on the complexity of the teacher model class P. If P has a large critical radius, then the student error bound suffers due to potential teacher overfitting even if the teacher generalizes well. Prop. 3, proved in App. E, shows that, in the worst case, this teacher overfitting penalty is unavoidable and does in fact lead to increased student error. This occurs as the student only has access to the teacher's training set probabilities which, due to overfitting, need not reflect its test set probabilities. Proposition 3 (Impact of teacher overfitting on vanilla KD). There exists a classification problem in which the following properties all hold simultaneously with high probability for f 0 = E[log(p 0 (X))]: • The critical radius δ n of the teacher-student function class G in Thm. 1 is a non-vanishing constant, due to the complexity of the teacher's function class. • The Vanilla KD error f -f 0 2 2,2 for constant f with SEL loss is lower bounded by a non-vanishing constant, matching the δ n dependence of the Thm. 1 upper bound up to a constant factor. • Enhanced KD with SEL loss, γ(t) = 0, and constant f satisfies f -f 0 2 2,2 = O(n -4/(4+d) ). These examples serve to lower bound student performance in the worst case by the teacher's critical radius and class probability MSE, matching the upper bounds given in Thm. 1. However, we note that in other better-case scenarios vanilla distillation can perform better than the upper-bounding Thm. 1 would imply. In the next section, we adapt and generalize techniques from semiparametric inference to mitigate the effects of teacher overfitting and underfitting in all cases.

4. ENHANCING KNOWLEDGE DISTILLATION

To address the two distinct inefficiencies of vanilla distillation revealed in Sec. 3, we will adapt and generalize two distinct techniques from semiparametric inference: orthogonal correction and cross-fitting.

4.1. COMBATING TEACHER UNDERFITTING WITH LOSS CORRECTION

We can view the plug-in distillation loss (z;f (x), p(x)) as a zeroth order Taylor approximation to the ideal loss (z;f (x),p 0 (x)) around p. An ideal first-order approximation would take the form (z;f (x),p(x))+ p 0 (x)-p(x),∇ π (z;f (x),p(x)) . However, its computation also requires knowledge of p 0 . Nevertheless, since p 0 (x) = E[Y | X = x], we can always construct an unbiased estimate of the ideal first order term by replacing p 0 (x) with y: ortho (z;f (x),p(x)) = (z;f (x),p(x))+ y-p(x),E[∇ π (z;f (x),p(x)) | x] . (1) For standard distillation base losses like SEL and ACE, the orthogonal loss (1) has an especially simple form, as ∇ π (z;f (x), p(x)) is linear in f . Indeed, this is true more generally for the following class of Bregman divergence losses. Definition 1 (Bregman divergence losses). Any Bregman divergence loss function of the form (z;f (x),p(x)) Ψ(f (x))-Ψ(g(p(x)))-∇ g Ψ(g(p(x))),f (x)-g(p(x)) has ortho (z;f (x),p(x)) = (z;f (x),p(x))+(y-p(x)) ∇ p g(p(x)) ∇ 2 gg Ψ(g(p(x)))f (x)+const (2) with the second term bilinear in f (x) and y -p(x). For the SEL loss, Ψ(s) = 1 2 s 2 2 , g(p) = log(p), and the correction matrix ∇ p g(p(x)) ∇ 2 gg Ψ(g(p(x))) = diag( 1 p(x) ). Similarly, the ACE loss falls into the class of Bregman divergence losses. We will show that orthogonal correction (1) can significantly improve student bias due to teacher underfitting; however, for our standard distillation losses (SEL and ACE), the same orthogonal correction term often introduces unreasonably large variance due to division by small probabilities appearing in the correction matrix (see Definition 1). To grant ourselves more flexibility in balancing bias and variance, we propose and analyze a family of γ-corrected losses, parameterized by a matrix valued function γ : X → R k ×R k : γ (z;f (x),p(x)) (z;f (x),p(x))+(y-p(x)) γ(x)f (x) to mimic the bilinear structure of Bregman orthogonal losses (2). Note that we can always recover the vanilla distillation loss by taking γ ≡ 0. We denote the associated population and empirical risks by L D (f,p,γ) E[ γ (Z;f (X),p(X))] and L n (f,p,γ) E n [ γ (Z;f (X),p(X))] . Observe that at p 0 the correction term is mean-zero and hence L D (f,p 0 ,γ) is independent of γ L D (f,p 0 ) E[ (Z;f (X),p 0 (X))] = L D (f,p 0 ,γ) for all γ. The γ-corrected loss has strong connections to the literature on Neyman orthogonality (Chernozhukov et al., 2018; Chernozhukov et al., 2016; Nekipelov et al., 2018; Chernozhukov et al., 2018; Foster & Syrgkanis, 2019) . In particular, if the function γ is set appropriately, then one can show that the γ-corrected loss function satisfies the condition of a Neyman orthogonal loss defined by Foster & Syrgkanis (2019). We begin our analysis by showing a general lemma for any estimator f , which adapts the main theorem of Foster & Syrgkanis (2019) to account for approximate orthogonality; the proof can be found in App. F. Lemma 4 (Algorithm-agnostic analysis). Consider any estimation algorithm that produces an estimate f with small plug-in excess risk, i.e., L D ( f ,p,γ)-L D (f 0 ,p,γ) ≤ ( f ,p,γ). If the loss L D is σ-strongly convex with respect to f and F is a convex set, then σ 4 f -f 0 2 2,2 ≤ ( f ,p,γ)+ 1 σ (γ f0, p -γ) (p-p 0 ) 2 2,2 . If, in addition, sup z,φ,π,i∈[d] ∇ φiππ (z;φ,π) op ≤ M , then (γ f0, p -γ) (p-p 0 ) 2 2,2 ≤ 2 (q f0, p -γ) (p-p 0 ) 2 2,2 +M 2 k p-p 0 4 2,4 . Connection to Neyman orthogonality Remarkably, if we set γ = q f0, p, then the γ-corrected loss is Neyman orthogonal (Foster & Syrgkanis, 2019), and the student MSE bound depends only on the squared MSE of the teacher. Moreover, q f0, p is an observable quantity for any Bregman divergence loss (Definition 1) as q f0, p is independent of f 0 . However, we note that this setting of the γ can lead to larger variance, i.e., the achievable excess risk can be much larger than the excess risk without the correction. For instance, in the case of the SEL loss q f0, p(x) = 1 p(x) , which can be excessively large when p is close to 0, leading to a large increase in the variance of our loss. Thus, in a departure from the standard approach in semiparametric inference, we will be choosing γ in practice to balance bias and variance. Example instantiation of student's estimation algorithm If we use plug-in empirical risk minimization, i.e., f = argmin f ∈F L n (f,p,γ), to estimate f 0 with p estimated on an independent sample, then the results of Maurer & Pontil (2009) directly imply that as long as the loss function (z;φ,π) is uniformly bounded in [-H,H], then, with probability at least 1-δ, ( f ,p,γ) = O sup f ∈F Var( γ (Z;f (X), p(X)))log(τ (n)/δ) n + Hlog(τ (n)/δ) n where τ (n) = N ∞ (1/n,F,2n ) and N ∞ ( ,F,m) is the ∞ empirical covering number of function class F in the worst-case over all realizations of m data points and at approximation level . This result has two drawbacks: it is a slow rate result that scales as 1/ √ n for parametric or bounded Vapnik-Chervonenkis (VC)-dimension classes, and it requires the student to be fit on a completely separate dataset from the teacher's. In the next theorem, we address both of these drawbacks: i) we invoke localized Rademacher complexity analysis to provide a fast rate result which would be of the order of 1/n for VC or parametric function classes, and ii) we use a more sophisticated data-partitioning technique called cross-fitting, which allows the student to be trained using all of the available teacher data.

4.2. COMBATING TEACHER OVERFITTING WITH CROSS-FITTING

We now describe a more sophisticated version of data partitioning to make use of all data points in our student estimation, while at the same time not suffering from the sample complexity of the teacher's function space. This approach is referred to as cross-fitting (CF) in the semiparametric inference literature (see, e.g., Chernozhukov et al. (2018) ): 1. Partition the dataset into B equally sized folds P 1 ,...,P B . 2. For each fold t ∈ [B] estimate p(t) and γ(t) using all the out-of-fold data points. 3. Estimate f by minimizing the empirical loss: f = argmin f ∈F 1 n B t=1 i∈Pt γ(t) (Z i ;f (X i ),p (t) (X i )). (Enhanced KD) In other words, the nuisance estimates (γ (t) ,p (t) ) that are evaluated on the data points in fold t when fitting the student in step 3, are estimated only using data points outside of P t . Theorem 5 (Enhanced KD analysis). Suppose f 0 belongs to a convex set F. Let δ n/B,ζ/B = δ n/B + c 0 Blog(c1B/ζ) n for universal constants c 0 , c 1 and δ n/B an upper bound on the critical radius of the class G(p (t) ,γ (t) ) = {z → r γ(t) (z;f (x),p (t) (x))-γ(t) (z;f 0 (x),p (t) (x)) : f ∈ F,r ∈ [0,1]} for each t ∈ [B]. Let µ(z) = sup f ∈F ,t∈[B] ∇ φ γ(t) (z;f (X),p (t) (x) ) 2 , and assume that, with probability 1 for each t ∈ [B], the loss γ(t) (z;φ, p(t) (x)) is σ-strongly convex in φ for each z and each g ∈ G(p (t) ,γ (t) ) is uniformly bounded in [-H,H]. Moreover, suppose that the function class F satisfies the 2 / 4 ratio condition: sup f ∈F f -f0 2,4 f -f0 2,2 ≤ C. If f is the output of Enhanced KD, then, with probability at least 1-ζ, σ 8 f -f 0 2 2,2 = 1 σ O δ 2 n/B,ζ/B C 2 H 2 µ 2 4 + 1 B B t=1 E (Y -p(t) (X)) γ(t) (X) 4 2 + 1 σ O( 1 B B t=1 (γ f0, p(t) -γ (t) ) (p (t) -p 0 ) 2 2,2 ). The proof is found in App. G. Observe that, unlike Thm. 1, the function classes G(p (t) ,γ (t) ) in the Thm. 5 do not vary the teacher's model over P but rather evaluate p at the specific out-of-fold estimates p(t) and only vary f ∈ F. Since in practice the teacher's model can be quite complex, removing this dependence on the sample complexity of the teacher's function space can bring immense improvement with the critical radius of G(p (t) ,γ (t) ) significantly smaller than that of G from Thm. 1. For instance, suppose that the loss function γ(t) (z;f,p (t) ) is L-Lipschitz with respect to f and that F is a VC-subgraph class with VC dimension d F . Then the critical radius of the function class G(p (t) ,γ (t) ) is of order d F log(n)/n for any choice of (p (t) ,γ (t) ) (see, e.g., Foster & Syrgkanis, 2019, Sec. 4.2). 2 However, under the same conditions, the critical radius of the teacher-student function class G in Thm. 1 will still depend on the teacher's function space. If P is also a VC-subgraph class with VC dimension d P d F , then the critical radius of G will be of the much larger order d P log(n)/n. We can also see in the bound of Thm. 5 the interplay between bias and variance introduced by γ. In particular, the part of the bound that depends on γ(t) can be further simplified as E[δ 4 n,ζ C 4 (Y -p(X)) γ(t) (X) 4 2 + (γ p,0 (X)-γ (t) (X)) (p(X)-p 0 (X)) 4 2 ], (3) where the terms respectively encode the increase in variance and decrease in bias from employing loss correction. Notably, Thm. 5 implies that CF without γ-correction (i.e., γ(t) (x) = 0) is sufficient to reduce student error due to teacher overfitting but may still be susceptible to excessive student error due to teacher underfitting. These qualitative predictions accord with our experimental observations in Sec. 5 and Fig. 5 .

4.3. BIASED STOCHASTIC GRADIENT DESCENT ANALYSIS

When the set of candidate prediction rules f θ is parameterized by a vector θ ∈ R d , we may alternatively fit θ via stochastic gradient descent (SGD) (Robbins & Monro, 1951; Bottou & Bousquet, 2008) on the γ-corrected objective L D (f θ ,p,γ). With a minibatch size of 1 and a starting point θ 0 , the parameter updates take the form θ t+1 = θ t -η t ∇ θ f θ (X t ) ∇ φ γ (W t ;f θ (X t ),p(X t )) for t+1 ∈ [n]. (4) Ideally, these updates would converge to a minimizer of the ideal risk L(θ;p 0 ) = L D (f θ ,p 0 ). Our next result shows that, if the teacher p is independent of (W t ) t∈[n] , then the SGD updates (4) have excess ideal risk governed by a bias term ζ(γ) and a variance term σ(γ) 2 /n. Here, σ 2 0 (θ) represents the baseline stochastic gradient variance that would be incurred if SGD were run directly on the ideal risk L(θ;p 0 ) rather than our surrogate risk. Our proof in App. H builds upon the biased SGD bounds of Ajalloeian & Stich (2020) . Theorem 6 (Biased SGD analysis). Suppose that the loss L(θ;p 0 ) is λ-strongly smooth in θ. Define the bias and root-variance parameters ζ(γ) sup θ∈R d ∇ θ f θ (γ f θ , p -γ) (p-p 0 ) 2,2 σ(γ) sup θ∈R d σ 0 (θ)+ E[ ∇ θ f θ (X) γ(X) (Y -p 0 (X)) 2 2 ]+ ∇ θ f θ (γ f θ , p -γ) (p-p 0 ) 2 2,2 for σ 2 0 (θ) i∈[d] Var[∇ θi (W ; f θ (X), p 0 (X))] the unbiased SGD variance. If F 0 = L(θ 0 ;p 0 )-min θ∈R d L(θ;p 0 ), then the iterates {θ t } n t=1 of the loss corrected SGD algorithm satisfy min t∈[n] E ∇ θ L(θ t ;p 0 ) 2 2 = O σ(γ) √ λF0 √ n +ζ 2 (γ) . If, in addition, L(θ;p 0 ) is µ-strongly convex in θ, then the iterates satisfy E[L(θ n ;p 0 )-min θ∈R d L(θ;p 0 )] = 1 µ O( λ µ σ(γ) 2 n +ζ 2 (γ))+O F 0 e -µ 2λ n . Similar to Thm. 5, the bound in Thm. 6, portrays the interplay of bias and variance as γ ranges from 0 to q f θ , p (recall that q f θ , p is independent of f θ for any Bregman loss). In particular, the part of the bound for strongly convex losses that depends on γ can be further simplified to: E λ γ(X) 2 2 Y -p0(X) 2 2 µn + (γ f θ , p(X )-γ(X)) (p(X)-p 0 (X)) 2 2 ∇ θ f θ (X) 2 2 (5) This has a very intuitive form: the first term is the impact of γ(X) on the variance, which is also related to the square of the noise of y, divided by the standard error scaling. The second controls how γ improves the bias introduced by the error in the teacher's p.

5. EXPERIMENTS

We complement our theoretical analysis with a pair of experiments demonstrating the practical benefits of cross-fitting and loss correction on six real-world classification tasks. Throughout, we use the SEL loss and report mean performance ± 1 standard error across 5 independent runs. Code to replicate all experiments can be found at https://github.com/microsoft/semiparametric-distillation, and supplementary experimental details and results can be found in App. I. Selecting the loss correction matrix γ Motivated by the analyses in Sec. 4, for each training point (x,y), we will select our correction matrix γ(x) to balance bias and variance by minimizing a pointwise upper bound on the loss correction error (5) (ideally with a closed-form solution to avoid excessive computational overhead). 3 To eliminate dependence on the unobserved p 0 , we observe that the bias term (γ f θ , p(x) -γ(x)) (p(x) -p 0 (x)) 2 2 = O( q f θ , p(x)-γ(x) 2 op ) up to additive terms independent of γ. We introduce a tunable hyperparameter α > 0 to trade off between this bias bound and the variance term in (5) and select γ(x) = diag(v(x)) to minimize: E[ γ(x)(y-p(x)) 2 2 | x]+α q f θ , p(x)-γ(x) 2 op = E[ v(x)(y-p(x)) 2 2 | x]+α 1 p(x) -v(x) 2 2 . Since the conditional expectation involves the unknown quantity p 0 , we estimate 4 This objective is quadratic in v(x) and thus has a closed-form solution. Given γ(x), the student's loss-corrected objective is equivalent to a square loss with labels log(p(x))+γ(x) (y-p(x)). E[ v(x)(y -p(x)) 2 2 | x] with its sample v(x)(y -p(x)) 2 2 . Tabular data. We first validate our KD enhancements on five real-world tabular datasets-FICO (FIC), StumbleUpon (Eve; Liu et al., 2017) , and Adult, Higgs, and MAGIC from Dheeru & Karra Taniskidou (2017)-with random forest (Breiman, 2001) students and teachers. In Fig. 1a , we examine the impact of varying student model capacity for a fixed high-capacity teacher with 500 trees on FICO. This setting lends itself to teacher overfitting, and we find that cross-fitting consistently improves upon vanilla KD by up to 4 AUC percentage points. In Fig. 1b we explore the impact of teacher underfitting by limiting the teacher's maximum tree depth on Adult. Here we observe consistent gains from loss correction with student performance exceeding even that of the teacher for smaller maximum tree depths. Analogous results for the remaining datasets can be found in App. I.1. Image data. We next validate our KD enhancements on the image classification dataset CIFAR-10 ( Krizhevsky & Hinton, 2009) . We pair a residual network (ResNet-8) student with teacher networks of varying depths (ResNet-14/20/32/44/56) (He et al., 2016) . It has been observed that larger and deeper teachers need not yield better students, as the teacher might overfit to the training set (Cho & Hariharan, 2019; Müller et al., 2019) . To induce this overfitting, we turn off data augmentation (random horizontal flipping and cropping). We compare students trained with Vanilla KD and Enhanced KD with and without loss correction in Fig. 2 . We find that cross-fitting consistently reduces the effect of teacher overfitting with largest impact realized for the deepest models. This effect is most evident in the cross-entropy test loss, where the Vanilla KD student incurs significantly larger loss than the cross-fitted student. For both accuracy and test loss, employing loss correction on top of cross-fitting provides an additional small performance boost. Effect of the loss correction hyperparameter α. Our hyperparameter α controls the tradeoff between bias and variance in loss correction. When α is very small, the objective is close to the vanilla KD objective. When α is large, the objective is closer to the Neyman-orthogonal loss. In Figure 3 , we show the effect of varying α, with ResNet-8 as the student and ResNet-20 as the teacher, on the CIFAR-10 dataset. Large values of α lead to high variance and thus lower test accuracy. Intermediate values of α improves on both the Vanilla KD objective, which corresponds to α = 0 and on the orthogonal objective (α = ∞). The test accuracy drops sharply beyond some threshold of α as the variance becomes too high (due to the terms q p(x) = diag 1 p1(x) ,..., 1 pK (x) ), causing training to become unstable. 

6. CONCLUSION

We developed a new analysis of knowledge distillation under the lens of semiparametric inference. By framing the KD process as learning with plug-in estimation in the presence of nuisance, we obtained new generalization bounds for distillation and new lower bounds highlighting the susceptibility of KD to teacher overfitting and underfitting. To address these failure modes, we introduced two complementary KD enhancements-cross-fitting and loss correction-which improve student performance both in theory and in practice. Past work has shown that augmenting the student training set with synthetic data from a generative model (e.g., a generative adversarial network (Liu et al., 2018) or MUNGE (Bucila et al., 2006) ) often leads to improved student performance. A natural next step is to prove an analogue of Thm. 5 for synthetic augmentation to understand when this strategy successfully mitigates the impact of teacher overfitting. In addition, two tantalizing open questions are, first, whether other techniques from semiparametric inference, such as targeted maximum likelihood (Van Der Laan

A EXTENDED LITERATURE REVIEW

We point the interested reader to Gou et al. (2020) for a sweeping survey of the many developments in knowledge distillation over the past half decade. In addition to the references discussing theoretical aspects of knowledge distillation provided in Sec. 1, we highlight here a number of empirical investigations of why distillation works. Cho & Hariharan (2019) show that larger teacher models do not necessarily improve the performance of student models as parsimonious student models are not able to mimic the teacher model. They suggest early stopping in training large teacher neural networks as means of regularizing. Cheng et al. (2020) demonstrate that when applied to image data, distillation allows the student neural net to learn multiple visual concepts simultaneously, while, when learning from raw data, neural networks learn concepts sequentially. Knowledge distillation has also been used for adversarial attacks (Papernot et al., 2016b; Ross & Doshi-Velez, 2017; Gil et al., 2019; Goldblum et al., 2020) , data security (Papernot et al., 2016a; Lopes et al., 2017; Wang et al., 2019) , image processing (Li & Hoiem, 2017; Wang et al., 2017; Chen et al., 2018; Li et al., 2017) , natural language processing (Nakashole & Flauger, 2017; Mou et al., 2016; Hu et al., 2018; Freitag et al., 2017) , and speech processing (Chebotar & Waters, 2016; Lu et al., 2017; Watanabe et al., 2017; Oord et al., 2018; Shen et al., 2018) . B GLOSSARY  (f,p) E[ (Z;f (X),p(X))] Empirical risk L n (f,p) E n [ (Z;f (X),p(X))] Population optimal student model f 0 argmin f ∈F L D (f,p 0 ) Empirical optimal student model f argmin f ∈F L n (f,p) f p,q f (X) p L q = E f (X) q p 1/q ∇ φ Partial derivative of (z;φ,π) with respect to the second input ∇ π Partial derivative of (z;φ,π) with respect to the third input  ∇ φπ [∇ φπ (z;φ,π)] i,j = ∂ 2 ∂φj ∂πi (z;φ,π) q f,p (x) E[∇ φπ (Z;f (X),p(X)) | X = x] γ f,p (x) E U ∼Unif([0,1]) [q f,U p+(1-U )p0 (x)] R(δ;F) Localized Rademacher complexity of function class F δ n Critical radius γ-corrected loss γ (z;f (x),p(x)) (z;f (x),p(x))+(y-p(x)) γ(x)f (x) Population γ-risk L D (f,p,γ) E[ γ (Z;f (X),p(X))] Empirical γ-risk L n (f,p,γ) E n [ γ (Z;f (X),p L n ( f ,p)-L n (f 0 ,p)-(L D ( f ,p)-L D (f 0 ,p)) ≤ O Hδ n,ζ f , p -f0, p 2,2 +Hδ 2 n,ζ with probability at least 1-ζ. Moreover, by Cauchy-Scwharz, f , p -f0, p 2,2 ≤ µ 4 f -f 0 2,4 . By the assumed 2 / 4 ratio condition we therefore have ( f ,p,γ) ≤ O δ n,ζ CH µ 4 f -f 0 2,2 +Hδ 2 n,ζ . Plugging this bound into Lemma 4 (which holds irrespective of whether data re-use, sample splitting, or cross-fitting is employed) and applying the arithmetic-geometric mean inequality yields σ 8 f -f 0 2 2,2 ≤ 1 σ O δ 2 n,ζ C 2 H 2 µ 2 4 + γ p,0 (p-p 0 ) 2 2,2 by Taylor's theorem with Lagrange remainder. This non-vanishing student error reflects the non-vanishing critical radius δ n of the composite student-teacher function class G defined in Thm. 1; since the student function class F has low complexity, the complexity of G is driven by the highly flexible interpolating teacher. Next, instantiate the notation of Thm. 5, and consider a student prediction rule f trained via Enhanced KD with SEL loss, γ(t) = 0, B = O(1), and F (6). The critical radius of G(p (t) ,γ (t) ) satisfies δ n/B = O( Bk/n) by Wainwright (2019, Ex. 13.8) . Moreover, each cross-fitted teacher satisfies E[ p 0 -p(t) 2 2,2 ] = O(n -4/(4+d) ) by Belkin et al. (2019, Thm. 1) , so, by Chebyshev's and Jensen's inequalities, with probability at least 4+d) ) for all t. Therefore, Thm. 5 implies that 4+d) ) with probability at least 1-ζ. 1-ζ/2, p 0 -p(t) 2,2 ≤ E[ p 0 -p(t) 2,2 ]+ 2BVar( p 0 -p(t) 2,2 )/ζ ≤ (1+ 2B/ζ) E[ p 0 -p(t) 2 2,2 ] = O(n -2/( f -f 0 2 2,2 = O( 1 n + 1 B B t=1 (γ f0, p(t) ) (p (t) -p 0 ) 2 2,2 ) = O( 1 n + 1 B B t=1 (diag( 1 p(t) )(p (t) -p 0 ) 2 2,2 ) = O( 1 n + 1 B 2 B t=1 p(t) -p 0 2 2,2 ) = O(n -4/( F PROOF OF LEMMA 4: ALGORITHM-AGNOSTIC ANALYSIS First we define for any functional L(f ) the Frechet derivative as: D f L(f )[ν] = ∂ ∂t L(f +tν) | t=0 When L is an operator of the form: E[g(f (X))], then: D f L(f )[ν] = E[∇g(f (X)) ν(X)]. By the σ-strong convexity of L D ,foot_6 we have that L D ( f ,p,γ) ≥ L D (f 0 ,p,γ)+D f L D (f 0 ,p,γ)[ f -f 0 ]+ σ 2 f -f 0 2 2,2 . Furthermore, our excess risk assumption and the optimality of f 0 give us σ 2 f -f 0 2 2,2 ≤ L D ( f ,p,γ)-L D (f 0 ,p,γ) excess risk of f -D f L D (f 0 ,p,γ)[ f -f 0 ] (a) ≤ ( f ,p,γ)-D f L D (f 0 ,p 0 ,γ)[ f -f 0 ] ≥0 by optimality of f0 +D f (L D (f 0 ,p 0 ,γ)-L D (f 0 ,p,γ))[ f -f 0 ]. By Taylor's theorem with integral remainder, E[ ∇ φ (W ;f 0 (x),p 0 (x))-∇ φ (W ;f 0 (x),p(x)), f (x)-f 0 (x) | X = x] = (p 0 (x)-p(x)) γ f0, p(x)( f (x)-f 0 (x)) whenever ∇ φπ is well-defined. We can now invoke the expansion (7) and Cauchy-Schwarz to obtain the bound D f (L D (f 0 ,p 0 ,γ)-L D (f 0 ,p,γ))[ f -f 0 ] = E[ ∇ φ (W ;f 0 (X),p 0 (X))-∇ φ (W ;f 0 (X),p(X)), f (X)-f 0 (X) ] -E[(p 0 (X)-p(X)) γ(X)( f (X)-f 0 (X))] = E[(p 0 (X)-p(X)) (γ f0, p(X )-γ(X))( f (X)-f 0 (X))] ≤ E[ (p 0 (X)-p(X)) (γ f0, p(X )-γ(X)) 2 f (X)-f 0 (X) 2 ] ≤ (p 0 -p) (γ f0, p -γ) 2,2 f -f 0 2,2 Thus combining all the above inequalities: σ 2 f -f 0 2 2,2 ≤ ( f ,p,γ)+ (p-p 0 ) (γ f0, p -γ) 2,2 f -f 0 2,2 By an AM-GM inequality, for all a,b ≥ 0: a •b ≤ 1 2 ( 2 σ a 2 + σ 2 b 2 ). Applying this to the product of norms on the RHS and re-arranging yields σ 4 f -f 0 2 2,2 ≤ ( f ,p,γ)+ 1 σ (p-p 0 ) (γ f0, p -γ) 2 2,2 . To get the final inequality, observe that: (p-p 0 ) (γ f0, p -γ) 2 2,2 ≤ 2 (p-p 0 ) (q f0, p -γ) 2 2,2 +2 (p-p 0 ) (γ f0, p -q f0, p) 2 2,2 Moreover, by the boundedness of the third derivative, we have: (p-p 0 ) (γ f0, p -q f0, p) 2 2,2 ≤ E[ p(X)-p 0 (X) 2 2 γ f0, p(X )-q f0, p(X ) 2 2 ] ≤ E[ p(X)-p 0 (X) 2 2 M 2 k p(X)-p 0 (X) 2 2 ] ≤ M 2 k p-p 0 4 2,4 Combining all the above yields the final bound. G PROOF OF THM. 5: CROSS-FITTED ERM ANALYSIS Let L n,t denote the empirical loss over the samples in the t-th fold and p(t) ,γ (t) the nuisance functions used on the samples in the k-th fold. For any t ∈ [K] and conditional on p(t) ,γ (t) , suppose that δ n upper bounds the critical radius of the function class G(p (t) ,γ (t) ), then by Lemma 11 of Foster & Syrgkanis (2019),foot_7 if we denote with t,f (z) = γ(t) (z;f (x),p (t) (x)), w.p. 1-ζ: L n,t ( f ,p (t) ,γ (t) )-L n,t (f 0 ,p (t) ,γ (t) )-(L D ( f ,p (t) ,γ (t) )-L D (f 0 ,p (t) ,γ (t) )) ≤ O Hδ n/B,ζ t, f -t,f0 2,2 +Hδ 2 n/B,ζ Moreover, we have that by the definition of cross-fitted ERM: 1 B B t=1 L n,t ( f ,p (t) ,γ (t) )-L n,t (f 0 ,p (t) ,γ (t) ) ≤ 0 Thus we have that w.p. 1-ζB: 1 B B t=1 L D ( f ,p (t) ,γ (t) )-L D (f 0 ,p (t) ,γ (t) ) ≤ O Hδ n/B,ζ 1 B B t=1 t,f -t,f0 2,2 +Hδ 2 n/B,ζ Moreover, if we let µ(z) = sup φ,t ∇ φ (z;φ,p (t) (x)) 2 , then we have by Cauchy-Schwarz inequality: t,f -t,f0 2,2 ≤ µ 4 f -f 0 2,4 + E (Y -p(t) (X)) γ(t) (X)(f (X)-f 0 (X)) 2 ≤ µ 4 f -f 0 2,4 +E (Y -p(t) (X)) γ(t) (X) 2 2 f (X)-f 0 (X)) 2 ≤ µ 4 +E (Y -p(t) (X)) γ(t) (X) 4 2 1/4 f -f 0 2,4 If we further assume that the function class F satisfies an 2 / 4 condition that: sup f ∈F f -f 0 2,4 f -f 0 2,2 ≤ C then w.p. 1-ζ: 1 B B t=1 ( f ,p (t) ,γ (t) ) ≤ O Hδ n/B,ζ/B 1 B B t=1 C µ p 4 +E (Y -p(X)) γ(X) 4 2 1/4 f -f 0 2,2 +Hδ 2 n/B,ζ/B . Applying Lemma 4 for any p(t) ,γ (t) and averaging the final inequality we get: σ 4 f -f 0 2 2,2 ≤ 1 B B t=1 ( f ,p (t) ,γ (t) )+ 1 σ (γ f0, p(t) -γ (t) ) (p (t) -p 0 ) 2 2,2 . Plugging in the bound above to Lemma 4 and applying the AM-GM inequality and Jensen's inequality, yields: σ 8 f -f 0 2 2,2 ≤ 1 σ O δ 2 n/B,ζ/B C 2 H 2 µ 2 4 + 1 B B t=1 E (Y -p(t) (X)) γ(t) (X) 4 2 + 1 σ O 1 B B t=1 (γ f0, p(t) -γ (t) ) (p (t) -p 0 ) 2 2,2 . H PROOF OF THM. 6: BIASED SGD ANALYSIS Below, for any integer s, we define the operator norm of any vector v ∈ R s and any tensor T operating on R s as v op v 2 and T op sup v: v 2 =1 T [v] op . Recall the definition ∇(W ;θ,p,γ) = ∇ θ f θ (X) ∇ φ γ (W ;f θ (X),p(X)) = ∇ θ f θ (X) (∇ φ (W ;f θ (X),p(X))+γ(X) (Y -p(X))). Observe that since E[Y | X = x] = p 0 (x), we can write for any γ: L(θ;p 0 ) = E[ (W ;f θ (X),p 0 (X))+(Y -p 0 (X)) γ(X)f θ (X)] = E[ γ (W ;f θ (X),p 0 (X)) ] Thus we also have that: ∀θ,γ : ∇ θ L(θ;p 0 ) = E[∇(W ;θ,p 0 ,γ)] Given this observation, we can decompose the gradient that is used in our SGD algorithm into a bias and variance component, when viewed from the perspective of a biased SGD algorithm for the population oracle loss: ∇(W ;θ,p,γ) = ∇ θ L(θ;p 0 ) +E[∇(W ;θ,p,γ)]-E[∇(W ;θ,p 0 ,γ)] b(θ,p,γ) +∇(W ;θ,p,γ)-E[∇(W ;θ,p,γ)] n(W ;θ,p,γ) The following two lemmas bound the gradient bias and noise terms. Lemma 7 (Gradient bias). If sup x,φ,π E[∇ ππφ (W ;φ,π) | X = x] op ≤ M , then for any parameter vector θ and functions p and γ, we have: b(θ,p,γ) = E[∇ θ f θ (X) (γ f θ ,p (X)-γ(X)) (p(X)-p 0 (X))], b(θ,p,γ) 2 ≤ ∇ θ f θ (γ f θ ,p -γ) (p-p 0 ) 2,2 , and b(θ,p,γ) 2 ≤ ∇ θ f θ (q f θ ,p -γ) (p-p 0 ) 2,2 + M 2 ∇ θ f θ F,2 p-p 0 2 2,4 . Proof By Taylor's theorem with integral remainder and Lagrange remainder respectively the SGD bias for each parameter i takes the form b i (θ,p,γ) = E[∇ i (W ;θ,p,γ)]-E[∇ i (W ;θ,p 0 ,γ)] = E[ ∇ φ γ (W ;f θ (X),p(X))-∇ φ γ (W ;f θ (X),p 0 (X)),∇ θi f θ (X) ] = E[(p(X)-p 0 (X)) (γ f θ ,p (X)-γ(X))∇ θi f θ (X)] = E[(p(X)-p 0 (X)) (q f θ ,p (X)-γ(X))∇ θi f θ (X)] + 1 2 E[∇ ππφ (W ;f θ (X),p(X))[∇ θi f θ (X),p(X)-p 0 (X),p(X)-p 0 (X)]]. Furthermore, our operator norm assumption and Cauchy-Schwarz imply |b i (θ,p,γ)| ≤ |E[(p(X)-p 0 (X)) (q f θ ,p (X)-γ(X))∇ θi f θ (X)]|+ M 2 E ∇ θi f θ (X) 2 p(X)-p 0 (X) 2 2 ≤ |E[(p(X)-p 0 (X)) (q f θ ,p (X)-γ(X))∇ θi f θ (X)]|+ M 2 ∇ θi f θ 2,2 p-p 0 2 2,4 . Thus, by the triangle inequality and Jensen's inequality we find that b(θ,p,γ) 2 ≤ ∇ θ f θ (γ f θ ,p -γ) (p-p 0 ) 2,2 and b(θ,p,γ) 2 ≤ ∇ θ f θ (q f θ ,p -γ) (p-p 0 ) 2,2 + M 2 ∇ θ f θ F,2 p-p 0 2 2,4 . Lemma 8 (Gradient Variance). Define For any parameter θ and functions p and γ, E[ n(W ;θ,p,γ) 2 2 ] ≤ σ 0 (θ)+ E[ ∇ θ f θ (X) γ(X) (Y -p 0 (X)) 2 2 ]+ ∇ θ f θ (γ f θ ,p -γ) (p-p 0 ) 2 2,2 . We show the full results for all 5 of the datasets in Figs. 4 and 5 . We use SGD with initial learning rate 0.1, momentum 0.9, and batch size 128 to train for 200 epochs. We use the standard learning rate decay schedule, where the learning rate is divided by 5 at epoch 60, 120, and 160. For loss correction, we select the value of the hyperparameter α that yields the highest accuracy on a held-out validation set. For cross-fitting, we use 10 folds.



These loss functions do not depend on the ground-truth label y, but we use the augmented notation (z;f (x),p(x)) to accommodate the enhanced distillation losses presented in Sec. 4. In fact, under the Lipschitz condition alone and using contraction lemma arguments as inFoster & Syrgkanis (2019, Lem. 11), one can derive a version of Thm. 5 in which the upper bound depends only on the critical radius of the function class {r(f -f0) : f ∈ F,r ∈ [0,1]}, which solely depends on the function space of the student. Balancing the bias and variance terms (3) of Thm. 5 yields a similar objective. An alternative estimate that performs slightly worse is v(x)p(x)(1-p(x)) 2 2 . & Rubin, 2006), can be used to improve KD performance and, second, whether a semiparametric perspective can explain the surprising success of self-distillation(Furlanello et al., 2018) and noisy student training(Xie et al., 2020) through which students routinely outperform their teachers. We apply Foster & Syrgkanis (2019, Lem. 11) with Lg = g for g ∈ G with g * = 0. Then we instantiate the concentration inequality for the choice g = f , pf 0 , p ∈ G. Notably this strong convexity assumption can be relaxed to E ∇ φ (W ;f0(X),p0(X))( f (X)-f0(X) ≥ 0. We apply the lemma with Lg = g and g ∈ G(p(t) ,γ(t) ) and g * = 0. Then we instantiate the concentration inequality with g = t, ft,f 0 ∈ G(p (t) ,γ(t) ).



Figure1: For random forest students and teachers, cross-fitting improves student performance when the teacher overfits, while loss correction improves student performance when the teacher underfits.

Figure2: On CIFAR-10 with ResNet students and teachers, cross-fitting reduces the effect of teacher overfitting, and loss correction yields an additional small performance boost. Here, the test loss is cross-entropy.

Figure 3: On CIFAR-10 with ResNet students and teachers, large values of the loss correction hyperparameter α (corresponding to the orthogonal loss correction) lead to large variance and training instability, while intermediate values improve upon cross-fit KD without loss correction (α = 0). Here, the test loss is cross-entropy.

X))] C PROOF OF THM. 1: VANILLA DISTILLATION ANALYSIS Introduce the shorthand f, p(z) = (z;f (x),p(x)). Since δ n upper bounds the critical radius of the function class G, the localized Rademacher analysis of Foster & Syrgkanis (2019, Lem. 11) implies 5

Figure 4: Tabular random forest distillation with varying student complexity.

The critical radius of a class F, taking values in [-H,H], is the smallest positive solution δ n to the inequality R(δ;F) ≤ δ 2 H .Theorem 1 (Vanilla KD analysis). Suppose f 0 belongs to a convex set F satisfying the 2 / 4 ratio condition sup f ∈F f -f 0 2,4 / f -f 0 2,2 ≤ C and that the teacher estimates p ∈ P from the same dataset used to train the student. Let δ n,ζ = δ n +c 0

Glossary of notation

annex

D PROOF OF PROP. 2: IMPACT OF TEACHER UNDERFITTING ON VANILLA DISTILLATION Suppose that p 0 does not vary with x and, for known > 0, belongs to the set P = {p : p j (x) ∈ [ ,1], ∀x ∈ X , j ∈ [k]}. As all quantities in this proof are independent of x, we will omit the dependence on x whenever convenient.Consider the constant teacher estimate p = ȳ 1+λ ∨ obtained via ridge regression with regularization strength λ ≤ 1 and ȳ 1 n n i=1 y i . A constant student prediction rule intrained via Vanilla KD with SEL loss yields f (x) = log(p).Suppose that, unbeknownst to the teacher and student, the true p 0 satisfies the more stringent condition p 0,j ≥ 2 for all j ∈ [k].Then the student satisfiesby the concavity of the logarithm and the choice λ ≤ 1. Since (Bernstein, 1946) , we have= 1 with probability 1 by the law of the iterated logarithm,) now yields the first two advertised claims.The final claim follows directly from Thm. 5 with B = O(1) as γ(t) = γ f0, p and the critical radius of G(p (t) , γ(t) ) satisfies δ n/B = O( Bk/n) by Wainwright (2019, Ex. 13.8 ).E PROOF OF PROP. 3: IMPACT OF TEACHER OVERFITTING ON VANILLA DISTILLATION Suppose that p 0 has Lipschitz gradient and, for known > 0, belongs to the setSuppose moreover that X ∈ R d has Lebesgue density bounded away from 0 and ∞ and thatE[(1-p0,j (X))/p0,j (X)] for each j. Consider the teacher estimates pj (x) = max( ,p j (x)) for p the Nadaraya-Watson kernel smoothing estimator (Nadaraya, 1964; Watson, 1964) 4+d) . By Belkin et al. (2019, Thm. 1) , the teacher satisfies 4+d) ). Now instantiate the notation of Thm. 1, and consider a student prediction rule trained to learn a constant prediction rule via Vanilla KD with the SEL loss and(6) Since p exactly interpolates the observed labels (i.e., p(x i ) = y i ), the critical radius of the teacher-student function class G satisfies δ n = Ω(1). Moreover, since the student only has access to the teacher's training set probabilities, its estimate,p(X)-p 0 (X)]] for some convex combination p(X) of p(X) and p 0 (X).We begin by bounding the target expectation using Cauchy-SchwarzWe next employ the law of total variance to rewrite the variance terms:Finally, we control Var[Z i ] using Cauchy-SchwarzThe two claims of Thm. 6 now follow from Theorems 2 and 3 of Ajalloeian & Stich (2020) respectively, with the parameters σ 2 and ζ instantiated with quantities σ 2 (γ) and ζ(γ) of Lemmas 7 and 8.

I.1 TABULAR DATA

We use cross-fitting with 10 folds. The student is trained using the SEL loss with clipped teacher class probabilities max(p(x), ) for = 10 -3 . The α hyperparameter of the loss correction was chosen by cross-validation with 5 folds. We repeat the experiments 5 times to measure the mean and standard deviation.For the overfitting experiment, we use a random forest with 500 trees as the teacher and a random forest with 1-40 trees as the student.We also evaluate the impact of teacher underfitting by limiting the teacher's maximum tree depth (from 1 to 20). Lower depth corresponds to greater underfitting. The teacher has 100 trees, and the student has 10 trees. For all of the datasets, loss correction successfully mitigates the teacher's underfitting and thus improves the student's performance. The effect is most pronounced when the teacher underfits more heavily (has lower tree depth).

