MIXKD: TOWARDS EFFICIENT DISTILLATION OF LARGE-SCALE LANGUAGE MODELS

Abstract

Large-scale language models have recently demonstrated impressive empirical performance. Nevertheless, the improved results are attained at the price of bigger models, more power consumption, and slower inference, which hinder their applicability to low-resource (both memory and computation) platforms. Knowledge distillation (KD) has been demonstrated as an effective framework for compressing such big models. However, large-scale neural network systems are prone to memorize training instances, and thus tend to make inconsistent predictions when the data distribution is altered slightly. Moreover, the student model has few opportunities to request useful information from the teacher model when there is limited task-specific data available. To address these issues, we propose MixKD, a data-agnostic distillation framework that leverages mixup, a simple yet efficient data augmentation approach, to endow the resulting model with stronger generalization ability. Concretely, in addition to the original training examples, the student model is encouraged to mimic the teacher's behavior on the linear interpolation of example pairs as well. We prove from a theoretical perspective that under reasonable conditions MixKD gives rise to a smaller gap between the generalization error and the empirical error. To verify its effectiveness, we conduct experiments on the GLUE benchmark, where MixKD consistently leads to significant gains over the standard KD training, and outperforms several competitive baselines. Experiments under a limited-data setting and ablation studies further demonstrate the advantages of the proposed approach.

1. INTRODUCTION

Recent language models (LM) pre-trained on large-scale unlabeled text corpora in a self-supervised manner have significantly advanced the state of the art across a wide variety of natural language processing (NLP) tasks (Devlin et al., 2018; Liu et al., 2019c; Yang et al., 2019; Joshi et al., 2020; Sun et al., 2019b; Clark et al., 2020; Lewis et al., 2019; Bao et al., 2020) . After the LM pretraining stage, the resulting parameters can be fine-tuned to different downstream tasks. While these models have yielded impressive results, they typically have millions, if not billions, of parameters, and thus can be very expensive from storage and computational standpoints. Additionally, during deployment, such large models can require a lot of time to process even a single sample. In settings where computation may be limited (e.g. mobile, edge devices), such characteristics may preclude such powerful models from deployment entirely. One promising strategy to compress and accelerate large-scale language models is knowledge distillation (Zhao et al., 2019; Tang et al., 2019; Sun et al., 2020) . The key idea is to train a smaller model (a "student") to mimic the behavior of the larger, stronger-performing, but perhaps less practical model (the "teacher"), thus achieving similar performance with a faster, lighter-weight model. A simple but powerful method of achieving this is to use the output probability logits produced by the teacher model as soft labels for training the student (Hinton et al., 2015) . With higher entropy than one-hot labels, these soft labels contain more information for the student model to learn from. Previous efforts on distilling large-scale LMs mainly focus on designing better training objectives, such as matching intermediate representations (Sun et al., 2019a; Mukherjee & Awadallah, 2019) , learning multiple tasks together (Liu et al., 2019a) , or leveraging the distillation objective during the pre-training stage (Jiao et al., 2019; Sanh et al., 2019) . However, much less effort has been made to enrich task-specific data, a potentially vital component of the knowledge distillation procedure. In particular, tasks with fewer data samples provide less opportunity for the student model to learn from the teacher. Even with a well-designed training objective, the student model is still prone to overfitting, despite effectively mimicking the teacher network on the available data. In response to these limitations, we propose improving the value of knowledge distillation by using data augmentation to generate additional samples from the available task-specific data. These augmented samples are further processed by the teacher network to produce additional soft labels, providing the student model more data to learn from a large-scale LM. Intuitively, this is akin to a student learning more from a teacher by asking more questions to further probe the teacher's answers and thoughts. In particular, we demonstrate that mixup (Zhang et al., 2018) can significantly improve knowledge distillation's effectiveness, and we show with a theoretical framework why this is the case. We call our framework MixKD. We conduct experiments on 6 GLUE datasets (Wang et al., 2019) across a variety of task types, demonstrating that MixKD significantly outperforms knowledge distillation (Hinton et al., 2015) and other previous methods that compress large-scale language models. In particular, we show that our method is especially effective when the number of available task data samples is small, substantially improving the potency of knowledge distillation. We also visualize representations learned with and without MixKD to show the value of interpolated distillation samples, perform a series of ablation and hyperparameter sensitivity studies, and demonstrate the superiority of MixKD over other BERT data augmentation strategies.

