REWEIGHTING AUGMENTED SAMPLES BY MINIMIZ-ING THE MAXIMAL EXPECTED LOSS

Abstract

Data augmentation is an effective technique to improve the generalization of deep neural networks. However, previous data augmentation methods usually treat the augmented samples equally without considering their individual impacts on the model. To address this, for the augmented samples from the same training example, we propose to assign different weights to them. We construct the maximal expected loss which is the supremum over any reweighted loss on augmented samples. Inspired by adversarial training, we minimize this maximal expected loss (MMEL) and obtain a simple and interpretable closed-form solution: more attention should be paid to augmented samples with large loss values (i.e., harder examples). Minimizing this maximal expected loss enables the model to perform well under any reweighting strategy. The proposed method can generally be applied on top of any data augmentation methods. Experiments are conducted on both natural language understanding tasks with token-level data augmentation, and image classification tasks with commonly-used image augmentation techniques like random crop and horizontal flip. Empirical results show that the proposed method improves the generalization performance of the model.

1. INTRODUCTION

Deep neural networks have achieved state-of-the-art results in various tasks in natural language processing (NLP) tasks (Sutskever et al., 2014; Vaswani et al., 2017; Devlin et al., 2019) and computer vision (CV) tasks (He et al., 2016; Goodfellow et al., 2016) . One approach to improve the generalization performance of deep neural networks is data augmentation (Xie et al., 2019; Jiao et al., 2019; Cheng et al., 2019; 2020) . However, there are some problems if we directly incorporate these augmented samples into the training set. Minimizing the average loss on all these samples means treating them equally, without considering their different implicit impacts on the loss. To address this, we propose to minimize a reweighted loss on these augmented samples to make the model utilize them in a cleverer way. Example reweighting has previously been explored extensively in curriculum learning (Bengio et al., 2009; Jiang et al., 2014) , boosting algorithms (Freund & Schapire, 1999) , focal loss (Lin et al., 2017) and importance sampling (Csiba & Richtárik, 2018) . However, none of them focus on the reweighting of augmented samples instead of the original training samples. A recent work (Jiang et al., 2020a ) also assigns different weights on augmented samples. But weights in their model are predicted by a mentor network while we obtain the weights from the closed-form solution by minimizing the maximal expected loss (MMEL). In addition, they focus on image samples with noisy labels, while our method can generally be applied to also textual data as well as image data. Tran et al. (2017) propose to minimize the loss on the augmented samples under the framework of Expectation-Maximization algorithm. But they mainly focus on the generation of augmented samples. Unfortunately, in practise there is no way to directly access the optimal reweighting strategy. Thus, inspired by adversarial training (Madry et al., 2018) , we propose to minimize the maximal expected loss (MMEL) on augmented samples from the same training example. Since the maximal expected loss is the supremum over any possible reweighting strategy on augmented samples' losses, minimizing this supremum makes the model perform well under any reweighting strategy. More importantly, we derive a closed-form solution of the weights, where augmented samples with larger training losses have larger weights. Intuitively, MMEL allows the model to keep focusing on augmented samples that are harder to train. The procedure of our method is summarized as follows. We first generate the augmented samples with commonly-used data augmentation technique, e.g., lexical substitution for textual input (Jiao et al., 2019) , random crop and horizontal flip for image data (Krizhevsky et al., 2012) . Then we explicitly derive the closed-form solution of the weights on each of the augmented samples. After that, we update the model parameters with respect to the reweighted loss. The proposed method can generally be applied above any data augmentation methods in various domains like natural language processing and computer vision. Empirical results on both natural language understanding tasks and image classification tasks show that the proposed reweighting strategy consistently outperforms the counterpart of without using it, as well as other reweighting strategies like uniform reweighting.

2. RELATED WORK

