AUGMENTATION WITH PROJECTION: TOWARDS AN EFFECTIVE AND EFFICIENT DATA AUGMENTATION PARADIGM FOR DISTILLATION

Abstract

Knowledge distillation is one of the primary methods of transferring knowledge from large to small models. However, it requires massive task-specific data, which may not be plausible in many real-world applications. Data augmentation methods such as representation interpolation, token replacement, or augmentation with models are applied to tackle this problem. However, these data augmentation methods either potentially cause shifts in decision boundaries (representation interpolation), are not expressive enough (token replacement), or introduce too much computational overhead (augmentation with models). To this end, we propose Aug-Pro (Augmentation with Projection), an effective and efficient data augmentation method for distillation. Our method builds on top of representation interpolation augmentation methods to maintain the diversity of expressions and converts the augmented data to tokens to avoid shifting decision boundaries. It uses simple operations that come with little computational overhead. The results on multiple GLUE tasks show that our methods can improve distillation performance by a large margin at a low time cost. Codes are available at https://github.com

1. INTRODUCTION

Large-scale language models (Devlin et al., 2018; Raffel et al., 2020; Brown et al., 2020; Zhang et al., 2022c) have achieved great success on various natural language processing (NLP) tasks, such as information extraction (Lu et al., 2021) and question answering (Kassner & Schütze, 2020) . However, large-scale models have high computational overhead, which limits their deployment in edge devices and fast response scenarios (Sun et al., 2020b) . One widely used solution is to perform knowledge distillation (Hinton et al., 2015) from large-scale models to small-scale models. This method, however, usually requires a large amount of data to guarantee the transfer quality, which may not be easily obtained in real-world applications. To this end, data augmentation methods are applied (Liang et al., 2020; Wang & Yang, 2020; Zhang et al., 2022b) to improve the distillation performance. There are three major types of data augmentation methods: (1) Representation interpolation. For example, Liang et al. (2020) , Chen et al. (2020a) and Sun et al. (2020a) apply linear interpolation (Zhang et al., 2017) to word embeddings, hidden states between transformer layers, and encoder outputs, respectively, to augment the original dataset with virtual data points. Data points are virtual because they are not real language inputs. Instead, they are representations (e.g., embeddings). (2) Token replacement. Kobayashi (2018) replaces tokens with their synonyms. Easy Data augmentation (Wei & Zou, 2019) combines synonym replacement, random insertion, random swap, and random deletion. (3) Augmentation with models. Yoo et al. (2021) and Zhou et al. (2021) use GPT-3 (Brown et al., 2020) and T5 (Raffel et al., 2020) respectively as the language model to generate new text data of similar types. (1) supports many operations such as linear interpolation (Zhang et al., 2017) and small perturbation (Madry et al., 2017) . It makes the methods very expressive in generating a diverse range of data. However, the newly generated representations (e.g., embeddings) may sit outside of the real data distribution. For instance, word embeddings are converted from a vocabulary in the text domain. Performing augmentation at this level may result in representations that do not have their counterparts in the vocabulary. As a result, the augmented data may mislead the model to generate a shifted decision boundary that can largely affect the qualities (Section 3). ( 2) can generate in-domain data easily. By using synonym replacement (Wang & Yang, 2015) , new data can be obtained at a low cost. Despite this good property, this stream of methods lacks the ability to generate diversified data. Subsequently, they contribute little to sampling low-resource data areas and limit the performance gains in practice. (3) generates both diversified and in-domain data using large language models such as GPT-3 (Brown et al., 2020) and T5 (Raffel et al., 2020) . Due to their large computational overheads, on the other hand, the final distillation quality will be highly limited to the amount of generated data, which is usually not affordable to the scale of even tens of thousands of sentences in practice. Figure 1 summarizes the advantages of each augmentation method. Considering all the approaches above, we propose AugPro, an effective and efficient data augmentation method for the distillation scenario, which absorbs the advantages above without being limited by their drawbacks. Specifically, AugPro: (1) (effectiveness) is as expressive as representation interpolation; (2) (effectiveness) does not mislead decision boundaries; (3) (efficiency) has low computational overhead. In distillation settings, we can always use the teacher to label the hallucinated data in the knowledge distillation scenario. This suggests that we can encourage AugPro to produce as diverse data as possible that are not limit to instances with only the same or flipped labels. and labels. Given an image x 1 , x 2 and their labels y 1 , y 2 , MixUp uses a linear interpolation to generate a new data point x ′ and its label y ′ : x ′ = MixUp(x 1 , x 2 ) = λx 1 + (1 -λ)x 2 , y ′ = MixUp(y 1 , y 2 ) = λy 1 + (1 -λ)y 2 (1) FGSM (Goodfellow et al., 2014) and PGA (Madry et al., 2017) use gradients to generate adversarial examples. Given an image x, FGSM will generate new data x ′ : x ′ = x + ϵSign(∇ x L) (2) where L is the loss of a specific task and ϵ is a small value. x ′ and x have the same label. CutMix (Yun et al., 2019) cuts images and then concatenates them together to get a new one. Though these methods were originally designed for images, they can be adapted to NLP tasks. Liang et al. (2020) use MixUp on word embeddings to augment data for knowledge distillation. Chen et al. (2020b) use MixUp on hidden states between transformer layers. Jindal et al. (2020) also use MixUp on hidden states but consider the effect of mean and variance. Zhang et al. (2022b) apply PGA to student models' embeddings and leave teacher models' embeddings unchanged, finding that PGA can benefit knowledge distillation. Token replacement methods mainly focus on language inputs. Synonym replacement methods (Kobayashi, 2018) replace tokens with their synonyms. Easy data augmentation (EDA) (Wei & Zou, 2019) incorporates synonym replacement, random insertion, random swap, and random deletion. TreeMix (Zhang et al., 2022a) uses a constituency parser to decide which token should be replaced. Augmentation with models is another approach to generating new data. FlipDA (Zhou et al., 2021) uses T5 to generate data that has flipped labels. GPT3Mix (Yoo et al., 2021) designs prompts and uses GPT3 to generate new data. Back translation (Yu et al., 2018) uses neural networks to translate inputs to another language and then translate them back. Our method (AugPro) uses representation interpolation as the backbone and utilizes projection to convert representations to tokens with low-cost operations.