2. RELATED WORK

2.1 MODEL COMPRESSION Compressing large-scale language models, such as BERT, has attracted significant attention recently. Knowledge distillation has been demonstrated as an effective approach, which can be leveraged during both the pre-training and task-specific fine-tuning stages. Prior research efforts mainly focus on improving the training objectives to benefit the distillation process. Specifically, Turc et al. (2019) advocate that task-specific knowledge distillation can be improved by first pre-training the student model. It is shown by Clark et al. (2019) that a multi-task BERT model can be learned by distilling from multiple single-task teachers. Liu et al. (2019b) propose learning a stronger student model by distilling knowledge from an ensemble of BERT models. Patient knowledge distillation (PKD), introduced by Sun et al. (2019a) , encourages the student model to mimic the teacher's intermediate layers in addition to output logits. DistilBERT (Sanh et al., 2019) reduces the depth of BERT model by a factor of 2 via knowledge distillation during the pre-training stage. In this work, we evaluate MixKD on the case of task-specific knowledge distillation. Notably, it can be extended to the pre-training stage as well, which we leave for future work. Moreover, our method can be flexibly integrated with different KD training objectives (described above) to obtain even better results. However, we utilize the BERT-base model as the testbed in this paper without loss of generality.

2.2. DATA AUGMENTATION IN NLP

Data augmentation (DA) has been studied extensively in computer vision as a powerful technique to incorporate prior knowledge of invariances and improve the robustness of learned models (Simard et al., 1998; 2003; Krizhevsky et al., 2012) . Recently, it has also been applied and shown effective on natural language data. Many approaches can be categorized as label-preserving transformations, which essentially produce neighbors around a training example that maintain its original label. For example, EDA (Wei & Zou, 2019) propose using various rule-based operations such as synonym replacement, word insertion, swap or deletion to obtain augmented samples. Back-translation (Yu et al., 2018; Xie et al., 2019) is another popular approach belonging to this type, which relies on pre-trained translation models. Additionally, methods based on paraphrase generation have also been leveraged from the data augmentation perspective (Kumar et al., 2019) . On the other hand, label-altering techniques like mixup (Zhang et al., 2018) have also been proposed for language (Guo et al., 2019; Chen et al., 2020) , producing interpolated inputs and labels for the models predict. The proposed MixKD framework leverages the ability of mixup to facilitate the student learning more information from the teacher. It is worth noting that MixKD can be combined with arbitrary labelpreserving DA modules. Back-translation is employed as a special case here, and we believe other advanced label-preserving transformations developed in the future can benefit the MixKD approach as well. 2.3 MIXUP Mixup (Zhang et al., 2018) is a popular data augmentation strategy to increase model generalizability and robustness by training on convex combinations of pairs of inputs and labels (x i , y i ) and (x j , y j ): x = λx i + (1 -λ)x j (1) y = λy i + (1 -λ)y j (2) with λ ∈ [0, 1] and (x , y ) being the resulting virtual training example. This concept of interpolating samples was later generalized with Manifold mixup (Verma et al., 2019a) and also found to be effective in semi-supervised learning settings (Verma et al., 2019b; c; Berthelot et al., 2019b; a) . Other strategies include mixing together samples resulting from chaining together other augmentation techniques (Hendrycks et al., 2020) , or replacing linear interpolation with the cutting and pasting of patches (Yun et al., 2019) .

3. METHODOLOGY

3.1 PRELIMINARIES In NLP, an input sample i is often represented as a vector of tokens w i = {w i,1 , w i,2 , ..., w i,T }, with each token w i,t ∈ R V a one-hot vector often representing words (but also possibly subwords, punctuation, or special tokens) and V being the vocabulary size. These discrete tokens are then mapped to word embeddings x i = {x i,1 , x i,2 , ..., x i,T }, which serve as input to the machine learning model f . For supervised classification problems, a one-hot label y i ∈ R C indicates the ground-truth class of x i out of C possible classes. The parameters θ of f are optimized with some form of stochastic gradient descent so that the output of the model f (x i ) ∈ R C is as close to y i as possible, with cross-entropy as the most common loss function: L MLE = - 1 n n i y i • log(f (x i )) ( ) where n is the number of samples, and • is the dot product.

