TEACHER GUIDED TRAINING: AN EFFICIENT FRAMEWORK FOR KNOWLEDGE TRANSFER

Abstract

The remarkable performance gains realized by large pretrained models, e.g., GPT-3, hinge on the massive amounts of data they are exposed to during training. Analogously, distilling such large models to compact models for efficient deployment also necessitates a large amount of (labeled or unlabeled) training data. In this paper, we propose the teacher-guided training (TGT) framework for training a high-quality compact model that leverages the knowledge acquired by pretrained generative models, while obviating the need to go through a large volume of data. TGT exploits the fact that the teacher has acquired a good representation of the underlying data domain, which typically corresponds to a much lower dimensional manifold than the input space. Furthermore, we can use the teacher to explore input space more efficiently through sampling or gradient-based methods; thus, making TGT especially attractive for limited data or long-tail settings. We formally capture this benefit of proposed data-domain exploration in our generalization bounds. We find that TGT can improve accuracy on several image classification benchmarks as well as a range of text classification and retrieval tasks.

1. INTRODUCTION

Recent general purpose machine learning models (e.g., BERT (Devlin et al., 2019) , DALL-E (Ramesh et al., 2021) , SimCLR (Chen et al., 2020a) , Perceiver (Jaegle et al., 2021) , GPT-3 (Brown et al., 2020) ), trained on broad data at scale, have demonstrated adaptability to a diverse range of downstream tasks. Despite being trained in unsupervised (or so-called self-supervised) fashion, these models have been shown to capture highly specialized information in their internal representations such as relations between entities Heinzerling & Inui (2021) or object hierarchies from images (Weng et al., 2021) . Despite their impressive performance, the prohibitively high inference cost of such large models prevents their widespread deployment. A standard approach to reducing the inference cost while preserving performance is to train a compact (student) model via knowledge distillation (Bucilua et al., 2006; Hinton et al., 2015) from a large (teacher) model. However, existing distillation methods require a large amount of training data (labeled or unlabeled) for knowledge transfer. For each data point, the teacher must be evaluated, making the process computationally expensive (Xie et al., 2020d; He et al., 2021; Sanh et al., 2019a) . This is compounded by the need to repeat the distillation process separately for every downstream task, each with its own training set. Enabling efficient distillation is thus an important challenge. Additionally, minimizing the number of distillation samples would especially benefit low-data downstream tasks, e.g., those with long-tails. Another inefficiency with standard distillation approaches is that within each evaluation of the teacher, only the final layer output (aka logits) is utilized. This ignores potentially useful internal representations which can also be levered for knowledge transfer. Various extensions have been proposed in the literature along these lines (see, e.g., (Sun et al., 2020; Aguilar et al., 2020; Li et al., 2019; Sun et al., 2019) and references therein). However, despite their success, such extensions mostly use the teacher model in a black-box manner, and do not fully utilize the domain understanding it contains (Cho & Hariharan, 2019; Stanton et al., 2021) . In these approaches, the teacher is used passively as the input sample distribution remains fixed and does not adapt to the student model performance. Consequently, these forms of distillation do not lead to faster training of a high-performance student model. In this work, we go beyond the passive application of large teacher models for training compact student models, and leverage the domain understanding captured by the teacher to generate new informative training instances that can help the compact model achieve higher accuracy with fewer samples and thus enable reduced training time. In particular, we propose the teacher guided training (TGT) framework for a more efficient transfer of knowledge from large models to a compact model. TGT relies on the fact that teacher's internal representation of data often lies in a much smaller dimensional manifold than the input dimension. Furthermore, we can use teacher to help guide training by identifying the directions where the student's current decision boundary starts to diverge from that of the teacher, e.g., via backpropagating through the teacher to identify regions of disagreement. Given a learning task, the framework leverages a large teacher with a pretrained generator and labeler that exhibits high performance on the task. In particular, we assume that the generator consists of an encoder and a decoder. TGT performs three key operations during student model training: (1) Given an original training instance, by using the teacher generator, identify a novel task-relevant instance. We search for informative instances in the lower dimensional latent space, where we can propagate the gradient to. (2) Obtain (soft) labels for the original and newly generated training instances from the teacher labeler; and (3) Minimize the student training objective that depends on the original and newly generated instances along with their labels produced by the teacher labeler. We provide a theoretical justification for the TGT algorithm, showing that leveraging the data representation of large models ensures better generalization for the student. Given n instances in a Ddimensional space the generalization gap for learning a Lipschitz decision boundary of a classification task decays only as O n -1 D (Györfi et al., 2002) . In contrast, provided that the large model learns a good data representation in a d-dimensional latent space, the TGT framework realizes a generalization gap of O n -1 d +W(D, D t ) , where W(D, D t ) denotes the Wasserstein distance between the data distribution D and the distribution D t learned by the generative teacher model. Typically d ≪ D, thus TGT ensures much faster convergence whenever we use a high-quality generative teacher; making TGT especially attractive for low-data or long-tail regimes. To realize TGT, we take advantage of the fact that most of the unsupervised pretrained models like Transformers, VAE, and GANs have two components: (1) an encoder that maps data to a latent representation, and (2) a decoder that transforms the latent representation back to the original data space. We utilize this latent space for the data representations learned by the teacher model to efficiently search for the regions of mismatch between the teacher and student's decision boundaries. This search can take the form of either (i) a zero-order approach involving random perturbation or (ii) a first-order method exploring along the direction of the gradient of a suitably defined distance measure between the teacher and student models. Many pretrained models, particularly in NLP such as T5 (Raffel et al., 2020) , can also provide labels for a downstream task and act as a sole teacher. However, our approach is sufficiently general to utilize separate pretrained models for generative and discriminative (labeler) functions (cf. Fig. 1 ), e.g., we employ a BiGAN as generator and an EfficientNet as labeler for an image classification task. Our main contributions are summarized as follows: 1. We introduce the TGT framework, a conceptually simple and scalable approach to distilling knowledge from a large teacher into a smaller student. TGT adaptively changes the distribution of distillation examples, yielding higher performing student models with fewer training examples. 2. We provide theoretical justifications for utilizing the latent space of the teacher generator in the TGT framework, which yields tighter generalization bounds.

2. RELATED WORK