3. MOTIVATING EXAMPLES

Though representation interpolation methods have the good property of generating diverse data, we find that these vision techniques cannot be directly applied to NLP tasks. This is because representation interpolation augments data in a continuous manner, under which case the new data may never exist in the discrete input space, causing the decision boundary shifts. Take an example of a simple two-dimensional problem with linear separability (Figure 2 ). Let X = {x 1 , x 2 , x 3 , x 4 } be the universe of all data to learn, and Y = {y 1 , y 2 , y 3 , y 4 } be the corresponding labels. Suppose we know all of X , Y and run the linear support vector machine (SVM) with a hard margin, i.e., min β,b ∥β∥ 2 such that y i (β ⊤ x i + b) ≥ 1 for all i, we get the solution β * , b * Since it is hard to get all data in the real-world setting, we suppose that we only have {x To this end, augmented data should be the real data in the input space, i.e., the format of tokens, to leverage this problem. This observation leads to our method AugPro which uses projection to convert augmented representations to symbolic tokens. Compared with the virtual data point generated by representation interpolation, projection can explore more real data and leads to a lower error (Section 4.2 and Appendix H).

4. METHODOLOGY

In this section, we first formulate the definition of knowledge distillation in NLP (Hinton et al., 2015) and then introduce our method AugPro. x 1 x 2 x 3 x 4 - + (a) Ground-Truth x 1 x 2 x 3 x 4 x MixUp - + (b) MixUp SVM x 1 x 2 x 3 x 4 - + Projection (c) MixUp with Projection SVM Figure 2 : Decision boundary shifts in 2D space for discrete datasets. (a) is the gound-truth, where x 1 , x 3 are the observable training data while x 2 , x 4 with transparent colors mean the unseen data. In (b), one gets augmented data x MixUp with label y MixUp = -1, and do SVM with {x 1 , x 3 , x MixUp } with their labels. In (c), one projects x MixUp to its nearest neighbor x 2 , and do SVM with {x 1 , x 2 , x 3 }. We see that the correction of projection in (c) brings smaller decision boundary shifts than (b).

4.1. KNOWLEDGE DISTILLATION

Knowledge distillation is a method to distill knowledge from large-scale models to small-scale models. Formally speaking, considering an NLP classification task, we have a corpus D = {(x i , y i )} N i=1 that contains N input-output pairs, where x i is an input sentence with tokens x i = [w i1 , • • • , w in i ], w k ∈ V, V is the vocabulary, n i is the number of tokens in x i . y i is the output label for x i . We use plain texts rather than bold texts for x because language inputs are a sequence of tokens, which is different from images. Then we distill knowledge from a large-scale model f (•, θ T ) with parameter θ T (i.e., a teacher model) to a small-scale model g(•, θ S ) with parameter θ S (i.e., a student model). In practice, θ T has much more parameters than θ S . The distillation process can be divided into two stages: • Teacher training. Optimize θ T on the dataset D. In classification problems, we use cross-entropy loss to do empirical risk minimization on θ T : θ ′ T = arg min θ T 1 N N i=1 CrossEntropy(f (x i , θ T ), y i ) • Student training. Optimize θ S on the dataset D with both ground-truth labels and outputs from teachers. In classification problems, θ ′ S = arg min θ S L KD = arg min θ S 1 N N i=1 CrossEntropy(g(x i , θ S ), y i ) + d(g(x i , θ S ), f (x i , θ ′ T )) where d(•, •) is a distance function. In practice, d(•, •) could be cross-entropy or mean square error. Empirical results from former studies (Hinton et al., 2015; Sun et al., 2020b; Sanh et al., 2019; Sun et al., 2019) show that knowledge distillation will train a better θ ′ S because the student model not only learns from ground-truth labels but also learns the generality from the teacher model. Note for the student training, we can combine knowledge distillation and data augmentation together: θ ′ S = arg min θ S L KD + L Aug where L Aug denotes the knowledge distillation loss on augmented data which leads to different variants of methods. As one important way to help the student learn more effectively, how to generate new data with augmentation loss is the key and major discussion topic in the remaining sections.

4.2. AUGPRO: AUGMENTATION WITH PROJECTION

In this section, we will introduce four variants of L Aug : two backbones (MixUp and FGSM) and two AugPro variants building on top of them. Figure 4 shows the concept of our proposed method. Algorithm 1: AugPro Algorithm Input: Dataset D = {(X, Y )}, representation interpolation function h(•), projection function p(•), the teacher model f with fine-tuned parameters θ ′ T and the student model g with parameter θS that needs to be fine-tuned, vocabulary V , learning rate η, training steps K, batch size B, sentence length L, embedding dimension H. Equation (1, 2)) -Project representations to tokens B ′ = p(Brep) ∈ V B×L (e.g., Equation 5) -Use Equation ( 4) to compute LKD. Output: Fine-tuned θ ′ S • k = 0 • while k < K -Sample a data batch B = {(x, y)} ∈ V B×L from D = {(X, Y )}. -Get augmented representations Brep = h(B) ∈ R B×L×H (e.g., -Compute loss LAug (e.g., Equation (6, 7)) based on B ′ , f and g . -θS = θS -η∇(LKD + LAug), k = k + 1 • return θS AugPro builds on top of representation interpolation augmentation methods. The pipeline (Algorithm 1) can be divided into three steps: (1) We first get augmented representations (i.e. h(•)). (2) Then we use projection to convert representations to tokens (i.e. p(•))). (3) At last, we compute L Aug to update student models. The key to step (2) is projection. Concretely, language models map tokens to representations, and projection aims to find an inverse mapping to map representations back to tokens. This way, AugPro will avoid shifting decision boundaries (Section 3). However, the inverse mapping is hard to find in practice. First, popular language model architectures such as transformers (Vaswani et al., 2017) are usually complex and are hard to get the inverse mapping. Second, the input space of language is discrete, making the mapping irreversible. Thus, we could only use the approximation technique to get the approximated inverse mapping. To this end, we focus on the inverse mapping on the embedding level. First, the embedding mapping's structure is much simpler than the transformer layer and is straightforward for us to find the inverse mapping. Second, we can use the nearest-neighbors as the approximation method, which is a cheap approximation. Based on the analysis above, we use the nearest-neighbors to find our projection, i.e., the function p(•) in the Algorithm 1. AugPro does not rely on specific representation interpolation methods. In this paper, we apply AugPro to MixUp and FGSM, i.e. the h(•) in Algorithm 1 is MixUp or FGSM as in Equation (1,2). We will illustrate two variants to perform projection (step (2)) and compute L Aug loss (step (3)) in the following texts. We slightly abuse the notion of f and g to illustrate AugPro better. We divide f and g into two parts: the first part is an embedding function that maps tokens to embedding vectors (f e and g e , e denotes embeddings), the rest is the second part (f l and g l , l denotes layers). Under this definition, f = f l • f e and g = g l • g e .