3.2. KNOWLEDGE DISTILLATION FOR BERT

Consider two models f and g parameterized by θ T and θ S , respectively, with |θ T | |θ S |. Given enough training data and sufficient optimization, f is likely to yield better accuracy than g, due to higher modeling capacity, but may be too bulky or slow for certain applications. Being smaller in size, g is more likely to satisfy operational constraints, but its weaker performance can be seen as a disadvantage. To improve g, we can use the output prediction f (x i ) on input x i as extra supervision for g to learn from, seeking to match g(x i ) with f (x i ). Given these roles, we refer to g as the student model and f as the teacher model. While there are a number of recent large-scale language models driving the state of the art, we focus here on BERT (Devlin et al., 2018) models. Following Sun et al. (2019a) , we use the notation BERT k to indicate a BERT model with k Transformer (Vaswani et al., 2017) layers. While powerful, BERT models also tend to be quite large; for example, the default bert-base-uncased (BERT 12 ) has ∼110M parameters. Reducing the number of layers (e.g. using BERT 3 ) makes such models significantly more portable and efficient, but at the expense of accuracy. With a knowledge distillation set-up, however, we aim to reduce this loss in performance.

3.3. MIXUP DATA AUGMENTATION FOR KNOWLEDGE DISTILLATION

While knowledge distillation can be a powerful technique, if the size of the available data is small, then the student has only limited opportunities to learn from the teacher. This may make it much harder for knowledge distillation to close the gap between student and teacher model performance. To correct this, we propose using data augmentation for knowledge distillation. While data augmentation (Yu et al., 2018; Xie et al., 2019; Yun et al., 2019; Kumar et al., 2019; Hendrycks et al., 2020; Shen et al., 2020; Qu et al., 2020 ) is a commonly used technique across machine learning for increasing training samples, robustness, and overall performance, a limited modeling capacity constrains the representations the student is capable of learning on its own. Instead, we propose using the augmented samples to further query the teacher model, whose large size often allows it to learn more powerful features. While many different data augmentation strategies have been proposed for NLP, we focus on mixup (Zhang et al., 2018) for generating additional samples to learn from the teacher. Mixup's vicinal risk minimization tends to result in smoother decision boundaries and better generalization, while also being cheaper to compute than methods such as backtranslation (Yu et al., 2018; Xie et al., 2019) . Mixup was initially proposed for continuous data, where interpolations between data points remain in-domain; its efficacy was demonstrated primarily on image data, but examples in speech recognition and tabular data were also shown to demonstrate generality. Directly applying mixup to NLP is not quite as straightforward as it is for images, as language commonly consists of sentences of variable length, each comprised of discrete word tokens. Since performing mixup directly on the word tokens doesn't result in valid language inputs, we instead perform mixup on the word embeddings at each time step x i,t (Guo et al., 2019) . This can be interpreted as a special case of Manifold mixup Verma et al. (2019a) , where the mixing layer is set to the embedding layer. In other words, mixup samples are generated as: x i,t = λx i,t + (1 -λ)x j,t ∀t y i = λy i + (1 -λ)y j (5) with λ ∈ [0, 1]; random sampling of λ from a Uniform or Beta distribution are common choices. Note that we index the augmented sample with i regardless of the value of λ. Sentence length variability can be mitigated by grouping mixup pairs by length. Alternatively, padding is a common technique for setting a consistent input length across samples; thus, if x (i) contains more word tokens than x (j) , then the extra word embeddings are mixed up with zero paddings. We find this approach to be effective, while also being much simpler to implement. We query the teacher model with the generated mixup sample x i , producing output prediction f (x i ). The student is encouraged to imitate this prediction on the same input, by minimizing the objective: L TMKD = d(f (x i ), g(x i )) where d(•, •) is a distance metric for distillation, with temperature-adjusted cross-entropy and mean square error (MSE) being common choices. Since we have the mixup samples already generated (with an easy-to-generate interpolated pseudolabel y i ), we can also train the student model on these augmented data samples in the usual way, with a cross-entropy objective: L SM = - 1 n n i y i • log(g(x i )) Our final objective for MixKD is a sum of the original data cross-entropy loss, student cross-entropy loss on the mixup samples, and knowledge distillation from the teacher on the mixup samples: L = L MLE + α SM L SM + α TMKD L TMKD (8) where α SM and α TMKD are hyperparameters weighting the loss terms.