Our proposed TGT framework can be considered a form of data augmentation where data is dynamically added at points of current discrepancy between the teacher and student. Next, we provide a brief overview of how data augmentation has been used in the context of distillation and distinguish our work from these existing efforts. Using pseudo labels. The earliest line of work involves using consistency regularization (Sajjadi et al., 2016; Tarvainen & Valpola, 2017) to obtain pseudo labels for unlabelled data while a model is expected to make consistent predictions on an unlabeled instance and its augmented versions, cf. (Miyato et al., 2019; Xie et al., 2020a; Sohn et al., 2020; Zhu et al., 2021, inter alia) . Another approach is self-training (Xie et al., 2020d; Du et al., 2021) where one learn a smaller teacher model on the labeled data which then generates pseudo labels for a large relevant unlabeled set. A large student model is then trained on both labeled and pseudo labeled sets. Label propagation (Iscen et al., 2019) is another direction where unlabeled instances receive pseudo labels based on neighboring labeled instances in a suitably constructed similarity graph. Furthermore, prior work on learning to teach (Fan et al., 2018; Raghu et al., 2021; Pham et al., 2021) , dynamically updates the teacher so as to provided more valuable pseudo labels based on the student loss. Such an interactive approach presents a challenging optimization problem and potentially opens up the door for borrowing techniques from reinforcement learning. In contrast, our work focuses on the setting where high-quality pretrained teacher model is fixed throughout the training. We focus on a setting where updating the large teacher model is prohibitively costly or undesirable as such a model would potentially be used to distill many student models. Moreover, many large models like GPT-3 may only be available through API access, thus making it infeasible to update the teacher. Using pretrained models. One can use large pretrained class conditional generative models like BigGAN (Brock et al., 2019) or VQ-VAE2 (Razavi et al., 2019) to generate more data for augmentation. Despite evidence (Webster et al., 2019) that GANs are not memorizing training data, using them to simply augment the training dataset has limited utility when training ResNets (Ravuri & Vinyals, 2019b; a) . Lack of diversity (Arora et al., 2017) in data generated by GANs, especially among high density regions (Arora et al., 2018) , is a potential reason for this. In contrast, we use generative models to adaptively explore the local region of disagreement between teacher and student as opposed to blindly sampling from the generative model. This way we circumvent the excessive reliance on samples from high density regions which often have low diversity. Another line of work by Chen et al. (2020b) combines unsupervised/self-supervised pretraining (on unlabeled data) with SimCLR-based approach (Chen et al., 2020a) , task-specific finetuning (on labeled data), and distillation (natural loss on labeled and distillation loss on unlabeled data). Our work is very close to this line of work with two key differences: (1) We assume access to a very highquality teacher, which is potentially trained on a much larger labeled set, to provide pseudo labels; (2) We go beyond utilizing a given relevant unlabeled dataset and explore the dynamic generation of domain-specific unlabeled data by leveraging the representations learned by pretrained models. Additionally, we develop a theoretical framework to establish the utility of unlabeled data instances for student training, specifically the instances generated based on teacher learned representations. Using both pseudo labels and pretrained models. The GAL framework (He et al., 2021) previously considered generating training instances by using pretrained generator models along with pseudolabelers. However, the GAL framework generates these new instances in an offline manner at the beginning of student training. In contrast, our approach (cf. Fig. 1 ) generates new informative instances in an online fashion to attain high-performance and reduce training time for the student. Recently, MATE-KD (Rashid et al., 2021) also used a generator model to obtain new training instances based on the student's current performance (by looking at the divergence between the student and teacher predictions). However, there are two key differences between our TGT approach and the MATE-KD framework: First, their method updates the teacher so as to find adversarial examples for the students, which can cause the generator to drift away from the true data distribution. Second, they introduce perturbations in the input space itself and do not leverage the latent space of the teacher, which is the crux of our method. See Appendix A for further details. Notably KDGAN (Wang et al., 2018) Finally, data-free KD approaches (Nayak et al., 2019; Yoo et al., 2019; Chen et al., 2019) use only synthetically generated data for knowledge distillation. Unlike TGT, in such approaches, the synthetic data distribution is updated at each epoch, which causes the student model to lose the information over epochs and experience accuracy degradation (Binici et al., 2022) . In this framework, Micaelli & Storkey (2019) targeted generating samples that would cause maximum information gain to the student when learned, however, it also suffers from similar drawbacks as MATE-KD noted above.

3. TEACHER GUIDED TRAINING

We begin by formally introducing our setup in Section 3.1. We then describe our proposed TGT framework in Section 3.2 and present its theoretical analysis in Section 3.3.

3.1. PROBLEM SETUP

In this paper, we focus on a multiclass classification task where given an instance x ∈ X the objective is to predict its true label y ∈ Y := [K] out of K potential classes. Let D := D X,Y denote the underlying (joint) data distribution over the instance and label spaces for the task. Moreover, we use D X and D Y |X=x to denote the marginal distribution over the instance space X and the conditional label distribution for a given instance x, respectively. A classification model f : X → R K , with f (x) = (f (x) 1 , . . . , f (x) K ) , takes in an input instance x and yields scores for each of the K classes. Finally, we are given a (tractable) loss function ℓ : R K × [K] → R which closely approximates model's misclassification error on an example (x, y), e.g., softmax-based cross-entropy loss. Given n i.i.d. labeled samples S labeled n := {(x i , y i )} i∈[n] generated from D and a collection of allowable models F, one typically learns a model via empirical risk minimization (ERM): f n = arg min f ∈F 1 n i∈[n] ℓ(f (x i ), y i ). (1) In our TGT setup, we further assume access to a high quality teacher model, which has: • Teacher generator. A generative component that captures D X well, e.g., a transformer, VAE, or ALI-GAN. This usually consists of an encoder Enc : X → R d and a decoder Dec : R d → X. • Teacher labeler. A classification network, denoted by h : X → R K , with good performance on the underlying classification task. In general, our framework allows for h to be either a head on top of the teacher generator or an independent large teacher classification model. Given S labeled n and such a teacher model, our objective is to learn a high-quality compact student (classification) model in F, as assessed by its misclassification error on D.

3.2. PROPOSED APPROACH

To train a student model f ∈ F, we propose to minimize: R TGT f (S labeled n ) := 1 n i∈[n] ℓ(f (x i ), y i ) + ℓ d (f (x i ), h(x i )) + 1 m j∈[m] ℓ d (f (x j ), h(x j )) (2) where ℓ d : R K × R K → R is a loss function that captures the mismatch between two models f and h, and Sm = {x j } j∈[m] is introduced in subsequent passage. The first term, ℓ(f (x i ), y i ), corresponds to standard ERM problem (cf. Eq. ( 1)). The subsequent terms, ℓ d (f (x i ), h(x i )) and ℓ d (f (x j ), h(x j )), do not make use of labels. In particular, the second term, ℓ d (f (x i ), h(x i )), corresponds to the standard knowledge distillation where the teacher h provides supervision for the student f . We introduce a novel third term, ℓ d (f (x j ), h(x j )), where Sm = {x j } is generated based on S n = {x i }. Here, we want to generate informative instances Sm that will help student learn faster, e.g., points on the data manifold where the student disagrees with the teacher. In other words, we want to find x as follows: x ∈ arg max x∈X ℓ d (f (x), h(x)) such that p D X (x) ≥ λ (3) Note that the objective and constraint in Eq. ( 3) ensure that we select an instance where the student and teacher disagree and the instance belongs to a region where true data distribution assigns a non-trivial mass, respectively. Based on this, we propose two specific approaches to generate the novel samples Sm : 1. Isotropically perturb in latent space: x = Dec(Enc(x) + ν) where ν ∼ N (0, σ 2 I d ). This can be regarded as a zero-order search in the latent space, which satisfies the constraint of remaining within the data manifold. 2. Gradient-based exploration: Run a few iterations of gradient ascent on Eq. ( 3) in order to find the example that diverges most with teacher. To enforce the constraint, we run the gradient ascent in the latent space of the teacher generator as opposed to performing gradient ascent in the instance space X, which might move the perturbed point out of the data manifold. For a high-quality teacher generator, the latent space should capture the data manifold well. To implement this we need to backprop all the way through the student and teacher-labeler to the teacher-decoder, as shown in Fig. 1 . Mathematically, it involves the following three operations: z := Enc(x); z ← z + η∇ z ℓ d (f (Dec(z)) , h(Dec(z))); x := Dec(z). This is akin to a first-order search in the latent space. Extension to discrete data. Note that perturbing an instance from a discrete domain, e.g., text data, is not as straightforward as in a continuous space. Typically, one has to resort to expensive combinatorial search or crude approximations to perform such perturbations (Tan et al., 2020; Zang et al., 2020; Ren et al., 2019) . Interestingly, our approach in Eq. ( 4) provides a simple alternative where one performs the perturbation in the latent space which is continuous. On the other hand, in gradient based exploration, we assume that X is a differentiable space in order to calculate necessary quantities such as ∂f (x) ∂x in Eq. ( 5). This assumption holds for various data such as images and point clouds but not for discrete data like text. We can, however, circumvent this limitation by implementing weight sharing between the output softmax layer of the teacher's decoder Dec and the input embedding layer of the student f (and also to teacher labeler h when an independent model is used). Now, one can bypass discrete space during the backward pass, similar to ideas behind VQ-VAE (Hafner et al., 2019) . Note that, during forward pass, we still need the discrete representation for decoding, e.g., using beam search. Finally, we address the superficial resemblance between our approach and adversarial training. For latter, the goal is to learn a robust classifier, i.e., to increase margin. Towards this, for any x, one encourages model agreement in its local neighborhood B r (x), i.e., f (x ′ ) = f (x), ∀x ′ ∈ B r (x). One needs to carefully choose small enough neighborhood by restricting r, so as to not cross the decision boundary. In contrast, we are not looking for such max-margin training which has its own issues (Nowak-Vila et al., 2021) . We simply desire global agreement between the teacher and student, i.e., f (x ′ ) = h(x ′ ), ∀x ′ . As a result, we can explore much bigger regions as long as we remain on the data manifold, i.e., p D X (x) is non-trivially large.