AugPro-Mix

We get AugPro-Mix by applying the AugPro paradigm to MixUp. First, we apply MixUp on word embeddings and labels. This gives us embeddings from the teacher e f MixUp , embeddings from the student e g MixUp , and the label y MixUp . L Aug becomes L MixUp = 1 M M j=1 [CrossEntropy(g l (e g MixUp,j , θ S ), y MixUp,j )+d(g l (e g MixUp,j , θ S ), f l (e f MixUp,j , θ ′ T ))] when we use MixUp data to construct losses, where M denotes the number of augmented data. For AugPro, we use the nearest-neighbors to get AugPro-Mix tokens x AugPro-Mix : x AugPro-Mix = [w AugPro-Mix,1 , • • • , w AugPro-Mix,n ] where w AugPro-Mix,i = max w∈V Sim(e f MixUp (i), f e (w)) (5) Sim means the similarity function, which can be the cosine similarity. e(i) means the ith embedding vector in e. Concrete examples can be found in Appendix D. Then the loss function L Aug becomes: L AugPro-Mix = 1 M M j=1 d(g(x AugPro-Mix,j , θ S ), f (x AugPro-Mix,j , θ ′ T )) We do not use y MixUp because the projection operation (the nearest-neighbors in AugPro-Mix) does not necessarily preserve the label. AugPro-FGSM Though adversarial examples (AE) are originally aimed to improve the robustness of models and may harm the performance on clean inputs (Raghunathan et al., 2019) , Zhang et al. (2022b) shows that AE can benefit knowledge distillation. We get AugPro-FGSM by applying AugPro to FGSM. We first apply FGSM to the student model and get augmented data e g FGSM . The augmented data can be used to construct L Aug directly: L FGSM = 1 M M j=1 d(g l (e g FGSM,j , θ S ), f l (e f j , θ ′ T )) Following Equation 5, we could get x AugPro-FGSM by changing the footnotes accordingly. We usually set ϵ in Equation ( 2) large in AugPro-FGSM since we do not want x AugPro-FGSM to be the same as the original input, whereas ϵ in Equation ( 2) is a small value in FGSM. We use the cosine similarity to implement the Sim function. The loss function L Aug becomes: L AugPro-FGSM = 1 M M j=1 d(g(x AugPro-FGSM,j , θ S ), f (x AugPro-FGSM,j , θ ′ T )) Label Diversity We take two sentences from SST-2 dataset (Socher et al., 2013) as an example to further explain that projection does not necessarily preserve labels but generates diverse labels. The first sentence is watch on video at home with the sentiment Neutral. The second sentence is as good with the sentiment Positive. Then we can get the AugPro-Mix sentence watch good video at home. Obviously the label of AugPro-Mix sentence should be Positive rather than the linear interpolation of Positive and Neutral. This is the desired property in distillation as we will use the teacher to label these newly generated data points. Computational Overhead. If we assume the complexity of computing cosine similarity between two vectors is O(d) where d is the dimension of vectors, then the complexity of the projection (the nearest-neighbors in our implementation) is O(N V d), where N is the sentence length and V is the vocabulary size. N is usually within hundreds. V is usually around 30, 000 in popular pre-train language models using sub-word tokens such as BERT (Devlin et al., 2018) and T5 (Raffel et al., 2020) . As a result, O(N V d) brings little costs. On the other hand, the projection operation could be parallelized since the N V similarity calculations do not affect each other. In modern parallel computing architectures, such as GPUs and TPUs, projection can be calculated in a much faster manner. Compared to the major large-scale language models' complexities, this computation will take a small portion of resources. The detailed running time comparison can be found in Section 5.2. Three Properties of AugPro. Since AugPro supports operations used in representation interpolation methods, AugPro is expressive (property (1)). AugPro also converts representations to tokens to avoid shifting decision boundaries, leading to a smaller error rate (property (2)). It can be shown that AugPro-Mix has a 1 4N lower error rate than MixUp, and AugPro-FGSM has a 1 2N lower error rate than FGSM with certain assumptions (Appendix H). Moreover, AugPro has a low computational overhead to guarantee efficiency (property (3)), as described in the previous paragraph.