3.4. THEORETICAL ANALYSIS

We develop a theoretical foundation for the proposed framework. We wish to prove that by adopting data augmentation for knowledge distillation, one can achieve i) a smaller gap between generalization error and empirical error, and ii) better generalization. To this end, assume the original training data {x i } n i=1 are sampled i.i.d. from the true data distribution p(x), and the augmented data distribution by mixup is denoted as q(x) (apparently p and q are dependent). Let f be the teacher function, and g ∈ G be the learnable student function. Denote the loss function to learn g as l(•, •) 1 . The population risk w.r.t. p(x) is defined as R(f, g, p) = E x∼p(x) [l(f (x), g(x))], and the empirical risk as R emp (f, g, {x i } n i=1 ) = 1 n n i=1 l(f (x i ), g(x i )) . A classic statement for generalization is the following: with at least 1 -δ probability, we have R(f, g p , p) -R emp (f, g p , {x i } n i=1 ) ≤ , where > 0, and we have used g p to indicate that the function is learned based on p(x). Note different training data would correspond to a different error in equation 9. We use p to denote the minimum value over all 's satisfying equation 9. Similarly, we can replace p with q, and {x i } n i=1 with {x i } a i=1 ∪ {x i } b i=1 in equation 9 in the data-augmentation case. In this case, the student function is learned based on both the training data and augmented data, which we denote as g * . Similarly, we also have a corresponding minimum error, which we denote as * . Consequently, our goal of better generalization corresponds to proving R(f, g * , p) ≤ R(f, g p , p), and the goal of a smaller gap corresponds to proving * ≤ p . In our theoretical results, we will give conditions when these goals are achievable. First, we consider the following three cases about the joint data X {x i } a i=1 ∪ {x i } b i=1 and the function class G: • Case 1: There exists a distribution p such that X are i.i.d. samples from itfoot_1 ; G is a finite set. • Case 2: There exists p such that X are i.i.d. samples from it; G is an infinite set. • Case 3: There does not exist a distribution p such that X are i.i.d. samples from it. Our theoretical results are summarized in Theorems 1-3, which state that with enough augmented data, our method can achieve smaller generalization errors. Proofs are given in the Appendix. Theorem 1 Assume the loss function l(•, •) is upper bounded by M > 0. Under Case 1, there exists a constant c > 0 such that if b ≥ M 2 log(|G|/δ) c -a then * ≤ p where * and p denote the minimal generalization gaps one can achieve with or without augmented data, with at least 1 -δ probability. If further assuming a better empirical risk with data augmentation (which is usually the case in practice), i.e., R emp (f, g * , {x i } a i=1 ∪ {x i } b i=1 ) ≤ R emp (f, g p , {x i } n i=1 ), we have R(f, g * , p) ≤ R(f, g p , p) Theorem 2 Assume the loss function l(•, •) is upper bounded by M > 0 and Lipschitz continuous. Fix the probability parameter δ. Under Case 2, there exists a constant c > 0 such that if b ≥ M 2 log(1/δ) c -a then * ≤ p where * and p denote the minimal generalization gaps one can achieve with or without augmented data, with at least 1 -δ probability. If further assuming a better empirical risk with data augmentation, i.e., R emp (f, g * , {x i } a i=1 ∪ {x i } b i=1 ) ≤ R emp (f, g p , {x i } n i=1 ), we have R(f, g * , p) ≤ R(f, g p , p) A more interesting setting is Case 3. Our result is based on Baxter (2000) , which studies learning from different and possibly correlated distributions. Theorem 3 Assume the loss function l(•, •) is upper bounded. Under Case 3, there exists constants  c 1 , c 2 , c 3 > 0 such that if b ≥ a log(4/δ) c 1 a -c (f, g * , {x i } a i=1 ∪ {x i } b i=1 ) ≤ R emp (f, g p , {x i } n i=1 ), we have R(f, g * , p) ≤ R(f, g p , p) Remark 4 For Theorem 3 to hold, based on Baxter (2000) , it is enough to ensure {x i , x i } and {x j , x j } to be independent for i = j. We achieve this by constructing x i with x i and an extra random sample from the training data. Since all (x i , x j ) and the extra random samples are independent, the resulting concatenation will also be independent.