3.3. VALUE OF GENERATING SAMPLES VIA THE LATENT SPACE

Now, we formally show how leveraging the latent space can help learning. For this exposition, we assume X = R D . Furthermore, for directly learning in the input space, we assume that our function class F corresponds to all Lipschitz functions that map R D to R K . For any such function f ∈ F, existing generalization bounds take the form (Devroye et al., 2013; Mohri et al., 2018) : R ℓ,f (D) ≤ R ℓ,f (S n ) + R n (G ℓ,F ) ≤O(n -1/D ) +O log(1/δ)/n , where R ℓ,f (D) is true population risk of the classifier, R ℓ,f (S n ) is empirical risk, and R n (G ℓ,F ) is the Rademacher complexity of the induced function class G ℓ,F , which is known in our case to be O(n -1/D ) (see Appendix B for more details). Note that any reduction in the Rademacher term would imply a smaller generalizing gap, which is our goal. In our TGT framework, we assume availability of a teacher that is able to learn a good representation for the underlying data distribution. In particular, we assume that, for x ∈ supp(D X ), we have ∥Dec • Enc(x) -x∥ ≤ ϵ, i.e., for x, applying the decoder Dec on the latent representation of x, as produced by the encoder Enc, leads to a point Dec • Enc(x) ∈ X that approximates x with a small error. This ability of teacher generator to model the data distribution using latent representation can be used to reduce the complexity of the function class needed. Specifically, in TGT framework, we leverage the teacher decoder to restrict the function class to be a composition of the decoder function Dec and a learnable Lipschitz function operating on the latent space R d . Since d ≪ D, this leads to a function class with much lower complexity. Next, we formally capture this idea for distillation with both the original samples S n sampled from D X as well as the novel samples S introduced by the teacher generator. In what follows, we only consider the distillation losses and ignore the first loss term (which depends on true labels). Our analysis can be easily extended to take the latter term into account (e.g., by using tools from Foster et al. ( 2019)). We start with the standard distillation in the following result. See Appendix C.1 for the details. Theorem 3.1. Suppose a generative model with Enc and Dec satisfies the approximation guarantee in Eq. ( 6) for D X . Let Dec and teacher labeler h be Lipschtiz functions, and the distillation loss ℓ d satisfies Assumption C.1. Then, with probability at least 1 -δ, the following holds for any f ∈ F. R ℓ,f (D) ≤R h ℓ d ,f (S n ) + R n (G h,Dec ℓ d ,F ) ≤O(n -1/d ) +O log(1/δ) √ n + Lϵ + O √ KE D X ∥D Y |X -h(X)∥ 2 . where L is the effective Lipschitz constant of G h,Dec ℓ d ,F = {z → ℓ d (f • Dec(z), h • Dec(z)) : f ∈ F} - an induced function class which maps R d (latent space of generator) to R. Thus, we can reduce the Rademacher term from O(n -1/D ) to O(n -1/d ), which yields a significant reduction in sample complexity. However, as the teacher model is not perfect, a penalty is incurred in terms of reconstruction error Lϵ and prediction error O √ KE D X ∥D Y |X -h(X)∥ 2 . Thus far, we have not leveraged the fact that we can also use the teacher to generate additional samples. Accounting for using samples Sn (cf. Section 3.2), one can obtain similar generalization gap for the distillation based on the teacher generated samples: Theorem 3.2. Let Sn = {x i } i∈[n] be n i.i.d. samples generated by the the TGT framework, whose distribution be denoted by DX . Further, let fn ∈ F denote the student model learned via distillation on Sn , with h as the teacher model and ℓ d be the distillation loss satisfying Assumption C.1. Then, with probability at least 1 -δ, we have 2021). We report top-1 accuracy on balanced eval sets. R ℓ,f (D) ≤R h ℓ d , fn ( Sn ) + Rn (G h,Dec ℓ d ,F ) ≤O(n -1/d ) +O log(1/δ) n + W(D X , DX ) + O √ KE D X ∥D Y |X -h(X)∥ 2 , where G h,Dec ℓ d ,F is We also state the number of model parameters and inference cost (in terms of FLOPs) for all the methods. Note that TGT leads to performance improvements over standard distillation on all three datasets, particularly for ImageNet-LT where the teacher generator models the task distribution well. TGT also often outperforms stated baselines that rely on much larger and expensive models. generalization bounds (Maurer & Pontil, 2009) in Appendix C.3. Such bounds suggest that, besides minimizing the discrepancy W(D X , DX ), an ideal DX should reduce the variance of ℓ d f (x), h(x) for newly generated instances. Incidentally, the sampling approach realized by the gradient-based exploration in Eq. ( 5) aims to achieve this: it controls for W(D X , DX ) by operating in the latent space of a good quality teacher generative model and minimizes variance by finding instances with high loss values through gradient ascent, thereby striking a desired balance between the two objectives. See Appendix C.3 for a detailed discussion.

4. EXPERIMENTS

We now conduct a comprehensive empirical study of our TGT framework in order to establish that TGT (i) leads to high accuracy in transferring knowledge in low data/long-tail regimes (Section 4.1); (ii) effectively increases sample size (Section 4.2); and (iii) has wide adaptability even to discrete data domains such as text classification (Section 4.3) and retrieval (Section 4.4).

4.1. LONG-TAIL IMAGE CLASSIFICATION

Setup. We evaluate TGT by training student models on three benchmark long-tail image classification datasets: ImageNet-LT (Liu et al., 2019c) , SUN-LT (Patterson & Hays, 2012) , Places-LT (Liu et al., 2019c) We employ off-the-shelf teacher models, in particular BigBiGAN (ResNet-50) (Donahue & Simonyan, 2019) and EfficientNet-B3 (Xie et al., 2020c) as the teacher generator and teacher labeler models, respectively. We utilize MobileNetV3 (Howard et al., 2019) as compact student model architecture. The teacher-labeler model is self-trained on JFT-300M (Sun et al., 2017) , and then finetuned on the task-specific long-tail dataset. The teacher generator is trained on the unlabelled full version of ImageNet (Russakovsky et al., 2015) . Results. The resultsfoot_2 are reported in Table 1 compared with similar sized baselines (we ignored gigantic transformer models). We see that TGT is able to effectively transfer knowledge acquired by the teacher during its training with the huge amount of data into a significantly smaller student model, which also has lower inference cost. TGT considerably improves the performance across the board over standard distillation, even on Sun-LT and Places-LT whose data distribution does not exactly match to the distribution that the teacher's generator was trained with. That said, the gains from TGT are more pronounced when the mismatch between the task data distribution and the distribution modeled by the generator is not very large, which is the case for ImageNet-LT. The fact that TGT (random) (cf. Eq. ( 4)) provides large gains over standard distillation establishes the value of utilizing the latent space, as suggested by our analysis in Section 3.3. Note that TGT (gradient-based) brings further gains over TGT (random), particularly on SUN-LT and Places-LT which are extremely long-tail. We believe that gradient-based first-order exploration is specifically useful for settings where data is extremely sparse or where isotropic random perturbation in the latent space does not produce diverse enough instances. A systematic study of this constitutes an interesting avenue for future research. Owing to its computational efficiency, we focus on TGT (random) for rest of paper. Note that some of the baselines in Table 1 rely on specialized loss functions and/or training methods designed for long-tail settings, whereas we do not leverage such techniques. Combining the TGT framework with a long-tail specific loss function as opposed to using the standard cross-entropy loss function can potentially improve its performance. We leave this direction for future explorations. To further showcase effectiveness of knowledge transfer via TGT, we simulate a low-data regime by varying the amount of available training data for Ima-geNet (Russakovsky et al., 2015) and studying its impact on student's performance. We use the same model architectures as in Section 4.1, but finetune the teacher labeler on the entire ImageNet. We then compare the performance of the student trained via TGT, with the students trained via normal training (with one-hot labels) and standard distillation. 

4.3. TEXT CLASSIFICATION