5. EXPERIMENTS

Our experiments aim to answer two questions: (1) How effective is AugPro when applied to the knowledge distillation scenario? (2) Is AugPro efficient? Datasets and Settings Following previous knowledge distillation works (Liang et al., 2020; Zhang et al., 2022b) , we use GLUE (Wang et al., 2018) datasets as the benchmark. We use EncT5 (Liu et al., 2021) as our teacher and student models for the following reaons: (1) T5 has a much better performance than BERT and is close to SOTA in many tasks. EncT5 is a simplified version of T5 that uses the whole encoders of T5 but only one decoder layer. EncT5 performs similarly to T5 on classification tasks such as GLUE tasks with fewer parameters. For example, EncT5 (small) only contains 37M parameters but can perform similarly to T5 (small), which contains 77M parameters. Using EncT5 will make the results more convincing and show that our method is still useful even with powerful models. (2) Previous methods (Liang et al., 2020; Zhang et al., 2022b) distill knowledge from a 12-layer BERT to a 6-layer BERT or a 3-layer BERT. However, the gap between the teacher and student models is marginal. Therefore, the improvement space is limited, and the existence of variance will weaken the credibility of the results. To this end, we distill knowledge from EncT5 (Large, 24-layer, 354M, teacher) to EncT5 (small, 8-layer, 37M, student), as the two models have a significant performance gap. et al., 2018) first translates input to another language and then translates it back. We choose back translation as a representative method for the data augmentation type "augmentation with models". (4) Knowledge Distillation + K-Nearest-Neighbors (KD+KNN) KNN (Wang & Yang, 2015) first selects tokens from inputs, then replaces them with the K nearest neighbors in the embedding space. KNN can be regarded as one token replacement method. ( 5) KD+MixUp (6) KD+FGSM (7) KD+ TMix (Chen et al., 2020b) MixUp on the hidden state between transformer layers. The last three methods are of the "representation interpolation" type. We train student models with 0.6M steps and 512 batch size. Due to the high computation cost, we only augment data to twice as large as the original dataset size for back translation. For all other methods, we augment data to twice as large as the original batch size for each batch, i.e., we augment 0.6M steps • 512 batch size = 307.2M data in total. More training details are in Appendix C.

5.1. EFFECTIVENESS OF AUGPRO

Table 1 shows the results of knowledge distillation. Due to the high cost, we only report back translation results on the RTE dataset. We first use the training data to train a teacher model and then distill knowledge from the teacher model to the student model on the training data. We can conclude that: (1) All data augmentation methods will benefit the distillation. (2) AugPro can significantly improve the distillation performance compared with corresponding baselines. Specifically, AugPro is extremely useful for low-resource datasets such as CoLA and RTE. AugPro-Mix achieves scores 5.97% and 9.02% higher than MixUp on CoLA and RTE, respectively. AugPro-FGSM achieves scores 10.52% and 8.31% higher than FGSM on CoLA and RTE, respectively. For large datasets such as MNLI, AugPro-Mix and AugPro-FGSM can also improve the performance. (3) Moreover, combining AugPro-FGSM and AugPro-Mix achieves the best performance in all listed methods. Compared with vanilla knowledge distillation, combining AugPro-Mix and AugPro-FGSM improves the performance from 2% to 14%. 1 . We only keep 10% training data labeled and assume others are unlabeled. Then we use labeled training data to train a teacher model and unlabeled training data to do knowledge distillation-this is a more realistic setting since it is often easier to get unlabeled data than to get labeled data. The conclusions above still hold. Specifically, AugPro can improve the accuracy from 1% to 2% on average on three datasets. Compared with the vanilla distillation, AugPro can improve around 2% accuracy at most on three datasets.

5.2. EFFICIENCY OF AUGPRO

The efficiency of AugPro lies in two aspects. First, its complexity is low. Second, it can be computed in parallel. To fully demonstrate these two advantages, we report the real-time cost of AugPro and baselines in Table 3 . KD+data augmentation is rough twice the time of vanilla KD since these methods use twice the data as vanilla KD. We can also observe that augmentation with models (KD+BT) takes much more time than other kinds of baselines, which shows that this method is not efficient enough. At last, AugPro brings little computational overhead as the time cost is the same as the baselines. Results also show that KNN is much slower than other methods, which is explained in Appendix G.

5.3. ABLATION STUDY