4. EXPERIMENTS

We demonstrate the effectiveness of MixKD on a number of GLUE (Wang et al., 2019) dataset tasks: Stanford Sentiment Treebank (SST-2) (Socher et al., 2013) , Microsoft Research Paraphrase Corpus (MRPC) (Dolan & Brockett, 2005) , Quora Question Pairs (QQP)foot_2 , Multi-Genre Natural Language Inference (MNLI) (Williams et al., 2018) , Question Natural Language Inference (QNLI) (Rajpurkar et al., 2016) , and Recognizing Textual Entailment (RTE) (Dagan et al., 2005; Haim et al., 2006; Giampiccolo et al., 2007; Bentivogli et al., 2009) . Note that MNLI contains both an in-domain (MNLI-m) and cross-domain (MNLI-mm) evaluation set. These datasets span sentiment analysis, paraphrase similarity matching, and natural language inference types of tasks. We use the Hugging Face Transformersfoot_3 implementation of BERT for our experiments.

4.1. GLUE DATASET EVALUATION

We first analyze the contributions of each component of our method, evaluating on the dev set of the GLUE datasets. For the teacher model, we fine-tune a separate 12 Transformer-layer bert-base-uncased (BERT 12 ) for each task. We use the smaller BERT 3 and BERT 6 as the student model. We find that initializing the embeddings and Transformer layers of the student model from the first k layers of the teacher model provides a significant boost to final performance. We use MSE as the knowledge distillation distance metric d(•, •). We generate one mixup sample for each original sample in each minibatch (mixup ratio of 1), with λ ∼ Beta(0.4, 0.4). We set hyperparameters weighting the components in the loss term in equation 8 as α SM = α TMKD = 1. As a baseline, we fine-tune the student model on the task dataset without any distillation or augmentation, which we denote as BERT k -FT. We compare this against MixKD, with both knowledge distillation on the teacher's predictions (L TMKD ) and mixup for the student (L SM ), which we call BERT k -SM+TMKD. We also evaluate an ablated version without the student mixup loss (BERT k -TMKD) to highlight the knowledge distillation component specifically. We note that our method can also easily be combined with other forms of data augmentation. For example, backtranslation (translating an input sequence to the data space of another language and then translating back to the original language) tends to generate varied but semantically similar sequences; these sentences also tend to be of higher quality than masking or word-dropping approaches. We show that our method has an additive effect with other techniques by also testing our method with the dataset augmented with German backtranslation, using the fairseq (Ott et al., 2019) neural machine translation codebase to generate these additional samples. We also compare all of the aforementioned variants with backtranslation samples augmenting the data; we denote these variants with an additional +BT. We report the model accuracy (and F 1 score, for MRPC and QQP) in Table 1 . We also show the performance of the full-scale teacher model (BERT 12 ) and DistilBERT (Sanh et al., 2019) , which performs basic knowledge distillation during BERT pre-training to a 6-layer model. For our method, we observe that a combination of data augmentation and knowledge distillation leads to significant gains in performance, with the best variant often being the combination of teacher mixup knowledge distillation, student mixup, and backtranslation. In the case of SST-2, for example, BERT 6 -SM+TMKD+BT is able to capture 99.88% of the performance of the teacher model, closing 91.27% of the gap between the fine-tuned student model and the teacher, despite using far fewer parameters and having a much faster inference speed (Table 2 ).