Setup. We evaluate the proposed TGT framework on four benchmark text classification datasets: Amazon-5 (Zhang et al., 2015) , IMDB (Maas et al., 2011) , MNLI (Williams et al., 2018) , and Yelp-5 (Zhang et al., 2015) . Following Xie et al. (2020a) , we also consider an extremely sub-sampled version of Amazon-5 and Yelp-5 consisting of only 2.5k labeled examples. Again, we utilize off-the-shelf teacher models, in particular a BART-base (Lewis et al., 2020) and RoBERTa-large (Liu et al., 2019a) as the teacher generator and teacher labeler, respectively. Following Rashid et al. (2021) , we employ a DistilBERT (Sanh et al., 2019b) model as student architecture. Both teacher networks are pretrained on a very large generic text corpus of size 160GB. The teacher labeler is finetuned on each task-specific dataset while the teacher generator is not specialized to any specific task. Results. We compare TGT with other data augmentation and distillation baselines in Table 2 . Note that TGT considerably improves the performance and beats the state-of-the-art methods MATE-KD (Rashid et al., 2021) and UDA (Xie et al., 2020a) . Interestingly, by using TGT on a randomly initialized student, we can match the performance of finetuning (with one-hot labels) a pretrained model on Amazon-5 and Yelp-5. We highlight that baselines such as MATE-KD always work with a pretrained student model. Thus, the improvements realized by TGT with a randomly initialized student demonstrates enormous saving in overall data and training time requirement as it eliminates the need for pretraining on a large corpus. This further establishes that TGT can enable a data-efficient knowledge transfer from the teacher to the student.

Method

recall@20 recall@100 Table 3 : Performance of TGT and various baselines on the NQ retrieval task. (Kwiatkowski et al., 2019) . The teacher labeler follows the setup of (Oguz et al., 2021) that pretrains RoBERTa-base on a large corpus and also PAQ (Lewis et al., 2021) and then finetuned to NQ (Kwiatkowski et al., 2019) . BART-base (Lewis et al., 2020 ) is employed to serve as a task-agnostic generator. All student models follow the architecture of DistilBERT (Sanh et al., 2019b) . TGT significantly outperforms standard training (One-hot) and teacherlabel only distillation (Distillation). TGT closes the teacher-student gap by 37% at @20, 63% at @100) compared to the standard distillation. See Appendix F.4 for more details on the experimental setup. Setup. Finally, we evaluate TGT on Natural Questions (NQ) (Kwiatkowski et al., 2019) -a text retrieval benchmark. The task is to find a matching passage for a question, out of a large candidate passage corpus (21M). We use RoBERTa-Base dual-encoder model Oguz et al. (2021) as teacher labeler and BART-base (Lewis et al., 2020) as teacher generator. We utilize DistilBERT dual encoder model as our student architecture. We follow the standard retrieval distillation setup where the teacher labeler provides labels for all the within-batch question-to-passage pairs for the student to match. Besides one-hot training and standard distillation, we consider another baseline, namely uniform negatives. In uniform negatives, for each question-to-passage pair in NQ, we uniformly sample 2 additional passages from the passage corpus during training. TGT instead dynamically generates 2 confusing passages for each question-passage pair with BART generator, infusing the isotropic perturbation as per Eq. ( 4). Results. Table 3 shows that TGT significantly improves performance, closing the teacher-student gap by 37% at recall@20 and 63% at recall@100 compared to the standard distillation. Unlike TGT, uniform negatives only partially helped (slight improvement on recall@20 but degradation one recall@100 compared to the standard distillation). A plausible explanation is that, due to the extremely large passage corpus (21M), uniformly sampled passages are not very relevant to the matching question-to-passage pair in NQ. TGT instead generates informative passages that are close to the matching pair.

5. CONCLUSION AND FUTURE DIRECTIONS

We have introduced a simple and theoretically justified distillation scheme (TGT) that adaptively generates samples with the aim of closing the divergence between student and teacher predictions. Our results show it to outperform, in aggregate, existing distillation approaches. Unlike alternative methods, it is also applicable to both continuous and discrete domains, as the results on image and text data show. TGT is orthogonal to other approaches that enable efficient inference such as quantization and pruning, and combining them is an interesting avenue for future work. Another potential research direction is to employ TGT for multi-modal data which would require accommodating multiple generative models with their own latent spaces, raising both practical and theoretical challenges.

ETHICS STATEMENT

TGT framework relies on the availability of a good-quality teacher for the underlying domain to provide efficient distillation. The impact of knowledge distillation on transferring the teacher model's biases to the resulting student model is far from well understood. Moreover, the teacher generator that TGT utilizes are often large pretrained models trained on lot of unfiltered data. As a result these large models can have potential biases without the awareness of the user. Also, how various biases present in the generator impact the student model's fairness/bias is not addressed in our work. A deeper study of this issue is required for our proposed method, as well as for the knowledge distillation as an ML technique in general. A FURTHER COMPARISON WITH MATE-KD MATE-KD (Rashid et al., 2021) alternative trains generator model and student model, with the hope of generating most adversarial examples for the students during the training. This can cause the generator to drift away from true data distribution. In contrast, we keep the pre-trained teacher-generator model fixed throughout the training process of the student. Our objective behind employing the generator model is to leverage the domain knowledge it has already acquired during its pre-training. While we do want to generate 'hard instances' for the student, we also want those instances to be relevant for the underlying task. Thus, keeping the generator fixed introduces a regularization where the training instances the student encounters do not introduce domain mismatch. Keeping in mind the objective of producing new informative training instances that are in-domain, we introduce perturbation in the latent space realized by the encoder of the teacher-generator model (see Figure 1 ). This is different from directly perturbing an original training instance in the input space itself, as done by MATE-KD. As evident from our theoretical analysis and empirical evaluation, for a fixed teacher-generator model, employing perturbation in the latent space leads to more informative data augmentation and enables good performance on both image and text domain.

B BACKGROUND AND NOTATION