In ablation studies, we follow settings used in Table 1 unless otherwise stated. Perturbation scale for ϵ in AugPro-FGSM The key hyperparameter in AugPro-FGSM is ϵ in Equation (2). Small ϵ will make x AugPro-FGSM the same as the original input. Large ϵ tends to make x AugPro-FGSM hard to understand, meaningless, and out of the domain. Therefore, a proper ϵ is essential. Our experiments find that ϵ = 35 is the best fit for T5 embeddings. Table 4 shows KD+AugPro-FGSM performance with different ϵ. Signs of gradients in AugPro-FGSM are not important The effectiveness of AugPro-FGSM comes from gradients' signs and the projection in AugPro. To prove that AugPro-FGSM mainly benefits from AugPro, we implement two AugPro-FGSM variants: AugPro-FGSMD (Descent Projection) that uses the opposite signs to AugPro-FGSM , and AugPro-FGSMR (Random Projection) that uses random signs. Table 6 shows the results of AugPro-FGSM and its two variants. We can observe AugPro-FGSM has a similar score to its variants in all settings. Thus AugPro-FGSM mainly benefits from AugPro. We also conduct experiments that follow the setting of Table 2 , and results can be found in Appendix E. AugPro generates diverse labels We show that AugPro generates diverse labels at the end of Section 4. Here we empirically show that assuming AugPro preserving labels may harm performance. If AugPro preserves labels, AugPro-Mix and AugPro-FGSM data should have the same labels as MixUp and original data, respectively. We use these augmented data together with labels to fine-tune student models directly. Results in Table 5 suggest that such augmented data and labels may harm performance. Therefore, AugPro generates diverse labels and does not necessarily preserve labels. AugPro consistently benefits KD with different data sizes Figure 3 shows the performance of AugPro with different data sizes. It can be observed that AugPro is better than all baselines in all data sizes. Moreover, AugPro is extremely useful when the data size is small. For example, AugPro can improve the accuracy of 4% (SST-2) and 6% (MNLI-M) when the data size is 10%. We also report results on the MNLI-MM dataset in Appendix F.

6. CONCLUSIONS AND FUTURE WORK

We propose AugPro, an effective and efficient data augmentation paradigm for knowledge distillation. We use projections to tackle the problem of shifting decision boundaries caused by traditional representation interpolation methods in knowledge distillation. Moreover, AugPro has low computation costs and is fast in modern computing architectures. Results on GLUE tasks prove the effectiveness and efficiency of AugPro. In the future, we will further explore the impact of AugPro on labels to make it helpful in other scenarios.

A THE CONCEPT FIGURE OF AUGPRO

Here we show a concept figure (Figure 4 ) to let readers better understand the difference between AugPro (e.g., AugPro-Mix) and previous works (e.g., MixUp (Zhang et al., 2017) ).

Input1 Input2

Teacher Embedding To fine-tune the teacher model, we use a dropout rate of 0.1 and a learning rate of 1e-3 for all GLUE tasks.

B CONCRETE NUMBERS OF THE EXAMPLE IN SECTION 3