Model

After analyzing the contributions of the components of our model on the dev set, we find the SM+TMKD+BT variant to have the best performance overall and thus focus on this variant. We submit this version of MixKD to the GLUE test server, reporting its results in comparison with fine-tuning (FT), vanilla knowledge distillation (KD) (Hinton et al., 2015) , and patient knowledge distillation (PKD) (Sun et al., 2019a) in Table 3 . Once again, we observe that our model outperforms the baseline methods on most tasks.

4.2. LIMITED-DATA SETTINGS

One of the primary motivations for using data augmentation for knowledge distillation is to give the student more opportunities to query the teacher model. For datasets with a large enough number of samples relative to the task's complexity, the original dataset may provide enough chances to learn from the teacher, reducing the relative value of data augmentation. As such, we also evaluate MixKD with a BERT 3 student on downsampled versions of QQP, MNLI (matched and mismatched), and QNLI in Figure 1 . We randomly select 10% and 1% of the data from 

4.4. HYPERPARAMETER SENSITIVITY & FURTHER ANALYSIS

Loss Hyperparameters Our final objective in equation 8 has hyperparameters α SM and α TMKD , which control the weight of the student model's cross-entropy loss for the mixup samples and the knowledge distillation loss with the teacher's predictions on the mixup samples, respectively. We demonstrate that the model is fairly stable over a wide range by sweeping both α SM and α TMKD over the range {0.1, 0.5, 1.0, 2.0, 10.0}. We do this for a BERT 3 student and BERT 12 teacher, with SST-2 as the task; we show the results of this sensitivity study, both with and without German backtranslation, in Figure 3 . Given the overall consistency, we observe that our method is stable over a wide range of settings.

Mixup Ratio

We also investigate the effect of the mixup ratio: the number of mixup samples generated for each sample in a minibatch. We run a smaller sweep of α SM and α TMKD over the range {0.5, 1.0, 2.0} for mixup ratios of 2 and 3 for a BERT 3 student SST-2, with and without German backtranslation, in Figure 3 . We conclude that the mixup ratio does not have a strong effect on overall performance. Given that higher mixup ratio requires more computation (due to more samples over which to compute the forward and backward pass), we find a mixup ratio of 1 to be enough. 

A PROOFS

Proof [Proof of Theorem 1] First of all, {x i } a i=1 ∪ {x i } b i=1 can be regarded as drawn from distribution r(x) = ap(x) + bq(x) a + b . Given G is finite, we have the following theorem Theorem 5 (Mohri et al., 2018) Let l be a bounded loss function, hypothesis set G is finite. Then for any δ > 0, with probability at least 1 -δ, the following inequality holds for all g ∈ G: R(f, g, p) -R emp (f, g, {x i } n i=1 ) ≤ M log(|G|/δ) Thus we have in our case: R(f, g p , p) -R emp (f, g p , {x i } n i=1 ) ≤ p ≤ M log(|G|/δ) 2n and R(f, g * , p) -R emp (f, g * , {x i } a i=1 ∪ {x i } b i=1 ) =R(f, g * , r) -R emp (f, g * , {x i } a i=1 ∪ {x i } b i=1 ) + l(f (x), g * (x))(p(x) -r(x))d x =R(f, g * , r) -R emp (f, g * , {x i } a i=1 ∪ {x i } b i=1 ) + b a + b l(f (x), g * (x))(p(x) -q(x))d x ≤R(f, g * , r) -R emp (f, g * , {x i } a i=1 ∪ {x i } b i=1 ) + l(f (x), g * (x))(p(x) -q(x))d x ≤M log(|G|/δ) 2(a + b) + (10) where = l(f (x), g * (x))(p(x) -q(x))d x. If b ≥ M 2 log(|G|/δ) 2( p -) 2 -a then 2(a + b) ≥ M 2 log(|G|/δ) ( p -) 2 ( p -) 2 ≥ M 2 log(|G|/δ) 2(a + b) p ≥ M log(|G|/δ) 2(a + b) + Substitute into equation 10, we have R(f, g * , p) -R emp (f, g * , {x i } a i=1 ∪ {x i } b i=1 ) ≤ p Recall the definition of * , which is the minimum value of all possible satisfying R(f, g * , p) -R emp (f, g * , {x i } a i=1 ∪ {x i } b i=1 ) ≤ we know that * ≤ p . Let c = 2( p -) 2 , we can conclude the theorem. Proof [Proof of Theorem 2] First of all, {x i } a i=1 ∪ {x i } b i=1 can be regarded as drawn from distribution r(x) = ap(x) + bq(x) a + b . Theorem 6 (Mohri et al., 2018) Let l be a non-negative loss function upper bounded by M > 0, and for any fixed y, l(y, y ) is L-Lipschitz for some L > 0, then with probability at least 1 -δ, R(f, g, p) -R emp (f, g, {x i } n i=1 ) ≤ 2LR p (G) + M log(1/δ) 2n Thus we have R(f, g, p) -R emp (f, g, {x i } n i=1 ) ≤ p ≤ 2LR p (G) + M log(1/δ) 2n where R p (G) are Rademacher complexity over all samples of size n samples from p(x). 