For a, b ∈ R, we use a = O(b) to denote that there exits a constant γ > 0 such that a ≤ γ • b. Given a collection of n i.i.d. random variables U n = {u 1 , . . . , u n } ⊂ U, generated from a distribution D U and a function τ : U → R, we define the empirical mean of {τ (u 1 ), . . . , τ (u n )} as E Un [τ (U )] := 1 n i∈[n] τ (u i ). For the underlying multiclass classification problem defined by the distribution D := D X×Y , we assume that the label set Y with K classes takes the form [K] := {1, . . . , K}. We use F to denote the collection of potential classification models that the learning methods is allowed to select from, namely function class or hypothesis set: F ⊆ {X → R K }, which is a subset of all functions that map elements of the instance space X to the elements of R K . Given a classification loss function ℓ : R K × Y → R and a model f : X → R K and a sample S labeled n = {(x i , y i )} i∈[n] generated from D, we define the empirical risk for f ∈ F as follows. R ℓ,f (S labeled n ) := E S labeled n [ℓ f (X) ] = 1 n i∈[n] ℓ f (x i ), y i . Further, we define the population risk for f ∈ F associated with data distribution D as follows. R ℓ,f (D) = E X,Y ∼D [ℓ(f (X), Y )]. Note that, when the loss function ℓ is clear from the context, we drop ℓ from the notation and simply use R f (S labeled n ) and R f (D) to denote the the empirical and populations risks for f , respectively. Given a function class F, the loss function ℓ induces the following function class. G ℓ,F = (x, y) → ℓ(f (x), y) : f ∈ F . ( ) Definition B.1 (Rademacher complexity of G ℓ,F ). Now, given a sample S labeled n = {(x i , y i )} i∈[n] ∼ D n and a vector σ = (σ i , . . . , σ m ) ∈ {+1, -1} with n i.i.d. Bernoulli random variables, empirical Rademacher complexity R S (G ℓ,F ) and Rademacher complexity R n (G ℓ,F ) are defined as R S labeled n (G ℓ,F ) = 1 n E σ sup g∈G ℓ,F n i=1 σ i g(x i , y i ) and R n (G ℓ,F ) = E S∼D n R S labeled n (G ℓ,F ) Let S n = {x i } i∈[n] be a set of n unlabeled samples generated from D X . Then, given a teacher model h : X → R K and a distillation loss ℓ d : R K × R K → R, we define the empirical (distillation) risk for f ∈ F to be R h ℓ d ,f (S n ) := E Sn [ℓ d (f (X), h(X))] = 1 n i∈[n] ℓ d f (x i ), h(x i ) . Accordingly, the population (distillation) risk for f ∈ F is defined as R h ℓ d ,f (D) := E X∼D X [ℓ d (f (X), h(X))] . Again, when ℓ d is clear from the context, we simply use R h f (S n ) and R h f (D) to denote the empirical and population distillation risk for f , respectively. Note that, for a (student) function class F and a teacher model h, ℓ d produces an induced function class G ℓ d ,h (F), defined as G h ℓ d ,F := {x → ℓ d (f (x), h(x)) : f ∈ F}. ( ) Definition B.2 complexity of G h ℓ d ,F ). Given a sample S n = {x i } i∈[n] ∼ D n X and a vec- tor σ = (σ i , . . . , σ m ) ∈ {+1, -1} with n i.i.d. Bernoulli randoms variable, empirical Rademacher complexity R Sn G h ℓ d ,F and Rademacher complexity R n G h ℓ d ,F are defined as R Sn (G h ℓ d ,F ) = 1 n E σ sup g∈G h ℓ d ,F n i=1 σ i g(x i ) , and R n (G h ℓ d ,F ) = E S∼D n X R Sn (G h ℓ d ,F ) C DEFERRED PROOFS FROM SECTION 3 C.1 PROOF OF THEOREM 3.1 In this subsection, we present a general version of Theorem 3.1. Before that, we state the following relevant assumption on the distillation loss ℓ d . Assumption C.1. Let ℓ : R K × Y → R be a bounded loss function. For a teacher function h : X → R K , the distillation loss ℓ d takes the form ℓ d (f (x), h(x)) = y∈[K] h(x) y • ℓ(f (x), y). Remark C.2. Note that the cross-entropy loss The following results is a general version of Theorem 3.1 in the main body. ℓ d (f (x), h(x)) = -y h(x) y • log f (x) y , Theorem C.3. Let a generator with the encoder Enc and decoder Dec ensures the approximation guarantee in Eq. (6) for D X . Let Dec and teacher labeler be Lipschtiz functions, F be function class of Lipschitz functions, and the distillation loss ℓ d be Lipschtiz. Then, with probability at least 1 -δ, the following holds for any f ∈ F. R h ℓ d ,f (D X ) ≤ R h ℓ d ,f (S n ) + O n -1/d + Lϵ + O log(1/δ) n , where L denotes the effective Lipschitz constant of the induced function class G h ℓ d ,F in Eq. (15). Additionally, if the distillation loss ℓ d satisfies Assumption C.1 with a classification loss ℓ, then Eq. (18) further implies the following. R ℓ,f (D) ≤ R h ℓ d ,f (S n ) + O n -1/d + Lϵ + O log(1/δ) n + O √ K • E D X ∥D Y |X -h(X)∥ 2 . ( ) Proof. Note that R h ℓ d ,f (D X ) = E D X [ℓ d (f (X), h(X))] ≤ E Sn [ℓ d (f (X), h(X))] + sup f ∈F E Sn [ℓ d (f (X), h(X))] -E D X [ℓ d (f (X), h(X))] (i) ≤ E Sn [ℓ d (f (X), h(X))] + sup g∈G h ℓ d ,F E Sn [g(X)] -E D X [g(X)] (ii) ≤ E Sn [ℓ d (f (X), h(X))] + R Sn (G h ℓ d ,F ), where (i) follows from the definition of G h ℓ d ,F in Eq. ( 15) and (i) follow from the standard symmetrization argument (Devroye et al., 2013; Mohri et al., 2018) . Next, we turn our focus to the empirical Rademacher complexity R Sn (G h ℓ d ,F ). Recall that S n = {x 1 , x 2 , . . . , x n } contains n i.i.d. samples generated from the distribution D X . We define another set of n points Sn = {x 1 = Dec • Enc(x 1 ), . . . , xn = Dec • Enc(x n )}. It follows from our assumption on the quality of the generator (cf. Eq. ( 6)) that ∥Dec • Enc(x i ) -x i ∥ ≤ ϵ, ∀i ∈ [n]. (21) Note that R Sn (G h ℓ d ,F ) = 1 n E σ sup g∈G h ℓ d ,F i σ i g(x i ) , where σ denote a vector with n i.i.d Bernoulli random variables. R Sn (G h ℓ d ,F ) = 1 n E σ sup g∈G h ℓ d ,F 1 n i σ i g(x i ) -g(x i ) + g(x i ) ≤ 1 n E σ sup g∈G h ℓ d ,F 1 n i σ i g(x i ) + 1 n E σ sup g∈G h ℓ d ,F i σ i g(x i ) -g(x i ) ≤ 1 n E σ sup g∈G h ℓ d ,F i σ i g(x i ) + sup g∈G h ℓ d ,F 1 n i |g(x i ) -g(x i )| ≤ 1 n E σ sup g∈G h ℓ d ,F i σ i g(x i ) + 1 n i L • ∥x i -xi ∥ ≤ 1 n E σ sup g∈G h ℓ d ,F i σ i g(x i ) + Lϵ ≤ 1 n E σ sup g∈G h ℓ d ,F i σ i g(Dec(z i )) + Lϵ, where z i = Enc(x i ), for i ∈ [n] . By definition of G h ℓ d ,F , g(Dec(e)) = ℓ d (f (x), h(x)) for some f ∈ F. Now, we can define a new function class from R d to R: G h,Dec ℓ d ,F = {z → ℓ d (f • Dec(z), h • Dec(z)) : f ∈ F}. Therefore, it follows from Eq. ( 22) and Eq. ( 23) that R Sn (G h ℓ d ,F ) ≤ R En (G h,Dec ℓ d ,F ) + Lϵ, where E n = {Enc(x 1 ), . . . , Enc(x n )} ⊂ R d . It follows from the standard concentration results for empirical Rademacher complexity around Rademacher complexity that with probability at least 1 -δ, R En (G h,Dec ℓ d ,F ) ≤ R n (G h,Dec ℓ d ,F ) + O log 1 δ • 1 n . ( ) Since f ∈ F, h and Dec are Lipschitz functions, G h,Dec ℓ d ,F is collection of Lipschitz functions from R d to R. Thus, it follows from the standard results (Gottlieb et al., 2016, Theorem 4. 3) that R n (G h,Dec ℓ d ,F ) ≤ O n -1 d . Now, Eq. ( 18) follow from Eq. ( 20), Eq. ( 24), Eq. ( 25), and Eq. ( 26). Finally, Eq. ( 19) follows by combining Lemma D.4 with Eq. ( 18). C.2 PROOF OF THEOREM 3.2 Here, we present the following result, which along with Theorem C.5 implies Theorem 3.2 stated in the main body. Theorem C.4. Let Sn = {x i } i∈[n] be n i.i.d. samples generated from a distribution DX . Further, let fn ∈ F denote the student model learned via distillation on Sn , with h and ℓ d as the teacher model and distillation loss, respectively. Then, with probability at least 1 -δ, we have R h ℓ d , fn (D X ) ≤ R h ℓ d , fn ( Sn ) + W(D X , DX ) + Rn (G h ℓ d ,F ) + O log 1 δ • 1 n , ( ) where Rn (G h ℓ d ,F ) = E S∼ Dn R Sn (G h ℓ d ,F ) denote that Rademacher complexity of the induced function class G h ℓ d ,F , defined in Eq. (15). If S is constructed with the TGT framework based on a generator with the encoder Enc and decoder Dec, then Eq. ( 27) further specialized to R h ℓ d , fn (D X ) ≤ R h ℓ d , fn ( Sn ) + W(D X , DX ) + Rn (G h,Dec ℓ d ,F ) + O log 1 δ • 1 n , where G h,Dec ℓ d ,F defines the following induced function class from R d (i.e., the latent space of the generator) to R. G h,Dec ℓ d ,F = {z → ℓ d (f • Dec(z), h • Dec(z)) : f ∈ F}. ( ) Proof. Note that R h ℓ d , fn ( DX ) = E DX [ℓ d ( fn (X), h(X))] ≤ E Sn [ℓ d ( fn (X), h(X))] + sup f ∈F E Sn [ℓ d (f (X), h(X))] -E DX [ℓ d (f (X), h(X))] ≤E Sn [ℓ d ( fn (X), h(X))] + sup g∈G h ℓ d ,F E Sn [g(X)] -E DX [g(X)] ≤E Sn [ℓ d ( fn (X), h(X))] + R Sn (G h ℓ d ,F ), where the last two inequality follows from the definition of G h ℓ d ,F (cf. Eq. ( 15)) and the standard symmetrization argument (Devroye et al., 2013; Mohri et al., 2018) , respectively. Now, the standard concentration results for empirical Rademacher complexity implies that, with probability at least 1 -δ, we have the following. R Sn (G h ℓ d ,F ) ≤ E S∼ Dn R Sn (G h ℓ d ,F ) + O log 1 δ • 1 n (31) = Rn (G h ℓ d ,F ) + O log 1 δ • 1 n . ( ) It follows from Lemma D.3 that R h ℓ d , fn (D X ) ≤ R h ℓ d , fn ( DX ) + W(D X , DX ) Now the first part of Theorem C.4, as stated in Eq. ( 27), follows by combining Eq. ( 30), Eq. ( 31), and Eq. ( 33). We now focus on establishing Eq. ( 28). Note that, for a sample Sn = {x 1 , . . . , xn } generated by the TGT framework, there exists {z 1 , . . . , z n } ⊂ R d such that xi = Dec(z i ), ∀i ∈ [n]. Thus, R Sn (G h ℓ d ,F ) = 1 n E σ sup g∈G h ℓ d ,F i σ i g(x i ) (i) = 1 n E σ sup g∈G h ℓ d ,F i σ i g(Dec(z i )) ≤ 1 n E σ sup g ′ ∈G h,Dec ℓ d ,F i σ i g ′ (z i ) = R Sn (G h,Dec ℓ d ,F ), where (i) employs Eq. ( 34). Thus, combining Eq. ( 30) and Eq. ( 35) gives us that R h ℓ d , fn ( DX ) ≤ E Sn [ℓ d ( fn (X), h(X))] + R Sn (G h,Dec ℓ d ,F ). Now, similar to the proof of Eq. ( 27), we can invoke Lemma D.3 and the concentration result for empirical Rademacher complexity to obtain the desired result in Eq. ( 28) from Eq. ( 36).  √ K • E D X ∥D Y |X -h(X)∥ 2 . This term captures the quality of the teacher labeler h.