Data augmentation. Data augmentation is proven to be an effective technique to improve the generalization ability of various tasks, e.g., natural language processing (Xie et al., 2019; Zhu et al., 2020; Jiao et al., 2019) , computer vision (Krizhevsky et al., 2014) , and speech recognition (Park et al., 2019) . For image data, baseline augmentation methods like random crop, flip, scaling, and color augmentation (Krizhevsky et al., 2012) have been widely used. Other heuristic data augmentation techniques like Cutout (DeVries & Taylor, 2017) which masks image patches and Mixup (Zhang et al., 2018) which combines pairs of examples and their labels, are later proposed. Automatically searching for augmentation policies (Cubuk et al., 2018; Lim et al., 2019) have recently proposed to improve the performance further. For textual data, Zhang et al. (2015) ; Wei & Zou (2019) and Wang (2015) respectively use lexical substitution based on the embedding space. Jiao et al. (2019) ; Cheng et al. (2019) ; Kumar et al. (2020) generate augmented samples with a pre-trained language model. Some other techniques like back translation (Xie et al., 2019) , random noise injection (Xie et al., 2017) and data mixup (Guo et al., 2019; Cheng et al., 2020) are also proven to be useful. Adversarial training. Adversarial learning is used to enhance the robustness of model (Madry et al., 2018) , which dynamically constructs the augmented adversarial samples by projected gradient descent across training. Although adversarial training hurts the generalization of model on the task of image classification (Raghunathan et al., 2019) , it is shown that adversarial training can be used as data augmentation to help generalization in neural machine translation (Cheng et al., 2019; 2020) and natural language understanding (Zhu et al., 2020; Jiang et al., 2020b) . Our proposed method differs from adversarial training in that we adversarially decide the weight on each augmented sample, while traditional adversarial training adversarially generates augmented input samples. In (Behpour et al., 2019) , adversarial learning is used as data augmentation in object detection. The adversarial samples (i.e., bounding boxes that are maximally different from the ground truth) are reweighted to form the underlying annotation distribution. However, besides the difference in the model and task, their training objective and the resultant solution are also different from ours. Sample reweighting. Minimizing a reweighted loss on training samples has been widely explored in literature. Curriculum learning (Bengio et al., 2009; Jiang et al., 2014) feeds first easier and then harder data into the model to accelerate training. Zhao & Zhang (2014) ; Needell et al. (2014) ; Csiba & Richtárik (2018) ; Katharopoulos & Fleuret (2018) use importance sampling to reduce the variance of stochastic gradients to achieve faster convergence rate. Boosting algorithms (Freund & Schapire, 1999) choose harder examples to train subsequent classifiers. Similarly, hard example mining (Malisiewicz et al., 2011) downsamples the majority class and exploits the most difficult examples. Focal loss (Lin et al., 2017; Goyal & He, 2018) focuses on harder examples by reshaping the standard cross-entropy loss in object detection. Ren et al. (2018) ; Jiang et al. (2018) ; Shu et al. (2019) use meta-learning method to reweight examples to handle the noisy label problem. Unlike all these existing methods, in this work, we reweight the augmented samples' losses instead of training samples.

3. MINIMIZE THE MAXIMAL EXPECTED LOSS

In this section, we derive our reweighting strategy on augmented samples from the perspective of maximal expected loss. We first give a derivation of the closed-form solution of the weights on augmented samples. Then we describe two kinds of loss under this formulation. Finally, we give the implementation details using the natural language understanding task as an example.

3.1. WHY MAXIMAL EXPECTED LOSS

