KNOWLEDGE DISTILLATION AS SEMIPARAMETRIC INFERENCE

Abstract

A popular approach to model compression is to train an inexpensive student model to mimic the class probabilities of a highly accurate but cumbersome teacher model. Surprisingly, this two-step knowledge distillation process often leads to higher accuracy than training the student directly on labeled data. To explain and enhance this phenomenon, we cast knowledge distillation as a semiparametric inference problem with the optimal student model as the target, the unknown Bayes class probabilities as nuisance, and the teacher probabilities as a plug-in nuisance estimate. By adapting modern semiparametric tools, we derive new guarantees for the prediction error of standard distillation and develop two enhancements-cross-fitting and loss correction-to mitigate the impact of teacher overfitting and underfitting on student performance. We validate our findings empirically on both tabular and image data and observe consistent improvements from our knowledge distillation enhancements.

1. INTRODUCTION

Knowledge distillation (KD) (Craven & Shavlik, 1996; Breiman & Shang, 1996; Bucila et al., 2006; Li et al., 2014; Ba & Caruana, 2014; Hinton et al., 2015) is a widely used model compression technique that enables the deployment of highly accurate predictive models on devices such as phones, watches, and virtual assistants (Stock et al., 2020) . KD operates by training a compressed student model to mimic the predicted class probabilities of an expensive, high-quality teacher model. Remarkably and across a wide variety of domains (Hinton et al., 2015; Sanh et al., 2019; Jiao et al., 2019; Liu et al., 2018; Tan et al., 2018; Fakoor et al., 2020) , this two-step process often leads to higher accuracy than training the student directly on the raw labeled dataset. While the practice of KD is now well developed, a general theoretical understanding of its successes and failures is still lacking. As we detail below, a number of authors have argued that the success of KD lies in the more precise "soft labels" provided by the teacher's predicted class probabilities. Recently, Menon et al. (2020) observed that these teacher probabilities can serve as a proxy for the Bayes probabilities (i.e., the true class probabilities) and that the closer the teacher and Bayes probabilities, the better the student's performance should be. Building on this observation, we cast KD as a plug-in approach to semiparametric inference (Kosorok, 2007) : that is, we view KD as fitting a student model f in the presence of nuisance (the Bayes probabilities p 0 ) with the teacher's probabilities p as a plug-in estimate of p 0 . This insight allows us to adapt modern tools from semiparametric inference to analyze the error of a distilled student in Sec. 3. Our analysis also reveals two distinct failure modes of KD: one due to teacher overfitting and data reuse and the other due to teacher underfitting from model misspecification or insufficient training. In Sec. 4, we introduce and analyze two complementary KD enhancements that correct for these failures: cross-fitting-a popular technique from semiparametric inference (see, e.g., Chernozhukov et al., 2018)-mitigates teacher overfitting through data partitioning while loss correction mitigates teacher underfitting by reducing the bias of the plug-in estimate p. The latter enhancement was inspired by the orthogonal machine learning (Chernozhukov et al., 2018; Foster & Syrgkanis, 2019) approach to semiparametric inference which suggests a particular adjustment for the teacher's log probabilities. We argue in Sec. 4 that this orthogonal correction minimizes the teacher bias but often at the cost of unacceptably large variance. Our proposed correction avoids this variance explosion by balancing the bias and variance terms in our generalization bounds. In Sec. 5, we complement our theoretical analysis with a pair of experiments demonstrating the value of our enhancements on six real classification problems. On five real tabular datasets, cross-fitting and loss correction improve student performance by up to 4% AUC over vanilla KD. Furthermore, on CIFAR-10 (Krizhevsky & Hinton, 2009) , a benchmark image classification dataset, our enhancements improve vanilla KD accuracy by up to 1.5% when the teacher model overfits. Related work. Since we cannot review the vast literature on KD in its entirety, we point the interested reader to Gou et al. (2020) for a recent overview of the field. We devote this section to reviewing theoretical advances in the understanding of KD and summarize complementary empirical studies and applications of in the extended literature review in App. A. 2020) considers more generic model compression in a rate-distortion framework, where the rate is the size of the student model and distortion is the difference in excess risk between the teacher and the student. Menon et al. (2020) consider the case of losses such that the population risk is linear in the Bayes class probabilities. They consider distilled empirical risk and Bayes distilled empirical risk which are the risk computed using the teacher class probabilities and Bayes class probabilities respectively rather than the observed label. They show that the variance of the Bayes distilled empirical risk is lower than the empirical risk. Then using analysis from Maurer & Pontil (2009); Bennett (1962) , they derive the excess risk of the distilled empirical risk as a function of the 2 distance between the teacher's class probabilities and the Bayes class probabilities. We significantly depart from Menon et al. (2020) in multiple ways: i) our Thm. 1 allows for the common practice of data re-use, ii) our results cover the standard KD losses SEL and ACE which are non-linear in p 0 , iii) we use localized Rademacher analysis to achieve tight fast rates for standard KD losses, and iv) we use techniques from semiparametric inference to improve upon vanilla KD.

2. KNOWLEDGE DISTILLATION BACKGROUND

We consider a multiclass classification problem with k classes and n training datapoints z i = (x i ,y i ) sampled independently from some distribution P. Each feature vector x belongs to a set X , each label vector y ∈ {e 1 ,...,e k } ⊂ {0,1} k is a one-hot encoding of the class label, and the conditional probability of observing each label is the Bayes class probability function p 0 (x) = E[Y | X = x]. Our aim is to identify a scoring rule f : X → R k that minimizes a prediction loss on average under the distribution P. Knowledge distillation. Knowledge distillation (KD) is a two-step training process where one first uses a labeled dataset to train a teacher model and then trains a student model to predict the teacher's predicted class probabilities. Typically the teacher model is larger and more cumbersome, while the student is smaller and more efficient. Knowledge distillation was first motivated by model compression (Bucila et al., 2006) , to find compact yet high-performing models to be deployed (such as on mobile devices). In training the student to match the teacher's prediction probability, there are several types of loss functions that are commonly used. Let p(x) ∈ R k be the teacher's vector of predicted class probabilities, f (x) ∈ R k be the student model's output, and [k] {1,2,...,k}. The most popular distillation loss



number of papers have argued that the availability of soft class probabilities from the teacher rather than hard labels enables us to improve training of the student model. This was hypothesized in Hinton et al. (2015) with empirical justification. Phuong & Lampert (2019) consider the case in which the teacher is a fixed linear classifier and the student is either a linear model or a deep linear network. They show that the student can learn the teacher perfectly if the number of training examples exceeds the ambient dimension. Vapnik & Izmailov (2015) discuss the setting of learning with privileged information where one has additional information at training time which is not available at test time. Lopez-Paz et al. (2015) draw a connection between this and KD, arguing that KD is effective because the teacher learns a better representation allowing the student to learn at a faster rate. They hypothesize that a teacher's class probabilities enable student improvement by indicating how difficult each point is to classify. Tang et al. (2020) argue using empirical evidence that label smoothing and reweighting of training examples using the teacher's predictions are key to the success of KD. Mobahi et al. (2020) analyzed the case of self-distillation in which the student and teacher function classes are identical. Focusing on kernel ridge regression models, they proved that self-distillation can act as increased regularization strength. Bu et al. (

