TEACHER GUIDED TRAINING: AN EFFICIENT FRAMEWORK FOR KNOWLEDGE TRANSFER

Abstract

The remarkable performance gains realized by large pretrained models, e.g., GPT-3, hinge on the massive amounts of data they are exposed to during training. Analogously, distilling such large models to compact models for efficient deployment also necessitates a large amount of (labeled or unlabeled) training data. In this paper, we propose the teacher-guided training (TGT) framework for training a high-quality compact model that leverages the knowledge acquired by pretrained generative models, while obviating the need to go through a large volume of data. TGT exploits the fact that the teacher has acquired a good representation of the underlying data domain, which typically corresponds to a much lower dimensional manifold than the input space. Furthermore, we can use the teacher to explore input space more efficiently through sampling or gradient-based methods; thus, making TGT especially attractive for limited data or long-tail settings. We formally capture this benefit of proposed data-domain exploration in our generalization bounds. We find that TGT can improve accuracy on several image classification benchmarks as well as a range of text classification and retrieval tasks.

1. INTRODUCTION

Recent general purpose machine learning models (e.g., BERT (Devlin et al., 2019) , DALL-E (Ramesh et al., 2021) , SimCLR (Chen et al., 2020a) , Perceiver (Jaegle et al., 2021) , GPT-3 (Brown et al., 2020) ), trained on broad data at scale, have demonstrated adaptability to a diverse range of downstream tasks. Despite being trained in unsupervised (or so-called self-supervised) fashion, these models have been shown to capture highly specialized information in their internal representations such as relations between entities Heinzerling & Inui (2021) or object hierarchies from images (Weng et al., 2021) . Despite their impressive performance, the prohibitively high inference cost of such large models prevents their widespread deployment. A standard approach to reducing the inference cost while preserving performance is to train a compact (student) model via knowledge distillation (Bucilua et al., 2006; Hinton et al., 2015) from a large (teacher) model. However, existing distillation methods require a large amount of training data (labeled or unlabeled) for knowledge transfer. For each data point, the teacher must be evaluated, making the process computationally expensive (Xie et al., 2020d; He et al., 2021; Sanh et al., 2019a) . This is compounded by the need to repeat the distillation process separately for every downstream task, each with its own training set. Enabling efficient distillation is thus an important challenge. Additionally, minimizing the number of distillation samples would especially benefit low-data downstream tasks, e.g., those with long-tails. Another inefficiency with standard distillation approaches is that within each evaluation of the teacher, only the final layer output (aka logits) is utilized. This ignores potentially useful internal representations which can also be levered for knowledge transfer. Various extensions have been proposed in the literature along these lines (see, e.g., (Sun et al., 2020; Aguilar et al., 2020; Li et al., 2019; Sun et al., 2019) and references therein). However, despite their success, such extensions mostly use the teacher model in a black-box manner, and do not fully utilize the domain understanding it contains (Cho & Hariharan, 2019; Stanton et al., 2021) . In these approaches, the teacher is used passively as the input sample distribution remains fixed and does not adapt to the student model performance. Consequently, these forms of distillation do not lead to faster training of a high-performance student model. In this work, we go beyond the passive application of large teacher models for training compact student models, and leverage the domain understanding captured by the teacher to generate new informative training instances that can help the compact model achieve higher accuracy with fewer samples and thus enable reduced training time. In particular, we propose the teacher guided training (TGT) framework for a more efficient transfer of knowledge from large models to a compact model. TGT relies on the fact that teacher's internal representation of data often lies in a much smaller dimensional manifold than the input dimension. Furthermore, we can use teacher to help guide training by identifying the directions where the student's current decision boundary starts to diverge from that of the teacher, e.g., via backpropagating through the teacher to identify regions of disagreement. To realize TGT, we take advantage of the fact that most of the unsupervised pretrained models like Transformers, VAE, and GANs have two components: (1) an encoder that maps data to a latent representation, and (2) a decoder that transforms the latent representation back to the original data space. We utilize this latent space for the data representations learned by the teacher model to efficiently search for the regions of mismatch between the teacher and student's decision boundaries. This search can take the form of either (i) a zero-order approach involving random perturbation or (ii) a first-order method exploring along the direction of the gradient of a suitably defined distance measure between the teacher and student models. Many pretrained models, particularly in NLP such as T5 (Raffel et al., 2020) , can also provide labels for a downstream task and act as a sole teacher. However, our approach is sufficiently general to utilize separate pretrained models for generative and discriminative (labeler) functions (cf. Fig. 1 ), e.g., we employ a BiGAN as generator and an EfficientNet as labeler for an image classification task. Our main contributions are summarized as follows: 1. We introduce the TGT framework, a conceptually simple and scalable approach to distilling knowledge from a large teacher into a smaller student. TGT adaptively changes the distribution of distillation examples, yielding higher performing student models with fewer training examples. 2. We provide theoretical justifications for utilizing the latent space of the teacher generator in the TGT framework, which yields tighter generalization bounds. 3. We empirically show the superiority of TGT to existing state-of-the-art distillation methods on both vision and NLP tasks, unlike most prior work which is specialized to one domain.



Figure 1: An overview of the proposed teacher guided training (TGT) framework. Given a learning task, the framework leverages a large teacher with a pretrained generator and labeler that exhibits high performance on the task. In particular, we assume that the generator consists of an encoder and a decoder. TGT performs three key operations during student model training: (1) Given an original training instance, by using the teacher generator, identify a novel task-relevant instance. We search for informative instances in the lower dimensional latent space, where we can propagate the gradient to. (2) Obtain (soft) labels for the original and newly generated training instances from the teacher labeler; and (3) Minimize the student training objective that depends on the original and newly generated instances along with their labels produced by the teacher labeler.