Consider a classification task with N training samples. For the i-th training sample x i , its label is denoted as y xi . Let f θ (•) be the model with parameter θ which outputs the classification probabilities. (•, •) denotes the loss function, e.g. the cross-entropy loss between outputs f θ (x i ) and the groundtruth label y xi . Given an original training sample x i , the set of augmented samples generated by some method is B(x i ). Without loss of generality, we assume x i ∈ B(x i ). The conventional training objective is to minimize the loss on every augmented sample z in B(x i ) as min θ 1 N N i=1   1 |B(xi)| (z,yz )∈B(x i ) (f θ (z), yz)   , where y z is the label of z ∈ B(x i ), and can be different with y xi . |B(x i )| is the number of augmented samples in B(x i ), which is assumed to be finite. In equation ( 1), for each given x i , the weights on its augmented samples are the same (i.e., 1/|B(x i )|). However, different samples have different implicit impacts on the loss, and we can assign different weights on them to facilitate training. Note that computing the weighted sum of losses of each augmented sample in B(x i ) can be viewed as taking expectation of loss on augmented samples z ∈ B(x i ) under a certain distribution. When the augmented samples generated from the same training sample are drawn from a uniform distribution, the loss in equation ( 1) can be rewritten as min θ R θ (PU ) = min θ 1 N N i=1 E z∼P U (•|x i ) [ (f θ (z), yz)] -λP KL(PU (• | xi) PU (• | xi)) , where the Kullback-Leibler (KL) divergence The KL divergence term in equation ( 3) is used as a regularizer to encourage P B close to P U (see Remark 2). From equation (3), the conditional distribution P B determines the weights of each augmented sample in B(x i ). There may exist an optimal formulation of P B in some regime, e.g. corresponding to the optimal generalization ability of model. Unfortunately, we can not explicitly characterize such an unknown optimal P B . To address this, we borrow the idea from adversarial training (Madry et al., 2018) and minimize the maximal reweighted loss on augmented samples. Then, the model is guaranteed to perform well under any reweighting strategy, including the underlying optimal one. Specifically, let the conditional distribution P B be P * θ = arg sup P B R θ (P B ). Our objective is to minimize the following reweighted loss KL(P U (• | x i ) P U (• | x i )) min θ R θ (PB) = min θ 1 N N i=1 E z∼P B (•|x i ) [ (f θ (z), yz)] -λP KL(PB(• | xi) PU (• | xi)) . (3) Remark 1. When P B (• | x i ) reduces to the uniform distribution P U (• | x i ) for any x i , since KL(P U (• | x i ) P U (• | x i )) = 0, min θ R θ (P * θ ) = min θ sup P B R θ (PB). The following Remark 2 discusses about the KL divergence term in equation (3). Remark 2. Since we take a supremum over P B in equation (4), the regularizer KL(P B P U ) encourages P B to be close to P U because it reaches the minimal value zero when P B = P U . Thus the regularizer controls the diversity among the augmented samples by constraining the discrepancy between P B and uniform distribution P U , e.g., a larger λ P promotes a larger diversity among the augmented samples. The following Theorem 1 gives the explicit formulation of R θ (P * θ ). Theorem 1. Let R θ (P B ) and R θ (P * θ ) be defined in equation ( 1) and (4), then we have R θ (P * θ ) = 1 N N i=1   z∈B(x i ) P * θ (z | xi) (f θ (z), yz) -λP P * θ (z | xi) log (|B(xi)|P * θ (z | xi))   , where P * θ (z | xi) = exp 1 λ P (f θ (z), yz) z∈B(x i ) exp 1 λ P (f θ (z), yz) = Softmaxz 1 λP (f θ (B(xi)), y B(x i ) ) , where Softmax z ( 1 λ P (f θ (B(x i )), y B(xi) )) represents the output probability of z for vector ( 1 λ P (f θ (z 1 ), y z1 ), • • • , 1 λ P (f θ (z |B(xi)| ), y |B(xi)| ) ). Remark 3. If we ignore the KL divergence term in equation (3), due to the equivalence of minimizing cross-entropy loss and MLE loss (Martens, 2019) , the proposed MMEL also falls into the generalized Expectation-Maximization (GEM) framework (Dempster et al., 1977) . Specifically, given a training example, the augmented samples of it can be viewed as latent variable, and any reweighting on these augmented samples corresponds to a specific conditional distribution of these augmented samples given the training sample. In the expectation step (E-step), we explicitly derive the closed-form solution of the weights on each of these augmented samples according to (6). In the maximization step, since there is no analytical solution for deep neural networks, following (Tran et al., 2017) , we update the model parameters with respect to the reweighted loss by one step of gradient descent. The proof of this theorem can be found in Appendix A. From Theorem 1, the loss of it decides the weight on each augmented sample z ∈ B xi , and the weight is normalized by Softmax over all augmented samples in B xi . The reweighting strategy allows more attention paid to augmented samples with higher loss values. The strategy is similar to those in (Lin et al., 2017; Zhao & Zhang, 2014 ) but they apply it on training samples.

3.2. TWO TYPES OF LOSS

