PROGRESSIVE KNOWLEDGE DISTILLATION: BUILDING ENSEMBLES FOR EFFICIENT INFERENCE Anonymous

Abstract

We study the problem of progressive distillation: Given a large, pre-trained teacher model g, we seek to decompose the model into an ensemble of smaller, lowinference cost student models f i . The resulting ensemble allows for flexibly tuning accuracy vs. inference cost, which can be useful for a multitude of applications in efficient inference. Our method, B-DISTIL, uses an algorithmic procedure that uses function composition over intermediate activations to construct expressive ensembles with similar performance as g, but with much smaller student models. We demonstrate the effectiveness of B-DISTIL by decomposing pretrained models across a variety of image, speech, and sensor datasets. Our method comes with strong theoretical guarantees in terms of convergence as well as generalization.

1. INTRODUCTION

Knowledge distillation aims to transfer the knowledge of a large model into a smaller one (Buciluǎ et al., 2006; Hinton et al., 2015) . While this technique is commonly used for model compression, one downside is that the procedure is fairly rigid-resulting in a single compressed model of a fixed size. In this work, we instead consider the problem of progressive distillation: approximating a large model via an ensemble of smaller, low-latency models. The resulting decomposition is useful for a number of applications in efficient inference. For example, components of the ensemble can be selectively combined to flexibly meet accuracy/latency constraints (Li et al., 2019; Yang & Fan, 2021) , can enable efficient parallel inference execution schemes, and can facilitate early-exit (Bolukbasi et al., 2017; Dennis et al., 2018) or anytime inference (Ruiz & Verbeek, 2021; Huang et al., 2017a) applications, which are scenarios where inference may be interrupted due to variable resource availability. More specifically, we seek to distill a large pre-trained model, g, onto an ensemble comprised of low-parameter count, low-latency models, f i . We additionally aim for the resulting ensemble to form a decomposition, such that evaluating the first model produces a coarse estimate of the prediction (e.g., covering common cases), and evaluating additional models improves on this estimate (see Figure 1 ). There are major advantages to such an ensemble for on-device inference: 1) inference cost vs. accuracy trade-offs can be controlled on-demand at execution time, 2) the ensemble can either be executed in parallel or in sequence, or possibly a mix of both, and 3) we can improve upon coarse initial predictions without re-evaluation in response to resource availability. While traditional distillation methods are effective when transferring information to a single model of similar capacity, it has been shown that performance can degrade significantly when reducing the capacity of the student model (Mirzadeh et al., 2020; Gao et al., 2021) . Moreover, distillation of a deep network onto a weighted sum of shallow networks rarely performs better than distillation onto a single model (Cho & Hariharan, 2019; Allen-Zhu & Li, 2020) . Our method exploits connections to minimax optimization and online learning (Schapire & Freund, 2013) to allow models in our ensemble to compose and reuse intermediate activation outputs of other models during inference. As long as these composition functions are resource efficient, we can increase base class capacity at roughly the same inference cost as a single model. Moreover, we show that our procedure retains the theoretical guarantees of these methods (Schapire & Freund, 2013) . Concretely, • We formulate progressive distillation as a two player zero-sum game, derive a weak learning condition for distillation and present our algorithm, B-DISTIL, to approximately solve this game. To make the search for weak learners in low parameter count models feasible, we explicitly solve The more student models we evaluate, the closer the ensemble's decision boundary is to that of the teacher model. Models in ensemble are allowed to depend on previously computed features. a log-barrier based relaxation of our weak learning condition. Moreover, by allowing models to reuse computation from select intermediate layers of previously evaluated models of the ensemble, we are able to increase the model's capacity without significant increase in inference cost. • We empirically evaluate our algorithm on synthetic as well as real-world classification tasks from vision, speech and sensor data with models suitable for the respective domains. We show that our ensemble behaves like a decomposition, allowing a run-time trade-off between accuracy and computation, while remaining competitive to the teacher model. • We provide theoretical guarantees for our algorithm in terms of in-sample convergence and generalization performance. Our framework is not architecture or task specific and can recover existing ensemble models used in efficient inference, and we believe puts forth a general lens to view previous work and also to develop new, principled approaches for efficient inference.

2. BACKGROUND AND RELATED WORK

Knowledge distillation. Machine learning inference is often resource-constrained in practice due to requirements around metrics such as memory, energy, cost, or latency. This has spurred the development of numerous techniques for model compression. A particularly popular approach is knowledge distillation, which considers transferring the knowledge of a larger model (or model ensemble) to a smaller one (Buciluǎ et al., 2006; Hinton et al., 2015) . Despite its popularity, performing compression via distillation has several known pitfalls. Most notably, many have documented that distillation may not perform well when there is a capacity gap between the teacher and student, i.e., the teacher is significantly larger than the student (Mirzadeh et al., 2020; Gao et al., 2021; Cho & Hariharan, 2019; Allen-Zhu & Li, 2020) . When performing distillation onto a weighted combination of ensembles, it has been observed that adding additional models into the ensemble does not dramatically improve performance over that of a single distilled model (Allen-Zhu & Li, 2020). There is also a lack of theoretical work characterizing when and why distillation is effective for compression Gou et al. (2021) . Our work aims to address many of these pitfalls by developing a principled approach for progressively distilling a large model onto an ensemble of smaller, low-capacity ones. Early exits and anytime inference. Numerous applications stand to benefit from the output of progressive distillation, which allows for flexibly tuning accuracy vs. inference cost and executing inference in parallel. Enabling trade-offs between accuracy and inference cost is particularly useful for applications that use early exit or anytime inference schemes. In on-device continuous (online) inference settings, early exit models aim to evaluate common cases quickly in order to improve energy efficiency and prolong battery life (Dennis et al., 2018; Bolukbasi et al., 2017) . For instance, a battery powered device continuously listening for voice commands can use early exit methods to improve battery efficiency by classifying non-command speech quickly. Many methods that produce early exit models are also applicable to anytime inference (Huang et al., 2017a; Ruiz & Verbeek, 2021) . In anytime inference, the aim is to produce a prediction even when inference is interrupted, say due to resource contention or a scheduler decision. Unlike early exit methods where the classifier chooses when to exit early, anytime inference methods have no control over when they are interrupted. We explore the effectiveness of using our method, B-DISTIL, for such applications in Section 5.



Figure 1: In progressive distillation, a large teacher model is distilled onto low inference cost models.The more student models we evaluate, the closer the ensemble's decision boundary is to that of the teacher model. Models in ensemble are allowed to depend on previously computed features.