X = {x 1 = (2.5, 2), x 2 = (2, -2), x 3 = (-2.5, -2), x 4 = (-2, 2)} Y = {y 1 = +1, y 2 = +1, y 3 = -1, y 4 = -1} • β * = ( For knowledge distillation, we set the dropout rate to be 0.1 for both the teacher and student models. We find that adding dropout to teachers will make the distillation better. We run all experiments with 1e-3 and 1e-5 learning rates and report the best results. As a result, learning rate is set to 1e-5 for all experiments on the STSB dataset, EncT5 8 -FT and EncT5 8 -KD experiments on CoLA, MRPC and RTE datasets. All other experiments use the learning rate 1e-3. The λ (Equation ( 1)) for AugPro-Mix is 0.5. Previous works (Zhang et al., 2017; Liang et al., 2020) use a beta distribution to sample λ for MixUp. We try λ ∼ Beta(0.4, 0.4) and λ = 0.5 for MixUp and find they have similar performance in most tasks. In some tasks such as CoLA, λ = 0.5 is better. Therefore, we use λ = 0.5 for MixUp. Following previous works (Zhou et al., 2021) , we set k = 15 for the KNN baseline and randomly select 0.1 portion of tokens to replace. We use outputs of the 4th layer of the student model and the 12th layer of the teacher model, i.e., the middle layer of both models, to conduct TMix experiments. Published as a conference paper at ICLR 2023

C.2 IMPLEMENTATION DETAIL OF EQUATION 5

Algorithm 1 shows the AugPro pipeline. Here, we show a detailed implementation of Equation 5 in Algorithm 2. We believe these two algorithms can help readers reproduce our methods and results. Following are three MNLI inputs: (i) hypothesis: The judgments need to consider the broader public interest. premise: These judgments need to be made in a consistent manner with consideration of the broader public interest in the program or activity under review. (ii) hypothesis: I agree. premise: yeah i i agree with that (iii) hypothesis: She's never been to a hospital. premise: and uh as if that wasn't bad enough the the ones that were half alive that they rushed to the hospital and she got to work on and she got to see them die and uh and they just get all the world's worst situations very few rewarding situations uh and If we use AugPro-Mix on (i) and (ii), (ii) and (iii), we will get: (i + ii) hypothesis: I judgment. premise: yeah judgmenti needi agree made in (ii + iii) hypothesis: She agree. never been to a hospital. premise: yeah uh as agreeif that wasn't bad enough the the ones that were half alive that they rushed to the hospital and she got to work on and she got to see them die and uh and they just get all the world's worst situations very few rewarding situations uh an If we use AugPro-FGSM on (i), (ii) and (iii), we will get: (i) hypothesis: River judgment Mit handy to consider the broader public interest. premise: These judgment Mit handy to be made in a consistent manner with consideration-the broader public interest in the federal clȃdire activity under reviewed. (ii) hypothesis: Kyle agree. premise: yeah chambres chambres agree with that (iii) hypothesis: She' Mit never been to a hospital. premise: and uh as if that wasn' financing bad enough the the ones that are half alive that they rushed to the hospital and she got to work on and she got to see them die and uh and they just get all the world' Mit worst situations very few rewarding situations uh and The above-augmented sentences are originally token lists, but not real sentences. Luckily, T5 uses Sen-tencePiece to construct its vocabulary, which supports the precise de-tokenization for any token list. To convert token lists to sentences, we use a simple command "".join(tokens).replace(" "," "). Linguistic Analysis. Our motivation focuses on the perspective of machine learning, i.e., avoiding shifting decision boundaries by converting representations to tokens. Here, we would like to add a brief analysis from a linguistic perspective. We can observe that the above-augmented sentences may have grammatical errors, "meaningless" tokens, and may be less meaningful than original sentences. However, "meaningless" to humans does not suggest meaningless to models, as AugPro indeed boosts the distillation performance. Besides, augmented sentences are not totally semantically meaningless to humans. It is hard to see why augmented data is so helpful from the linguistic perspective, suggesting that we should focus on analyzing these data from the machine learning perspective, which is exactly our motivation. To further support our motivation, we conduct a simple baseline "random generation" that randomly chooses tokens from the vocabulary and concatenates them to form an augmented sentence. Random generation can generate meaningless sentences easily. Results are shown in Table 7 . We can conclude that random generation is a poor augmentation method, suggesting that AugPro is helpful not because of the meaningful or meaningless semantics, but because of avoiding shifting decision boundaries.

E SIGNS OF GRADIENTS IN AUGPRO-FGSM IS NOT IMPORTANT (MORE RESULTS)

Table 8 shows the effect of signs of gradients in AugPro-FGSM with the setting of Table 2 . The conclusion is same as the conslusion concluded from Table 6 .

F AUGPRO CONSISTENTLY BENEFITS KD WITH DIFFERENT DATA SIZES (MORE RESULTS)

The main texts show that AugPro consistently benefits KD with different data sizes on SST-2 and MNLI-M datasets. Figure 5 shows this conclusion still holds on the MNLI-MM dataset. The overall trend is similar to the MNLI-M dataset.

G KNN IS SLOWER THAN OTHER METHODS

There are two reasons that KNN is slower than other methods. First, KNN has a higher computational complexity O(N V dlogk) than AugPro (O(N V d)). Second, KNN is hard to be implemented by XLA. A popular and fast implementation of KNN is to use np.argpartition. However, XLA does not support partition operationsfoot_1 , making KNN hard to be implemented with JAX on TPUs. To this end, Chern et al. (2022) propose another method to run KNN on TPUs at peak FLOP/s. We use their method to implement KNN, but such implementation may not be the optimal solution for KNN. H AUGPRO CAN ACHIEVE LESS ERROR RATES Suppose X = {-1, 1} 2 log n is the universe of all data with labels Y to learn. Define a distribution P such that (x, y) ∼ P if x is uniformly independently drawn from X and y is uiformly independently drawn from {-1, +1}. X D is the training data compositing of n samples uniformly and independently drawn from X . Suppose n is even. Let x(j) be the j-th coordinate of vector x. We construct X AugPro-Mix as follows: Index the elements in X D arbitrarily, and x i denotes the i-th element. For each positive integer i ∈ [n/2], we construct an augmented data z i such that, for each coordinate j ∈ As for constructing X AugPro-FGSM , for each x i ∈ X D , we add Gaussian N (0, 4) to each coordinate of x i , and then project back to X to get w i . [2 log n], z i (j) = x 2i-1 (j) if x 2i-1 (j) = x 2i (j), otherwise z i (j) is uniformly chosen from {1, -1}. Similarly, assume A AugPro-FGSM can observe the dataset X AugPro-FGSM and its labels Y AugPro-FGSM , and define the generalization error to be error(A AugPro-FGSM ) = Pr (x,y)∼P,(XAugPro-FGSM,YAugPro-FGSM) [A(x, X AugPro-FGSM , Y AugPro-FGSM ) ̸ = y]. Claim 2. One has lim n→∞ error(A) -error(A AugPro-FGSM ) 1/(2n) = 1.

H.1 PROOF OF CLAIM 1

We provide the following preliminary background of martingale to prove Claim 1 for completeness. Definition 1 (Martingale). A sequence of random variables Y 1 , Y 2 , • • • is a martingale with respect to another sequence X 1 , X 2 , • • • if for all n, E[|Y n |] < ∞ and E[Y n+1 | X 1 , • • • , X n ] = Y n . Definition 2 (Martingale Difference). {D k } ∞ k=1 is a martingale difference sequence w.r.t. {X k } ∞ k=1 if for all n:  • D n is a measurable function of X 1 , • • • , X n • E[|D n |] < ∞ • E[D n+1 | X 1 , • • • , X n ] = 0. If Y k is a martingale, then D k = Y k -Y k- t := t j=1 E[X 2 j | F j-1 ] for all t, where {F t } is the filtration. Then for all ℓ ≥ 0 and σ 2 > 0, Pr ∃k ≥ 0 : | k t=1 X t | ≥ ℓ & W k ≤ σ 2 ≤ 2 exp -ℓ 2 /2 σ 2 + M ℓ/3 . With these tools, we are ready to prove the claim. Proof of Claim 1. In our setting, it is evident that error Choosing c 1 large enough proves Equation (8). Similarly we show for some constant c 2 > 0, Pr |X AugPro-Mix | ≥ 3n/2 -3c 2 n log n ≥ 1 -1/poly(n). It is equivalent to constructing X AugPro-Mix by iterations. For i-th iteration, we draw x 2i-1 and x 2i i.i.d. uniformly and construct z i as described before. Let X i AugPro-Mix denote the data we get after constructing x 2i-1 , x 2i and z i . Let Y ′ i = |X i AugPro-Mix | and D ′ i be the indicator that Y ′ i -Y ′ i-1 = 3 (i.e. all of the three data are distinct and first constructed). We know Y ′ n/2 ≥ 3 n/2 i=1 D ′ i . For any fixed vector x ∈ {-1, 1} 2 log n , we know Pr[x 2i-1 = x] = Pr[x 2i = x] = Pr[z i = x] = 1/n 2 . Hence for any X i-1 AugPro-Mix and by union bound, one has E[D ′ i = 1 | X i-1 AugPro-Mix ] ≥ 1 - 3(|X i-1 AugPro-Mix | + 3) |X | ≥ 1 - 9 2n . Let F ′ i be the filtration and hence It is equivalent to constructing X AugPro-FGSM iteration by iteration, where in i-th iteration, we draw x i and construct w i as described before. Let X i AugPro-FGSM be the data we get after constructing x i and w i , let Y i = |X i AugPro-FGSM | and let D i be the indicator that Y i -Y i-1 = 2. For any fixed vector x ∈ {-1, 1} 2 log n , we know Pr[x i = x] = 1/n 2 = Pr[w i = x] and Pr[w i = x i ] < 0.7 2 log n < 1/n 1.02 . Let F i be the filtration, and for any X i-1 E[D ′ i | F ′ i-1 ] ≥ 1 -9 2n . Let D′ i := D ′ i -E[D ′ i | F i- AugPro-FGSM , by union bound, one has E[D i = 1 | X -1 AugPro-FGSM ] ≥ 1 - 2|X i-1 AugPro-FGSM | |X | -1/n 1.02 ≥ 1 -3/n.



Converted time. We run BT on TPU v2 and compute the equivalent time cost. https://github.com/google/jax/issues/10541



Figure 1: An illustration of each augmentation method's advantages.

1 , x 3 } as the training dataset. If we simply use MixUp, we get x MixUp , as the augmented data, whose label y MixUp = sign((β * ) ⊤ x MixUp + b * ) = y 3 = y 4 . Now running the linear SVM with {x 2 , x 4 , x MixUp } with labels {y 2 , y 4 , y MixUp }, we get β MixUp and b MixUp . Nevertheless, sign(β ⊤ MixUp x 3 + b MixUp ) ̸ = y 3 . As a comparison, if we project x MixUp to its nearest neighbor and get x MixUp-P = x 2 whose label y MixUp-P = y 2 , running SVM with {x 1 , x 2 , x MixUp-P } and {y 1 , y 2 , y MixUp-P } can get β MixUp-P and b MixUp-P , which can classify all data correctly. Appendix B shows the concret number of each parameter.

We train several baselines for comparison: (1) Fine-Tuning (FT): We directly fine-tune EncT5 on the dataset. (2) Knowledge Distillation (KD): We first fine-tune a teacher model (EncT5 Large), then distill knowledge from the teacher model to the student model (EncT5 Small). (3) Knowledge Distillation + Back Translation (KD+BT): Back translation (Yu

Figure 3: AugPro performance with different data sizes. Figure (a) and Figure (b) are for SST-2 dataset and MNLI-M dataset. Blue lines (or triangle markers) are AugPro methods. Yellow lines (or diamond markers) are baseline methods. The green line (or X marker) is KD. AugPro has the same line type as the corresponding baseline. For example, AugPro-FGSM and FGSM are all dashed lines.

Figure 4: Left: MixUp with knowledge distillation. Right: AugPro-Mix with knowledge distillation.

4/9, -1/18), b * = 0. • x MixUp = 13 25 x 3 + 12 25 x 1 , β MixUp = (250/533, 200/533), b MixUp = -12/13 • β MixUp-P = (4/9, 0), b MixUp-P = 1/9 C IMPLEMENTATION DETAILS C.1 HYPERPARAMETERS We use JAX and T5X to implement EncT5 and AugPro, and use T5 1.1 checkpoints to initialize models. The batch size is 512, and the maximum sentence length is 128. Training steps are 0.6M for all experiments. We use 8 TPU v3 slices to do all experiments.

Projection Algorithm Input: Augmented representations Brep ∈ R B×L×H , where B is the batch size, L is the sentence length, and H is the embedding dimension. Vocabulary V contains Z tokens. Token embeddings E ∈ R Z×H . Output: Projected augmented data B ′ ∈ V B×L • Compute the similarity between each token pair Sims = Brep • E T ∈ R B×L×Z • Find the nearest neighbor B ′ index = arg min axis=2 Sims ∈ N B×L • Find tokens according to indices to get B ′ • return B ′ D EXAMPLES OF AUGPRO-MIX AND AUGPRO-FGSM We use three MNLI inputs to show the augmented data by AugPro-Mix and AugPro-FGSM.

Figure 5: AugPro performance with different data sizes on MNLI-MM dataset. Blue lines (or triangle markers) are AugPro methods. Yellow lines (or diamond markers) are baseline methods. The green line (or X marker) is KD. AugPro has the same line type as the corresponding baseline. For example, AugPro-FGSM and FGSM are all dashed lines.

Let's consider optimal algorithms A and A AugPro-Mix . A can only observe X D (or X MixUp , X FGSM ) and their labels Y D (or Y MixUp , Y FGSM ). For proof, we assume A can only observe X D and their labels Y D . The other two can be proved similarly. A AugPro-Mix can observe X AugPro-Mix and their corresponding labels Y AugPro-Mix .Define the generalization error error(A) := Pr (x,y)∼P,(X D ,Y D ) [A(x, X D , Y D ) ̸ = y], where A(x, X D , Y D ) ∈ {+1, -1}is the prediction of x outputted by A, and error(A AugPro-Mix ) = Pr (x,y)∼P,(XAugPro-Mix,YAugPro-Mix) [A(x, X AugPro-Mix , Y AugPro-Mix ) ̸ = y]. We have the following claim:

(A) = Pr x∼P [x / ∈ X D ]/2 = 1/2 -E[|X D |]/2n 2, where the expectation is taken over the randomness of X D , and | • | denotes the cardinality after removing duplicate elements.Similarly, we have error(AAugPro-Mix ) = 1/2 -E[|X AugPro-Mix |]/2n 2 . It suffices to prove lim n→∞ E[|X AugPro-Mix | -|X D |] n/2 = 1.First, we show for some constant c 1 > 0,Pr |X D | ≥ nc 1 n log n ≥ 1 -1/poly(n).(8)It is equivalent to drawing the elements in X D one by one. Let X i D denote the elements after drawingx i . Let Y i = |X i D |, and D i be the indicator that x i ̸ = x j for all 0 < j < i, whereD 1 = 1. Then we know |X D | = Y n , D i = Y i -Y i-1 ,and for any X iLet F i be the filtration, and henceE[D i | F i-1 ] ≥ 1 -1 n . Let Di := D i -E[D i | F i-1 ] be a martingale difference sequence. Hence | Di | ≤ 1 almost surely, and E[ D2 i | F i-1 ] ≤ 1. Let W i = i j=1 E[ D2 j | F j-1 ]. By Freedman's Inequality, one hasPr[Y n < nc 1 n log n] c 1 n log n ∧ W n ≤ n]

1 ] be a martingale difference sequence. Similarly| D′ i | ≤ 1 almost surely and E[ D′2 i | F i-1 ] ≤ 1. Let W ′ i = i j=1 E[ D′2 j | F ′ j-1 ]. By Freedman's Inequality, one hasPr[Y ′ n/2 < 3n/2 -3c 2 n log n] c 2 n log n -O(1)] c 2 n log n -O(1) ∧ W ′ n/2 ≤ n/2] 2 = c 2 -O(1). Choosing c 2 large enough completes the proof of Equation (9). Combining equations (8) and (9) proves the statement.H.2 PROOF OF CLAIM 2Proof. Similarly, we have error(A AugPro-FGSM ) = 1/2 -E[|X AugPro-FGSM |]/2n 2 and it suffices to provelim n→∞ E[|X AugPro-FGSM | -E[|X D |]] n = 1.We already have Equation (8). It suffices to prove that for some constant c 3 > 0,Pr |X AugPro-FGSM | ≥ 2nc 3 n log n ≥ 1 -1/poly(n).(10)For each x i ∈ X AugPro-FGSM and each coordinate j ∈ [2 log n], we know Pr[x i (j) = w i (j)] = Φ(1/2) < 0.7, where Φ is cumulative distribution function (CDF) of one-dimensional standard Gaussian distribution, i.e. Φ(t) = Pr x∼N (0,1) [x ≤ t].