For augmented sample z ∈ B(x i ), instead of computing the discrepancy between the output probability f θ (z) and the hard label y z as in equation ( 5), one can also compute the discrepancy between f θ (z) and the "soft" probability f θ (x i ) in the absence of ground-truth label on augmented samples as in (Xie et al., 2019) . In the following, We use superscript "hard" for the loss in equation ( 5) as R hard θ (P * θ , xi) = z∈B(x i ) P * θ (z | xi) (f θ (z), yz)) -λP P * θ (z | xi) log (|B(xi)|P * θ (z | xi)), to distinguish with the following objective which uses the "soft probability": R soft θ (P * θ , xi) = (f θ (xi), yx i ) + λT z∈B(x i );z =x i P * θ (z | xi) (f θ (z), f θ (xi)) -λP P * θ (z | xi) log (|B(xi)| -1)P * θ (z | xi) . The two terms in R soft θ (P * θ , x i ) respectively correspond to the loss on original training samples x i and the reweighted loss on the augmented samples. The reweighted loss promotes a small discrepancy between the augmented samples and the original training sample. λ T > 0 is the coefficient used to balance the two loss terms, and P * θ (z | x i ) is defined similar to (6) as P * θ (z | xi) = exp 1 λ P (f θ (z), f θ (xi)) z∈B(x i );z =x i exp 1 λ P (f θ (z), f θ (xi)) . ( ) 𝒙𝒙 𝒊𝒊 𝒛𝒛 1 𝒛𝒛 2 𝒛𝒛 𝐵𝐵 -1 𝑓𝑓 𝜃𝜃 (𝒛𝒛 2 ) ℓ(𝑓𝑓 𝜃𝜃 (𝒛𝒛 1 ), 𝑓𝑓 𝜃𝜃 (𝒙𝒙 𝑖𝑖 )) 𝑓𝑓 𝜃𝜃 (𝒛𝒛 1 ) 𝑓𝑓 𝜃𝜃 (𝒛𝒛 𝐵𝐵 -1 ) ℓ(𝑓𝑓 𝜃𝜃 (𝒛𝒛 2 ), 𝑓𝑓 𝜃𝜃 (𝒙𝒙 𝑖𝑖 )) ℓ(𝑓𝑓 𝜃𝜃 (𝒛𝒛 |𝐵𝐵|-1 ), 𝑓𝑓 𝜃𝜃 (𝒙𝒙 𝑖𝑖 )) … × 𝑝𝑝 1 𝑅𝑅 𝜃𝜃 𝑠𝑠𝑠𝑠𝑠𝑠𝑠𝑠 (𝑃𝑃 𝜃𝜃 * , 𝒙𝒙 𝒊𝒊 ) × 𝑝𝑝 2 × 𝑝𝑝 𝐵𝐵 -1 𝑓𝑓 𝜃𝜃 (𝒙𝒙 𝑖𝑖 ) × 1 ℓ(𝑓𝑓 𝜃𝜃 (𝒙𝒙 𝑖𝑖 , 𝑦𝑦 𝒙𝒙𝑖𝑖 ) 𝒙𝒙 𝒊𝒊 𝒛𝒛 1 𝒛𝒛 2 𝒛𝒛 𝐵𝐵 -1 𝑓𝑓 𝜃𝜃 (𝒛𝒛 2 ) ℓ(𝑓𝑓 𝜃𝜃 (𝒛𝒛 1 ), 𝑦𝑦 𝒛𝒛1 ) 𝑓𝑓 𝜃𝜃 (𝒛𝒛 1 ) 𝑓𝑓 𝜃𝜃 (𝒛𝒛 𝐵𝐵 -1 ) ℓ(𝑓𝑓 𝜃𝜃 (𝒛𝒛 2 ), 𝑦𝑦 𝒛𝒛2 ) ℓ(𝑓𝑓 𝜃𝜃 (𝒛𝒛 𝐵𝐵 -1 ), 𝑦𝑦 𝑧𝑧|𝐵𝐵|-1 ) … × 𝑝𝑝 1 𝑅𝑅 𝜃𝜃 ℎ𝑎𝑎𝑎𝑎𝑎𝑎 (𝑃𝑃 𝜃𝜃 * , 𝒙𝒙 𝒊𝒊 ) × 𝑝𝑝 2 × 𝑝𝑝 𝐵𝐵 -1 𝑓𝑓 𝜃𝜃 (𝒙𝒙 𝑖𝑖 ) × 𝑝𝑝 𝐵𝐵 ℓ(𝑓𝑓 𝜃𝜃 (𝒙𝒙 𝑖𝑖 , 𝑦𝑦 𝒙𝒙𝑖𝑖 ) (a) MMEL-H. 𝒙𝒙 𝒊𝒊 𝒛𝒛 1 𝒛𝒛 2 𝒛𝒛 𝐵𝐵 -1 𝑓𝑓 𝜃𝜃 (𝒛𝒛 2 ) ℓ(𝑓𝑓 𝜃𝜃 (𝒛𝒛 1 ), 𝑓𝑓 𝜃𝜃 (𝒙𝒙 𝑖𝑖 )) 𝑓𝑓 𝜃𝜃 (𝒛𝒛 1 ) 𝑓𝑓 𝜃𝜃 (𝒛𝒛 𝐵𝐵 -1 ) ℓ(𝑓𝑓 𝜃𝜃 (𝒛𝒛 2 ), 𝑓𝑓 𝜃𝜃 (𝒙𝒙 𝑖𝑖 )) ℓ(𝑓𝑓 𝜃𝜃 (𝒛𝒛 |𝐵𝐵|-1 ), 𝑓𝑓 𝜃𝜃 (𝒙𝒙 𝑖𝑖 )) … × 𝑝𝑝 1 𝑅𝑅 𝜃𝜃 𝑠𝑠𝑠𝑠𝑠𝑠𝑠𝑠 (𝑃𝑃 𝜃𝜃 * , 𝒙𝒙 𝒊𝒊 ) × 𝑝𝑝 2 × 𝑝𝑝 𝐵𝐵 -1 𝑓𝑓 𝜃𝜃 (𝒙𝒙 𝑖𝑖 ) × 1 ℓ(𝑓𝑓 𝜃𝜃 (𝒙𝒙 𝑖𝑖 , 𝑦𝑦 𝒙𝒙𝑖𝑖 ) 𝒙𝒙 𝒊𝒊 𝒛𝒛 1 𝒛𝒛 2 𝒛𝒛 𝐵𝐵 -1 𝑓𝑓 𝜃𝜃 (𝒛𝒛 2 ) ℓ(𝑓𝑓 𝜃𝜃 (𝒛𝒛 1 ), 𝑦𝑦 𝒛𝒛1 ) 𝑓𝑓 𝜃𝜃 (𝒛𝒛 1 ) 𝑓𝑓 𝜃𝜃 (𝒛𝒛 𝐵𝐵 -1 ) ℓ(𝑓𝑓 𝜃𝜃 (𝒛𝒛 2 ), 𝑦𝑦 𝒛𝒛2 ) ℓ(𝑓𝑓 𝜃𝜃 (𝒛𝒛 𝐵𝐵 -1 ), 𝑦𝑦 𝑧𝑧|𝐵𝐵|-1 ) … × 𝑝𝑝 1 𝑅𝑅 𝜃𝜃 ℎ𝑎𝑎𝑎𝑎𝑎𝑎 (𝑃𝑃 𝜃𝜃 * , 𝒙𝒙 𝒊𝒊 ) × 𝑝𝑝 2 × 𝑝𝑝 𝐵𝐵 -1 𝑓𝑓 𝜃𝜃 (𝒙𝒙 𝑖𝑖 ) × 𝑝𝑝 𝐵𝐵 ℓ(𝑓𝑓 𝜃𝜃 (𝒙𝒙 𝑖𝑖 , 𝑦𝑦 𝒙𝒙𝑖𝑖 ) (b) MMEL-S. Figure 1 : MMEL with two types of losses. Figure ( 1a) is the hard loss (7) with probability computed using (6) while Figure ( 1b) is the soft loss (8) with the probabilities computed using (9). Algorithm 1 Minimize the Maximal Expected Loss (MMEL) Input: Training set {(x 1 , y x1 ), • • • , (x N , y x N )}, batch size S, learning rate η, number of training iterations T , R θ equals R hard θ or R soft θ . 1: for i in {1, 2, • • • , N } do generate augmented samples 2: Generating B(x i ) using some data augmentation method. 3: end for 4: for t = 1, • • • , T do minimize the maximal expected loss 5: Randomly sample a mini-batch S = {(x i1 , y xi 1 ), • • • , (x i S , y xi S )} from training set. 6: Fetch the augmented samples B(x i1 ), B(x i2 ), • • • , B(x i S ) .

7:

Compute P * θ according to (6) or ( 9).

8:

Update model parameters θ t+1 = θ t -η S x∈S ∇ θ R θ (P * θ , x). 9: end for The two losses are shown in Figure 1 . Summing over all the training samples, we get the two kinds of reweighted training objectives. Remark 4. The proposed MMEL-S tries to reduce the discrepancy between f θ (z) and f θ (x i ) for z ∈ B(x i ). However, if the prediction f θ (x i ) is inaccurate, such misleading supervision for z may lead to the degraded performance of MMEL-S. More details are in Appendix B.

3.3. EXAMPLE: MMEL IMPLEMENTATION ON NATURAL LANGUAGE UNDERSTANDING TASKS

In this section, we elaborate on implementing the proposed method using textual data in natural language understanding tasks as an example. Our method is separated into two phases. In the first phase, we generate augmented samples. Then in the second phase, with these augmented samples, we update the model parameters under these augmented samples with respect to the hard reweighted loss (7) or the soft counterpart (8). The generation and training procedure can be decoupled, and the augmented samples are offline generated in the first phase by only once. On the other hand, in the second phase, since we have the explicit solution of weights on augmented samples and the multiple forward and backward passes on these augmented samples can be computed in parallel, the whole training time is similar to the regular training counterpart for an appropriate number of augmented samples. The whole training process is shown in Algorithm 1. Generation of Textual Augmented Data. Various methods have been proposed to generate augmented samples for textual data. Recently, large-scale pre-trained language models like BERT (Devlin et al., 2019) and GPT-2 (Radford et al., 2019) learn contextualized representations and have been used widely in generating high-quality augmented sentences (Jiao et al., 2019; Kumar et al., 2020) . In this paper, we use a pre-trained BERT trained from masked language modeling to generate augmented samples. For each original input sentence, we randomly mask k tokens. Then we do a forward propagation of the BERT to predict the tokens in those masked positions by greedy search. Details can be found in Algorithm 2 in Appendix C. Mismatching Label. For R hard θ in equation ( 7), the loss term (f θ (z), y z ) on augmented sample z ∈ B(x i ) for some x i relies on its label y z . Unlike image data, where conventional augmentation methods like random crop and horizontal flip of an image do not change its label, substituting even one word in a sentence can drastically change its meaning. For instance, suppose the original sentence is "She is my daughter", and the word "She" is masked. The top 5 words predicted by the pre-trained BERT are "This, She, That, It, He". Apparently, for the task of linguistic acceptability task, replacing "She" with "He" can change the label from linguistically "acceptable" to "non-acceptable". Thus for textual input, for the term (f θ (z), y z ) in hard loss (7), instead of directly setting y z as y xi (Zhu et al., 2020) , we replace y z with the output probability of a trained teacher model. On the other hand, for the soft loss in equation ( 8), if an augmented sample z ∈ B(x i ) is predicted to a different class from x i by the teacher model, it is unreasonable to still minimize the discrepancy between f θ (z) and f θ (x i ). In this case, we replace f θ (x i ) in the loss term λ T z∈B(xi); z =xi P * θ (z | x i ) (f θ (z), f θ (x i )) with the output probability from the teacher model.

4. EXPERIMENTS

In this section, we evaluate the efficacy of the proposed MMEL algorithm with both hard loss (MMEL-H) and soft loss (MMEL-S). Experiments are conducted on both the image classification tasks CIFAR-10 and CIFAR-100 (Krizhevsky et al., 2014) with the ResNet Model (He et al., 2016) , and the General Language Understanding Evaluation (GLUE) tasks (Wang et al., 2019) with the BERT model (Devlin et al., 2019) .

4.1. EXPERIMENTS ON IMAGE CLASSIFICATION TASKS.

Data. CIFAR (Krizhevsky et al., 2014) is a benchmark dataset for image classification. We use both CIFAR-10 and CIFAR-100 in our experiments, which are colorful images with 50000 training samples and 10000 validation samples, but from 10 and 100 object classes, respectively. Setup. The model we used is ResNet (He et al., 2016) with different depths. We use random crop and horizontal flip (Krizhevsky et al., 2012) to augment the original training images. Since these operations do not change the augmented sample label, we directly adopt the original training sample label for all its augmented samples. Following (He et al., 2016) , we use the SGD with momentum optimizer to train each model for 200 epochs. The learning rate starts from 0.1 and decays by a factor of 0.2 at epochs 60, 120 and 160. The batch size is 128, and weight decay is 5e-4. For each x i , |B(x i )| = 10. The λ P of the KL regularization coefficient is 1.0 for both MMEL-H and MMEL-S. The λ T in equation ( 8) for MMEL-S is selected from {0.5, 1.0, 2.0}. We compare our proposed MMEL with conventional training with data augmentation (abbreviated as "Baseline(DA)") under the same number of epochs. Though MMEL can be computed efficiently in parallel, the proposed MMEL encounters |B(x i )| = 10 times more training data. For fair comparison, we also compare with two other baselines that also use 10 times more data: (i) naive training with data augmentation but with 10 times more training epochs compared with MMEL (abbreviated as "Baseline(DA+Long)"). In this case, the learning rate accordingly decays at epochs 600, 1200 and 1600; (ii) training with data augmentation under the framework of MMEL but with uniform weights on the augmented samples (abbreviated as "Baseline(DA+UNI)"). Main Results. The results are shown in Table 1 . As can be seen, for both CIFAR-10 and CIFAR-100, MMEL-H and MMEL-S significantly outperform the Baseline(DA), with over 0.5 points higher accuracy on all four architectures. Compared to Baseline(DA+Long), the proposed MMEL-H and MMEL-S also have comparable or better performance, while being much more efficient in training. This is because our backward pass only computes the gradient of the weighted loss instead of the separate loss of each example. Compared to Baseline(DA+UNI) which has the same computational cost as MMEL-H and MMEL-S, the proposed methods also have better performance. This indicates the efficacy of the proposed maximal expected loss based reweighting strategy. We further evaluate the proposed method on larege-scale dataset ImageNet (Deng et al., 2009) . The detailed results are in Appendix B. 

4.2. RESULTS ON NATURAL LANGUAGE UNDERSTANDING TASKS

Data. GLUE is a benchmark containing various natural language understanding tasks, including textual entailment (RTE and MNLI), question answering (QNLI), similarity and paraphrase (MRPC, QQP, STS-B), sentiment analysis (SST-2) and linguistic acceptability (CoLA). Among them, STS-B is a regression task, CoLA and SST-2 are single sentence classification tasks, while the rest are sentence-pair classification tasks. Following (Devlin et al., 2019) , for the development set, we report Spearman correlation for STS-B, Matthews correlation for CoLA and accuracy for the other tasks. For the test set for QQP and MRPC, we report "F1". Setup. The backbone model is BERT BASE (Devlin et al., 2019) . We use the method in Section 3.3 to generate augmented samples. For the problem of mismatching label as described in Section 3.3, we use a BERT BASE model fine-tuned on the downstream task as teacher model to predict the label of each generated sample z in B(x i ). For each x i , |B(x i )| = 5. The fraction of masked tokens for each sentence is 0.4. The λ P of the KL regularization coefficient is 1.0 for both MMEL-H and MMEL-S. The λ T in equation ( 8) for MMEL-S is 1.0. The other detailed hyperparameters in training can be found in Appendix D. The derivation of MMEL in Section 3 is based on the classification task, while STS-B is a regression task. Hence, we generalize our loss function accordingly for regression tasks as follows. For the hard loss in equation ( 7), we directly replace y z ∈ R with the prediction of teacher model on z. For the soft loss (8), for each entry of f θ (x i ) in loss term λ T z∈B(xi);z =xi P * θ (z | x i )MSE(f θ (z), f θ (x i )), we replace it with the prediction of teacher model if the difference between them is larger than 0.5. Similar to Section 4.1, We compare with three baselines. However, we change the first baseline to naive training without data augmentation (abbreviated as "Baseline") since data augmentation is not used by default in NLP tasks. The other two baselines are similar to those in Section 4.1: (i) "Baseline(DA+Long)" which fine-tunes BERT with data augmentation with the same batch size; and (ii)"Baseline(DA+UNI)" which fine-tunes BERT with augmented samples by using average loss. We also compare with another recent data augmentation technique SMART (Jiang et al., 2020b) . Main Results. The development and test set results on the GLUE benchmark are shown in Table 3 . The development set results for the BERT baseline are from our re-implementation, which is comparable or better than the reported results in the original paper (Devlin et al., 2019) . The results for SMART are taken from (Jiang et al., 2020b) , and there are no test set results in (Jiang et al., 2020b) . As can be seen, data augmentation significantly improves the generalization of GLUE tasks. Compared to the baseline without data augmentation (Baseline), MMEL-H or MMEL-S consistently achieves better performance, especially on small datasets like CoLA and RTE. Similar to the observation in the image classification task in Section 4.1, the proposed MMEL-H and MMEL-S are more efficient and have better performance than Baseline(DA+Long). MMEL-H and MMEL-S also outperform Baseline(DA+UNI), indicating the superiority of using the proposed reweighting strategy. In addition, our proposed method also beats SMART in both accuracy and efficiency because they use PGD-k (Madry et al., 2018) to construct adversarial augmented samples which requires nearly k times more training cost. Figure 2 shows the development set accuracy across over the training procedure. As can be seen, training with MMEL-H or MMEL-S converges faster and has better accuracy except SST-2 and RTE where the performance is similar. Effect of Predicted Labels. For the augmented samples from same origin, we use a fine-tuned task-specific BERT BASE teacher model to predict their labels as mentioned in Section 3.3 to handle the problem of mismatching label. In Table 4 , we show the comparison between using the label of the original sample and using predicted labels. As can be seen, using the predicted label significantly improves the performance. By comparing with the results in Table 3 , using the label of the original sample even hurts the performance. 

5. CONCLUSION

In this work, we propose to minimize a reweighted loss over the augmented samples which directly considers their implicit impacts on the loss. Since we can not access the optimal reweighting strategy, we propose to minimize the supremum of the loss under all reweighting strategies, and give a closedform solution of the optimal weights. Our method can be applied on top of any data augmentation methods. Experiments on both image classification tasks and natural language understanding tasks show that the proposed method improves the generalization performance of the model, while being efficient in training. C GENERATING AUGMENTED SAMPLES FOR TEXTUAL SAMPLES In this section, we elaborate the procedure of generating augmented sentences using greedy-based and beam-based method for a sequence. For each original input sentence, we randomly mask k tokens (which is obtained by rounding the product of masking ratio and length of the sequence to the nearest number) and then we do a forward propagation of the BERT to predict the tokens in those masked positions using greedy search. The detailed procedure is shown in Algorithm 2. We also use beam search (Yang et al., 2018) to generate augmented data. The details of beam search can be referred to (Yang et al., 2018) . For sentence-pair tasks, we treat the two sentences separately and generate augmented samples for each of them. z i [p j ] ← the most likely word predicted by BertModel(z i [p j ]|z i ).

7:

end for 8: end for In the following, we vary the factors that may affect the quality of the generated augmented samples. These factors include 1. The number of masked tokens, which equals the replacement proportion multiplied with the sentence length. This affects the diversity of augmented samples, i.e., replacing a large proportion of tokens makes the augmented sample less similar to the original one. 2. Treating the two sentences separately in sentence-pair tasks when generating augmented examples, or concatenate them as a single sentence; 3. Different generation methods like greedy search (Algorithm 2) and beam search. The results are shown in Table 6 . As can be seen, compared with Baseline without data augmentation, MMEL-H and MMEL-S under all hyperparameter configurations have higher accuracy, showing the efficacy of data augmentation and the proposed reweighting strategy. There is no significant difference in using greedy search or beam search to generate the augmented samples. In this natural understanding task, training with augmented samples generated with proper larger replacement proportion (i.e., larger diversity) has slightly better performance. For sentence-pair tasks, treating the two sentences separately and generate augmented samples for each of them has slightly better performance. In the experiments in Section 4.2, we use Greedy search, masking proportion 0.4, and generate augmented sentence for each sentence in sentence-pair tasks. D HYPERPARAMETERS FOR THE EXPERIMENT ON THE GLUE BENCHMARK. The optimizer we used is AdamW (Loshchilov & Hutter, 2018) . The hyperparameters of BERT BASE model are listed in Table 7 .



In the following, we simplify PB(• | •) as PB if there is no obfuscation.



the objective in equation (3) reduces to the one in equation (1).

Figure 2: Development set results on BERT BASE model with different loss functions.

Augmented Sample Generation by Greedy SearchInput: Pre-trained language model BertModel, original sentence x, number of augmented samples|B(x)| -1, number of masked tokens k. Output: Augmented samples B(x) = {z 1 , z 2 , • • • , z |B(x)|-1 }. 1: Randomly sample k positions {p 1 , • • • , p k } and get x mask . 2: for i = 1, 2, • • • |B(x)| -1 doGenerate the i-th augmented sample 3:z i ← x mask . 4: z i [p 1 ]← the ith most likely word predicted by BertModel(z i [p 1 ]|z i ).

equals zero. Here P U (• | x i ) denotes the uniform distribution on B(x i ). When the augmented samples are drawn from a more general distribution P B (• | •) 1 instead of the uniform distribution, we can generalize P U (• | •) here to some other conditional distribution P B .

Performance of ResNet on CIFAR-10 and CIFAR-100. The time is the training time measured on a single NVIDIA V100 GPU. The results of five independent runs with "mean (±std)" are reported, expected for "Baseline(DA + Long)" which is slow in training.Varying the Number of Augmented Samples. One hyperparameter of the proposed method is the number of augmented samples |B(x i )|. In Table2, we evaluate the effect of |B(x i )| on the CIFAR dataset. We vary |B(x i )| in {2, 5, 10, 20} for both MMEL-H and MMEL-S with other settings unchanged. As can be seen, the performance of MMEL improves with more augmented samples for small |B(x i )|. However, the performance gain begins to saturate when |B(x i )| reaches 5 or 10 for some cases. Since a larger |B(x i )| also brings more training cost, we should choose a proper number of augmented samples rather than continually increasing it.

Performance of MMEL on CIFAR-10 and CIFAR-100 with ResNet with varying |B xi |. Here "MMEL-*-k" means training with MMEL-* loss with |B(x i )| = k. The results are averaged over five independent runs with "mean(±std)" reported.

Development and test sets results on the BERT BASE model. The training time is measured on a single NVIDIA V100 GPU. The results of Baseline, Baseline(DA+UNI), MMEL-H and MMEL-S are obtained by five independent runs with "mean(±std)" reported.

Effect of using the predicted label. Development set results are reported.

Performance of ResNet on ImageNet.

funding

* This work is done when Mingyang Yi is an intern at Huawei Noah's Ark Lab.

A PROOF OF THEOREM 1

Proof. For any given x i and B(x i ), we aim to find P θ (• | x i ) on B(x i ) such that maxSince the objective is convex, by Lagrange multiplier method, letFrom ∇ P θ L(P θ , λ) = ∇ λ L(P θ , λ) = 0, for any pairs of z u , z v ∈ B(x i ), we haveHence we haveSumming over z v ∈ B(x i ), we haveThe proof completes.

B MMEL ON LARGE-SCALE DATASET

In this section, we evaluate the proposed method MMEL on large-scale image classification task ImageNet (Deng et al., 2009) .Data. ImageNet is a benchmark dataset which contains colorful images with over 1 million training samples and 50000 validation samples from 1000 categories.Setup. The model we used is ResNet for ImageNet with three different depths (He et al., 2016) .All these experiments are conducted for 100 epochs, and the learning rate decays at epochs 30, 60, and 90. We set batch size as 256, and |B(x i )| = 10 for each x i . The other experimental settings follow Section 4.1, expect for the following hyperparameters. We compare the proposed method with "Baseline(DA)".Main Results. The results are shown in Table 5 . From the results, the proposed MMEL-H improves the performance of the model for all three depths. However, the proposed MMEL-S is beaten by the baseline method. We speculate this is due to the relatively larger proportion of inaccurate prediction of original training samples on the large-scale dataset. More specifically, as in equation ( 8), for each augmented sample z ∈ B(x i ), the proposed MMEL-S encourages the model to fit the output of original training sample f θ (x i ). However, the accuracy of the original training samples in the ImageNet dataset can not reach 100% e.g., about 80% for ResNet50 on ImageNet. The inaccurate prediction f θ (x i ) can be a misleading supervision for augmented sample z ∈ B(x i ), leading to degraded performance of the proposed MMEL-S. Thus, we suggest using the MMEL-H if the accuracy of the original training samples is relative low.Published as a conference paper at ICLR 2021 

