PROGRESSIVE KNOWLEDGE DISTILLATION: BUILDING ENSEMBLES FOR EFFICIENT INFERENCE Anonymous

Abstract

We study the problem of progressive distillation: Given a large, pre-trained teacher model g, we seek to decompose the model into an ensemble of smaller, lowinference cost student models f i . The resulting ensemble allows for flexibly tuning accuracy vs. inference cost, which can be useful for a multitude of applications in efficient inference. Our method, B-DISTIL, uses an algorithmic procedure that uses function composition over intermediate activations to construct expressive ensembles with similar performance as g, but with much smaller student models. We demonstrate the effectiveness of B-DISTIL by decomposing pretrained models across a variety of image, speech, and sensor datasets. Our method comes with strong theoretical guarantees in terms of convergence as well as generalization.

1. INTRODUCTION

Knowledge distillation aims to transfer the knowledge of a large model into a smaller one (Buciluǎ et al., 2006; Hinton et al., 2015) . While this technique is commonly used for model compression, one downside is that the procedure is fairly rigid-resulting in a single compressed model of a fixed size. In this work, we instead consider the problem of progressive distillation: approximating a large model via an ensemble of smaller, low-latency models. The resulting decomposition is useful for a number of applications in efficient inference. For example, components of the ensemble can be selectively combined to flexibly meet accuracy/latency constraints (Li et al., 2019; Yang & Fan, 2021) , can enable efficient parallel inference execution schemes, and can facilitate early-exit (Bolukbasi et al., 2017; Dennis et al., 2018) or anytime inference (Ruiz & Verbeek, 2021; Huang et al., 2017a) applications, which are scenarios where inference may be interrupted due to variable resource availability. More specifically, we seek to distill a large pre-trained model, g, onto an ensemble comprised of low-parameter count, low-latency models, f i . We additionally aim for the resulting ensemble to form a decomposition, such that evaluating the first model produces a coarse estimate of the prediction (e.g., covering common cases), and evaluating additional models improves on this estimate (see Figure 1 ). There are major advantages to such an ensemble for on-device inference: 1) inference cost vs. accuracy trade-offs can be controlled on-demand at execution time, 2) the ensemble can either be executed in parallel or in sequence, or possibly a mix of both, and 3) we can improve upon coarse initial predictions without re-evaluation in response to resource availability. While traditional distillation methods are effective when transferring information to a single model of similar capacity, it has been shown that performance can degrade significantly when reducing the capacity of the student model (Mirzadeh et al., 2020; Gao et al., 2021) . Moreover, distillation of a deep network onto a weighted sum of shallow networks rarely performs better than distillation onto a single model (Cho & Hariharan, 2019; Allen-Zhu & Li, 2020) . Our method exploits connections to minimax optimization and online learning (Schapire & Freund, 2013) to allow models in our ensemble to compose and reuse intermediate activation outputs of other models during inference. As long as these composition functions are resource efficient, we can increase base class capacity at roughly the same inference cost as a single model. Moreover, we show that our procedure retains the theoretical guarantees of these methods (Schapire & Freund, 2013) . Concretely, • We formulate progressive distillation as a two player zero-sum game, derive a weak learning condition for distillation and present our algorithm, B-DISTIL, to approximately solve this game. To make the search for weak learners in low parameter count models feasible, we explicitly solve