B VARIANCE ANALYSIS

For the purpose of getting a sense of variance, we run experiments with additional random seeds on MRPC and RTE, which are relatively smaller datasets, and MNLI and QNLI, which are relatively larger datasets. Mean and standard deviation on the dev set of these GLUE datasets are reported in Table 5 . We observe the variance of the same model's performance to be small, especially on the relatively larger datasets. 



This is essentially the same as L in equation 8. We use a different notation l(f (x), g(x)) to explicitly spell out the two data-wise arguments f (x) and g(x). We make such an assumption because xi and x i are dependent, thus existence of p is unknown. data.quora.com/First-Quora-Dataset-Release-Question-Pairs https://huggingface.co/transformers/ CONCLUSIONSWe introduce MixKD, a method that uses data augmentation to significantly increase the value of knowledge distillation for compressing large-scale language models. Intuitively, MixKD allows the student model additional queries to the teacher model, granting it more opportunities to absorb the latter's richer representations. We analyze MixKD from a theoretical standpoint, proving that our approach results in a smaller gap between generalization error and empirical error, as well as better generalization, under appropriate conditions. Our approach's success on a variety of GLUE tasks demonstrates its broad applicability, with a thorough set of experiments for validation. We also believe that the MixKD framework can further reduce the gap between student and teacher models with the incorporation of more recent mixup and knowledge distillation techniques(Lee et al., 2020;Wang et al., 2020;Mirzadeh et al., 2019), and we leave this to future work. https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/ TinyBERT



Figure 1: Results of limited data case, where both the teacher and student models are learned with only 10% (left) or 1% of the training data (right).

Figure 3: Hyperparameter sensitivity analysis regarding the MixKD approach, with different choices of α TMKD , α SM and the ratio of mixup samples (w.r.t. the original training data).

We also haveR(f, g * , p) -R emp (f, g * , {x i } a i=1 ∪ {x i } b i=1 ) =R(f, g * , r) -R emp (f, g * , {x i } a i=1 ∪ {x i } b i=1 ) + l(f (x), g * (x))(p(x) -r(x))d x =R(f, g * , r) -R emp (f, g * , {x i } a i=1 ∪ {x i } b i=1 ) + b a + b l(f (x), g * (x))(p(x) -q(x))d x ≤R(f, g * , r) -R emp (f, g * , {x i } a i=1 ∪ {x i } b i=1 ) + l(f (x), g * (x))(p(x) -q(x))d x ≤2LR r (G) + M log(1/δ) 2(a + b) + (11) where = l(f (x), g * (x))(p(x) -q(x))d x. R r (G) are Rademacher complexity over all samples of size (a + b) samples from r(x) = ap(x) + bq(x) a + b . If b ≥ M 2 log(1/δ) 2( p --2LR r (G)) 2 -a then: 2(a + b) ≥ M 2 log(1/δ) ( p --2LR r (G)) 2 p --2LR r (G) ≥ M log(1/δ) 2(a + b) p ≥ M log(1/δ) 2(a + b) + + 2LR r (G)Substitute into equation 11, we have:R(f, g * , p) -R emp (f, g * , {x i } a i=1 ∪ {x i } b i=1 ) ≤ p Recall the definition of * , which is the minimum value of all possible satisfying R(f, g * , p) -R emp (f, g * , {x i } a i=1 ∪ {x i } b i=1 ) ≤ we know that * ≤ p . Let c = 2( p --2LR r (G)) 2 , we can conclude the theorem.Proof [Proof of Theorem 3] Similar to previous theorems, we writeR(f, g * , p) -R emp (f, g * , {x i } a i=1 ∪ {x i } b i=1 ) =R(f, g * , ap + bq a + b ) -R emp (f, g * , {x i } a i=1 ∪ {x i } b i=1 ) + l(f (x), g * (x))(p(x) -ap(x) + bq(x) a + b )d x =R(f, g * , ap + bq a + b ) -R emp (f, g * , {x i } a i=1 ∪ {x i } b i=1 ) + b a + b l(f (x), g * (x))(p(x) -q(x))d x ≤R(f, g * , ap + bq a + b ) -R emp (f, g * , {x i } a i=1 ∪ {x i } b i=1 ) +(12)