C.3 WEIGHTED ERM: AN ALTERNATIVE TRAINING PROCEDURE FOR TGT

Note that given the samples Sn = {x i } i∈[n] generated from DX and a teacher labeler h, we minimize the following empirical risk for student training: R h ℓ d ,f ( Sn ) = 1 n i∈[n] ℓ d f (x i ), h(x i ) . However, as we notice in Theorem C.4, this leads to an additional W(D X , DX ) penalty term in the generalization bound. One standard approach to address this issue is to consider the following weighted empirical risk. R h,IS ℓ d ,f ( Sn ) = 1 n i∈[n] ℓ d f (x i ), h(x i ) • p D X (x i ) p DX (x i ) , Proof. Note that R h ℓ d ,f (D 1 X ) -R h ℓ d ,f (D 2 X ) = E X∼D 1 X [ℓ d (f (X), h(X)] -E X ′ ∼D 2 X [ℓ d (f (X ′ ), h(X ′ )] ≤ sup g∈G h ℓ d ,F E X∼D 1 X [g(X)] -E X ′ ∼D 1 X [g(X ′ )] (i) = L • sup g∈G h ℓ d ,F E X∼D 1 X g(X) L -E X ′ ∼D 1 X g(X ′ ) L (ii) ≤ L • sup g∈Lip 1 (ρ) E X∼D 1 X [g(X)] -E X ′ ∼D 1 X [g(X ′ )] , (iv) = W(D 1 X , D 2 X ). where (i) follow by dividing and multiply by L; (ii) follows as, for any g ∈ G h ℓ d ,F is g L is 1-Lipschitz; and (iii) follows from Lemma D.2. Lemma D.4. Let the distillation loss ℓ d satisfy Assumption C.1 with a bounded loss function ℓ : R K × Y → R. Then, given a teacher h : X → R K and a student model f : X → R K , we have R h ℓ d ,f (D X ) -R ℓ,f (D) ≤ O √ K • E D X ∥D Y |X -h(X)∥ 2 , where D Y |X = (D Y |X (1), . . . , D Y |X (K)) is treated as a vector in R K . Proof. Note that R h ℓ d ,f (D X ) -R ℓ,f (D) = E D X [ℓ d (f (X), h(X))] -R ℓ,f (D) = E D X [ℓ d (f (X), h(X))] -E D [ℓ(f (X), Y )] = E D X y∈[K] h(X) y • ℓ(f (x), y) -E D X y∈[K] D Y |X (y) • ℓ(f (X), y) = E D X y∈[K] h(X) y -D Y |X (y) • ℓ(f (X), y) (i) ≤ E D X [∥D Y |X -h(X)∥ 2 • ∥ℓ(f (X))∥ 2 ], where (i) follow from the Cauchy-Schwarz inequality. Now the statement of Lemma D.4 follows from the assumption on the loss ℓ is bounded.

E ADDITIONAL EXPERIMENTS E.1 LONG-TAIL IMAGE CLASSIFICATION

Please see Table 4 for Places365-LT result. The relevant discussion is provided in Section 4.1. We also provided an expanded version of Table 1 (from the main text) in Table 5 with additional baselines.

F DETAILS TO REPRODUCE OUR EMPIRICAL RESULTS

Hereby we provide details to reproduce our experimental results. F.1 LONG-TAIL IMAGE CLASSIFICATION (SEC. 4.1) Dataset. The full balanced version of 3 datasets (ImageNet 4 , Place365 5 , SUN397 6 ) are available in tensflow-datasets (https://www.tensorflow.org/datasets/). Next to obtain the the official repositoryfoot_9 and used default parameters from the codebase. We fine-tuned all 3 datasets (ImageNet-LT, SUN397-LT, Place365-LT) for 3 epochs. We directly used teacher generator as BigBiGAN ResNet-50 checkpoint from the official repository https://github.com/deepmind/deepmind-research/tree/master/ bigbigan. (We did not fine-tune it.) Student training. We start from randomly initialized MobileNetV3-0.75 model. We employed SGD optimizer with cosine schedule (peak learning rate of 0.4 and decay down to 0). We also did a linear warm-up (from 0 to peak learning rate of 0.4) for first 5 epochs. The input image size are unfortunately different between EfficientNet-B3 model, BigBiGAN-ResNet50, and MobileNetV3-0.75 models. From original images in dataset, we use Tensorflow's bicubic resizing to obtain appropriate size image for each mode. We did a grid search over the perturbation parameters σ and η (c.f. Eq. ( 4) and Eq. ( 5) Dataset. We used ImageNetfoot_10 dataset from tensflow-datasets repository (https://www. tensorflow.org/datasets/). We used in-built sub-sampling functionality available in tensorflow (https://www.tensorflow.org/datasets/splits) to simulate the low-data regime. Teacher model. For teacher labeler, we directly used trained EfficientNet-B3 model checkpoint available from "Sharpness Aware Minimization" repositoryfoot_11 For teacher generator, we directly used trained BigBiGAN checkpoint from the official repository https://github.com/deepmind/ deepmind-research/tree/master/bigbigan. (We did not fine-tune either of the models.) Student training. We start from randomly initialized MobileNetV3-0.75 model. We employed SGD optimizer with cosine schedule (peak learning rate of 0.4 and decay down to 0). We also did a linear warm-up (from 0 to peak learning rate of 0.4) for first 5 epochs. Optimizer. For all text retrieval model training, we employed ADAM optimizer with linear decay schedule (peak learning rate of 1e-5 and decay to 1e-7). We also did a linear warm-up (from 0 to peak learning rate of 1e-5) for 1K steps. We used batch size of 128. Teacher fine-tuning. For teacher labeler dual encoder (a question encoder and a passage encoder), we utilized RoBERTa-Base (Liu et al., 2019a) This same teacher labeler and generator is used for all student training except for the direct training (one-hot). Student training. We start from DistillBERT pretrained checkpoint downloaded from HuggingFace repository 19 . All students are trained with 40K steps. The teacher labeler will label all-pair within the batch and will label additional 2 passages per each question-passage pair for the uniform negative sampling baseline and TGT. We employed a off-the-shelf BART-base model as our generator (Lewis et al., 2020) and isotropic perturbation was added by random Gaussian noise of scale σ = 0.1 combined with p = 0.2 for masking the original passage.

G QUALITATIVE EXAMPLES OF GENERATED EXAMPLES G.1 IMAGE CLASSIFICATION

We show some representative examples of generated images using TGT-random as well as TGTgradient based from the experiment on ImageNet classification in Table 8 .

G.2 TEXT CLASSIFICATION

We show some representative examples of generated text using TGT from the experiment on MNLI classification in Table 9 . Then I got up as softly as I could, and felt in the dark along the left-hand wall. [SEP] The wall was wet. Then I got up as softly as I could, and walked the way I felt in the dark along the left [SEP] The wall was wet. Data label: Entails Teacher label: Entail But then this very particular island is hardly in danger of being invaded except, of course, by tourism. [SEP] This island is least likely to be invaded by tourism. But then this very particular island is not in danger of being invaded except, of course, by tourism. [SEP] The island is likely to be invaded by tourism.

Data label: Contradicts

Teacher label: Neutral All you need to do is just wander off the beaten path, beyond the bustling tourist zone. [SEP] There is no point going off the beaten path, there is nothing there. All you need to do is just wander off the beaten path, and you ĺl be in the bustling tourist zone of the city. [SEP] There is no point going off the beaten path, there is nothing there.

Data label: Entails

Teacher label: Neutral The silt of the River Maeander has also stranded the once-mighty city of Miletus. [SEP] The River Maeander has been depositing silt near Miletus for nearly two millennia. The silt of the River Mae has also stranded the once-mighty city of Miletus. [SEP] The River Maeander has been depositing silt near Miletus for more than two decades. Data label: Entails Teacher label: Entails It was hardly the most enlightened of times, not with the conflict in Indochina rapidly becoming Americaś costliest and most divisive war. [SEP] The war in Indochina has cost America 100 billion dollars so far. It was hardly the most enlightened of times, not with the war in Indochina becoming Americaś costliest and most divisive war. [SEP] The war in Indochina has cost America 100 billion dollars so far. 



† Equal contribution. We empirically show the superiority of TGT to existing state-of-the-art distillation methods on both vision and NLP tasks, unlike most prior work which is specialized to one domain. Results for Places-LT and additional baselines for ImageNet-LT and SUN-LT are in Appendix E. For the sake of brevity, we simply include the softmax-operation in the definition of h and f , i.e., h(x) and f (x) are valid probability distributions over Y = [K]. Note that the formulation assumes that DX ≪ DX , i.e., DX is absolutely continuous w.r.t. DX . Also, one can replace the pdf's with probability mass functions if DX and DX are discrete distributions. https://www.tensorflow.org/datasets/catalog/imagenet2012 https://www.tensorflow.org/datasets/catalog/places365_small https://www.tensorflow.org/datasets/catalog/sun397 https://drive.google.com/drive/u/1/folders/1j7Nkfe6ZhzKFXePHdsseeeGI877Xu1yf https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/ noisystudent/efficientnet-b3/checkpoint.tar.gz https://www.tensorflow.org/datasets/catalog/imagenet2012 https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/ noisystudent/efficientnet-b3/checkpoint.tar.gz https://dl.fbaipublicfiles.com/dpr_scale/paq/PAQ.dpr.train.neg1.jsonl. zip https://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz https://dl.fbaipublicfiles.com/fairseq/models/bart.base.tar.gz https://huggingface.co/distilroberta-base/tree/main



Figure 1: An overview of the proposed teacher guided training (TGT) framework. Given a learning task, the framework leverages a large teacher with a pretrained generator and labeler that exhibits high performance on the task. In particular, we assume that the generator consists of an encoder and a decoder. TGT performs three key operations during student model training: (1) Given an original training instance, by using the teacher generator, identify a novel task-relevant instance. We search for informative instances in the lower dimensional latent space, where we can propagate the gradient to. (2) Obtain (soft) labels for the original and newly generated training instances from the teacher labeler; and (3) Minimize the student training objective that depends on the original and newly generated instances along with their labels produced by the teacher labeler.

Figure 2: Comparison among normal training (one-hot), standard distillation (distillation), and TGT in simulated low-data regimes. We imitate a low-data regime via subsampling the ImageNet training set and evaluate the trained student models on the entire eval set. We employ 450k training steps for normal training and standard distillation, and 112k training steps for TGT. TGT outperforms other methods in less training steps, thus, effectively simulating an increased sample size.

Fig.2shows that both TGT and standard distillation utilize additional training data more effectively than normal training, with TGT being the most efficient of the two. Interestingly, employing TGT is equivalent to an increase in sample size by 4x, compared to the normal training. This verifies that TGT generates informative training instances for the student.

here, one of the most common choices for the distillation loss, indeed satisfies Assumption C.1. 2

Note that, if the distillation loss ℓ d satisfies Assumption C.1 with a loss function ℓ, then, one can combine Theorem C.4 and Lemma D.4 to readily obtain bounds on R ℓ, fn (D) with an additional term O

student's performance, resulting in wasteful exploration of those regions of the input spaces where the student is already good. Further, unlike TGT, they search examples in the input space which is often inefficient due to the large ambient dimension of the input space.

defined in Thm. 3.1Please see Appendix C.2 for a more precise statement and proof of Thm. 3.2. Comparing with the generalization gap for standard distillation (cf. Thm. 3.1), the generalization gap for TGT in Thm. 3.2 does not have the reconstruction error related term Lϵ. Thus, by working with the samples with exact latent representation, TGT avoids this reconstruction error penalty. On the other hand, generalization gap for TGT does have an additional term W(D X , DX ), capturing the mistmatch between the original data distribution and the distribution of the samples used by TGT. Performance of TGT and various baselines on long-tail image classification benchmarks (see Appendix E for results on Places-LT). Rows with * denote results taken fromMenon et al. (2021b)   and the rest were taken fromSamuel et al. (

Performance of TGT and various baselines from the literature on four text classification benchmarks. For student model training, we show results for task-specific finetuning on both randomly initialized and pretrained DistilBERT models. Note that TGT (Pretrained) -TGT with a pretrained student model -outperforms all other methods across the board. Even more interestingly, on Amazon-5 and Yelp-5, TGT with randomly initialized student, i.e., TGT (Random Init), outperforms the standard approach of finetuning a pretrained model with one-hot labels, i.e., One-hot (Pretrained).

). hyper-parameters and grid are listed in table below:

Probably Asked Questions (PAQ) dataset(Lewis et al., 2021) dataset from official repository of "Domain-matched Pre-training Tasks for Dense Retrieval" available at https://github.com/ facebookresearch/dpr-scale which has been aligned to the same passage corpus16

pretrained checkpoint 17 from official FAIRSEQ repository https://github.com/facebookresearch/fairseq. We then conducted first round of fine-training for 300k iterations with passage-aligned PAQ dataset. We used same configuration as Oguz et al. (2021) except Oguz et al. trained with PAQ longer. After the pretraining, the teacher is fine-tuned on NQ-open (Kwiatkowski et al., 2019) downloaded with 40K steps. Similar to Karpukhin et al. (2020); Oguz et al. (2021), the teacher is trained with within-batch negatives and the softmax-based cross-entropy loss.For teacher generator, we directly use a pre-trained BART-Base(Lewis et al., 2020) checkpoint 18 from official FAIRSEQ repository https://github.com/facebookresearch/fairseq. (We did not fine-tune it.)

NeutralThe house was bought with the royalties she earned from her first book, The Tales of Peter Rabbit. [SEP] The house was bought with the money she inherited from her grandfather. The book was published in the United States in 1987 with the royalties she received from her first book, The Tales of Peter Rabbit. [SEP] The house was bought with the money she inherited from her grandfather. Data label: Entail Teacher label: Entail Leather goods are no longer a bargain in Spain, though very good quality products may still be priced lower than at home. [SEP] Leather goods are still very cheap in Spain. Leather and leather goods are no longer a bargain in Spain, though very good quality products may still be priced lower than at home and abroad. [SEP] Leather goods are still very cheap at Spain.

Text examples

annex

where p D X and p DX denote the probability density function (pdf) for D X and DX . 3 Accordingly, we define a new induced function class related to the weighted empirical risk:Importantly, we haveThus, following the analysis utilized in Theorem C.4, one can obtain a high probability generalization of the form.which avoids the W(D X , DX ) term.In what follows, we explore an alternative approach to highlight the importance of the sampling approach adapted by (gradient-based) TGT. By leveraging the variance-based generalization bound (Maurer & Pontil, 2009 ) that were previously utilized by Menon et al. (2021a) in the context distillation, we obtain the following result for the weighted empirical risk in Eq. ( 38).Proposition C.6. Let h, ℓ d , F and Sn be as defined in the statement of Theorem C.4. Further, assume thatis bounded for all x ∈ supp( DX ). Then, for any f ∈ F, the following holds with probability at least 1 -δ.where (I) denotesHere,denoting the covering number (Devroye et al., 2013) of the setProof. By utilizing the uniform convergence version of Bennet's inequality and uniform bound for Var Sn (ℓ IS d (x)), where Var Sn (ℓ IS d (x)) denotes the empirical variance of ℓ IS d (x) based on Sn , the following holds with probability at least 1 -δ (Maurer & Pontil, 2009) ., the statement of Theorem C.6 follows from Eq. ( 43).Note that by combining Eq. ( 42) with Theorem D.4 translate the bound on R h ℓ d ,f (D X ) to a bound on R ℓ,f (D) with an additional penalty term that depends on the quality of the teacher labeler h.Remark C.7. Eq. ( 42) suggests general approach to select the distribution DX that generated the training samples Sn . In order to ensure small generalization gap, it is desirable that the variance term Var DX (ℓ IS d (x)) is as small as possible. Note that, the distribution that minimizes this variance takes the formThis looks like the lagrangian form of Eq. (3). Interestingly, TGT framework with gradient-based sampling (cf. equation 5) focuses on generating samples that maximizes the right hand side RHS of Eq. ( 44) by first taking a sample generated according to D X and then perturbing it in the latent space to maximize the loss ℓ d f (x), h(x) . Thus, the resulting distribution DX has pdf that aims to approximate the variance minimizing pdf in Eq. ( 44).Here it is worth pointing out that, since exact form of p DX (•) and p D X (•) is generally not available during the training, it's not straightforward to optimize the weighted risk introduced in Eq. ( 38).As introduced in Section 3, TGT framework optimizes the empirical risk in Eq. ( 37) as opposed to minimizing Eq. ( 38). In this case, one can obatain a variance based bound analogous to Eq. ( 42) that takes the form:where, (II) denotesand M(n) depending the covering number for the induced function class G h ℓ d ,F (cf. Eq. ( 15)). Notably, this bound again incurs a penalty of W(D X , DX ). Remark C.8. Note that Eq. ( 45) suggests a general approach to select the distribution DX that generates the training samples Sn . In order to ensure small generalization gap, we need to focus on two terms depending on DX : (1) the variance term Var DX (ℓ h d,f (x)); and (2) the divergence term W(D X , DX ). We note that finding a distribution that jointly minimizes both terms is a non-trivial task. That said, in our sampling approach in Eq. ( 5), we control for W(D X , DX ) by operating in the latent space of a good quality teacher generative model and minimize variance by finding points with high loss values through gradient ascent, thereby striking a balance between the two objectives.

D TOOLBOX

This section presents necessary definitions and lemmas that we utilize to establish our theoretical results presented in Section 3 (and restated in Appendix C. Definition D.1 (Wasserstein-1 metric). Let (X, ρ) be a metric space. Given two probability distributions D 1 X and D 2 X over X, Wasserstein-1 distance between D 1 X and D 2 X is defined as follows.where Π(D 1 X , D 2 X ) denotes the set of all joint distributions over X × X that have D 1 X and D 2 X as their marginals. Lemma D.2 (Kantorovich-Rubinstein duality (Villani, 2008) ). Let Lip 1 (ρ) denote the set of all 1-Lipschitz functions in the metric ρ, i.e., for anyLemma D.3. Let ℓ d : R K × R K → R be a loss function employed during the distillation. For a given teacher h : X → R K and a function class F, we assume the the induced function class49) is contained in the class of L-Lipschitz functions with respect to a metric ρ. Then, for any two distributions D 1 X and D 2 X , we havewhere W(D 1 X , D 2 X ) denotes the Wasserstein-1 metric between the two distribution D 1 X and D 2 X (cf. Definition D.1). (Liu et al., 2019c) . The table shows the top-1 accuracy on the corresponding balanced eval sets for TGT and different long-tail baselines from the literature (taken from (Samuel et al., 2021) ). We also state the number of model parameters and inference cost (in terms of FLOPs) for all the methods. Note that TGT leads to performance improvements over standard distillation. Note that, for Places-LT, TGT does not outperform stated baselines for the literature that rely on specialized loss functions and/or training procedures designed from the long-tail setting. One reason for this could be that the BigBiGAN does not generate very informative samples for Places-LT due to distribution mismatch. That said, as discussed in Section 4.1, one can combine the TGT framework with a long-tail specific loss functions as opposed to employing the standard cross-entropy loss function as a way to further improve its performance. (Schönfeld et al., 2019) ResNet-101 32.8 42 M 7.6 B LWS (Kang et al., 2020) ResNeXt-50 33.9 25 M 4.2 B DRAGON + Bal'Loss (Samuel et al., 2021) ResNet We report top-1 accuracy on balanced eval sets. We also state the number of model parameters and inference cost (in terms of FLOPs) for all the methods. Note that TGT leads to performance improvements over standard distillation on both datasets, particularly for ImageNet-LT where the teacher generator models the task distribution well. TGT also often outperforms stated baselines that rely on much larger and expensive models.long-tail version of the datasets, we downloaded 7 image ids from repository of "Large-Scale Long-Tailed Recognition in an Open World (Liu et al., 2019b) " according to which we subsampled the full balanced dataset.Teacher fine-tuning. For teacher labeler, we follow "Sharpness Aware Minimization' (Foret et al., 2020) codebase (available at https://github.com/google-research/sam) to fine-tune on the long-tail datasets. We start with pretrained EfficientNet-B3 model checkpoint available from F.3 TEXT CLASSIFICATION (SEC. 4.3)Dataset. We conduct text classification experiments on following datasets:• Amazon-5 downloaded from http://goo.gl/JyCnZq • IMDB from tensorflow-datasets https://www.tensorflow.org/datasets/ catalog/imdb_reviews • MNLI from from tensorflow-datasets https://www.tensorflow.org/datasets/ catalog/multi_nli • Yelp-5 downloaded from http://goo.gl/JyCnZq Optimizer. For all training, we employed ADAM optimizer with linear decay schedule (peak learning rate of 3e-5 and decay to 0). We also did a linear warm-up at start. We used batch size of 128.Teacher fine-tuning. For teacher labeler, we started from RoBERTa-Base (Liu et al., 2019a) pretrained checkpoint 11 from official FAIRSEQ repository https://github.com/ facebookresearch/fairseq. We fine-tuned using default parameters, other than number of steps which are same as those listed in Table 7 .For teacher generator, we directly use a pre-trained BART-Base (Lewis et al., 2020 ) checkpoint 12 from official FAIRSEQ repository https://github.com/facebookresearch/fairseq. (We did not fine-tune it.)Student training. We start from DistillBERT pretrained checkpoint downloaded from HuggingFace repository 13 . We perturb by adding Gaussian noise of σ 2 variance in between encoder-decoder as well as masking out p fraction of input. Then we generate new examples by running a greedy decoding of BART teacher generator for sequence length of 512. For dual input classification task, like in MNLI, we generate the two inputs independently. We did a grid search over the perturbation parameters σ and masking fraction p. 

