KNOWLEDGE DISTILLATION AS SEMIPARAMETRIC INFERENCE

Abstract

A popular approach to model compression is to train an inexpensive student model to mimic the class probabilities of a highly accurate but cumbersome teacher model. Surprisingly, this two-step knowledge distillation process often leads to higher accuracy than training the student directly on labeled data. To explain and enhance this phenomenon, we cast knowledge distillation as a semiparametric inference problem with the optimal student model as the target, the unknown Bayes class probabilities as nuisance, and the teacher probabilities as a plug-in nuisance estimate. By adapting modern semiparametric tools, we derive new guarantees for the prediction error of standard distillation and develop two enhancements-cross-fitting and loss correction-to mitigate the impact of teacher overfitting and underfitting on student performance. We validate our findings empirically on both tabular and image data and observe consistent improvements from our knowledge distillation enhancements.

1. INTRODUCTION

Knowledge distillation (KD) (Craven & Shavlik, 1996; Breiman & Shang, 1996; Bucila et al., 2006; Li et al., 2014; Ba & Caruana, 2014; Hinton et al., 2015) is a widely used model compression technique that enables the deployment of highly accurate predictive models on devices such as phones, watches, and virtual assistants (Stock et al., 2020) . KD operates by training a compressed student model to mimic the predicted class probabilities of an expensive, high-quality teacher model. Remarkably and across a wide variety of domains (Hinton et al., 2015; Sanh et al., 2019; Jiao et al., 2019; Liu et al., 2018; Tan et al., 2018; Fakoor et al., 2020) , this two-step process often leads to higher accuracy than training the student directly on the raw labeled dataset. While the practice of KD is now well developed, a general theoretical understanding of its successes and failures is still lacking. As we detail below, a number of authors have argued that the success of KD lies in the more precise "soft labels" provided by the teacher's predicted class probabilities. Recently, Menon et al. (2020) observed that these teacher probabilities can serve as a proxy for the Bayes probabilities (i.e., the true class probabilities) and that the closer the teacher and Bayes probabilities, the better the student's performance should be. Building on this observation, we cast KD as a plug-in approach to semiparametric inference (Kosorok, 2007) : that is, we view KD as fitting a student model f in the presence of nuisance (the Bayes probabilities p 0 ) with the teacher's probabilities p as a plug-in estimate of p 0 . This insight allows us to adapt modern tools from semiparametric inference to analyze the error of a distilled student in Sec. 3. Our analysis also reveals two distinct failure modes of KD: one due to teacher overfitting and data reuse and the other due to teacher underfitting from model misspecification or insufficient training. In Sec. 4, we introduce and analyze two complementary KD enhancements that correct for these failures: cross-fitting-a popular technique from semiparametric inference (see, e.g., Chernozhukov et al., 2018) -mitigates teacher overfitting through data partitioning while loss correction mitigates teacher underfitting by reducing the bias of the plug-in estimate p. The latter enhancement was inspired by the orthogonal machine learning (Chernozhukov et al., 2018; Foster & Syrgkanis, 2019) approach to semiparametric inference which suggests a particular adjustment for the teacher's log probabilities. We argue in Sec. 4 that this orthogonal correction minimizes the teacher bias but often at the cost of