GLUE dev set results. We report the results of our BERT 12 teacher model, the 6-layer DistilBERT, and 3-and 6-layer MixKD student models with various ablations. DistilBERT results taken fromSanh et al. (2019). For MRPC and QQP, we report F1/Accuracy.

GLUE test server results. We show results for the full variants of the 3-and 6-layer MixKD student models (SM+TMKD+BT). Knowledge distillation (KD) and Patient Knowledge Distillation (PKD) results are fromSun et al. (2019a).

Computation cost comparison of teacher and student models on SST-2 with batch size of 16 on a Nvidia TITAN X GPU.

We compare our approach with the data augmentation module proposed by TinyBert(Jiao et al., 2019). MixKD. As shown in Table4, our approach exhibits much stronger results for distilling a 6-layer BERT model (on both MNLI and SST-2 datasets). Notably, TinyBERT's data augmentation module is much less efficient than mixup's simple operation, generating 20 times the original data as augmented samples, thus leading to massive computation overhead.

-TMKD+BT 89.79±0.27/85.04±0.48 82.05±0.11 88.42±0.06 69.37±0.50 BERT 6 -SM+TMKD+BT 89.64±0.38/84.43±0.36 82.41±0.12 88.76±0.15 68.02±0.11 BERT 3 -TMKD+BT 84.79±0.33/75.82±0.48 77.16±0.03 84.60±0.07 62.47±0.36 BERT 3 -SM+TMKD+BT 84.53±0.39/75.85±0.60 77.42±0.11 84.88±0.06 60.83±0.18 Mean and variance reported for BERT 6 -TMKD+BT,BERT 6 -SM+TMKD+BT,BERT 3 -TMKD+BT and BERT 3 -SM+TMKD+BT.

ACKNOWLEDGMENTS

CC is partly supported by the Verizon Media FREP program.

annex

where = l(f (x), g * (x))(p(x) -q(x))d x.For notation consistency, we write R(f, g * , ap + bq a + b ) = l(f (x) -g(x)) ap(x) + bq(x) a + b d x. However, {x i } a i=1 ∪ {x i } b i=1 are not drawn from the same distribution (which is r(x) = ap(x) + bq(x) a + b in previous cases).Let γ = a + b a , we split {x i } a i=1 ∪ {x i } b i=1 into γ parts that don't overlap with each other. The first part is {x i } a i=1 , all the other parts has at least a elements fromwhere C(G) is space capacity defined in Definition 4 in Baxter (2000) , which depends on * and G.By Theorem 4 in Baxter (2000) ,By Theorem 5 in Baxter (2000) ,The last inequality comes from b ≤ γa, which is because of γ = a + b a . Then we haveCombine equation 13 and equation 14, we haveSubstitute into equation 12, we have:) ≤ p Recall the definition of * , which is the minimum value of all possible satisfying R(f, g * , p) -R emp (f, g * , {x i } a i=1 ∪ {x i } b i=1 ) ≤ we know that * ≤ p .