uses a different setting from Table

Knowledge distillation on the GLUE dataset. We first use the training data to train a teacher model and then distill knowledge from the teacher model to the student model on the training data. EncT5 L denotes EncT5 with L transformer layers. L = 24 and L = 8 denote the teacher model with 354M parameters and the student model with 37M parameters, respectively.

Knowledge distillation on the GLUE dataset with a different setting from Table1. We regard 10% of the data as labeled and the rest as unlabeled. The teacher model is first trained on labeled training data and then used for knowledge distillation on unlabeled training data.

Time (minutes) costs every 1000 steps on average of various methods on 8 TPU v3 slices. Costs contain data augmentation, the forward pass, and the backpropagation. The table is divided into four parts. Each part contains a specific data augmentation type (KD, augmentation with models, token replacement, representation interpolation and AugPro).

KD+AugPro-FGSM performance with different ϵ.

AugPro-Mix and AugPro-FGSM are used to fine-tune student models. MixUp labels are used for AugPro-Mix data. AugPro-FGSM uses the original label.

Results of the random generation baseline.

KD+AugPro-FGSM and its variants performance with different signs. AugPro-FGSMD denotes FGSM with Decent Projection. AugPro-FGSMD uses the opposite sign to AugPro-FGSM . AugPro-FGSMR denotes FGSM with Random Projection. AugPro-FGSMR uses the random sign. This table follows the setting in Table2.

Proposition 1(Freedman's Inequality, Freedman (1975)). Consider a real-valued martingale difference sequence {X t } which is uniformly bounded, i.e. |X t | ≤ M almost surely for all t. Define the predictable quadratic variation process of the martingale W

availability

/ google-research/google-research/tree/master

annex

where c ′ 3 = c 3 -O(1). Choosing c 3 large enough completes the proof. Combing Equation ( 8) and ( 10) completes the proof.

