MASK-TUNING: TOWARDS IMPROVING PRE-TRAINED LAN-GUAGE MODELS' GENERALIZATION

Abstract

Pre-trained language models have the known generalization problem. This issue emerges from the pre-trained language models' learning process that heavily relies on spurious correlations, which work for the majority of training examples but do not hold in general. As a consequence, the models' performance drops substantially on out-of-distribution datasets. Previous studies proposed various solutions, including data augmentation and learning process improvement. In this paper, we present Mask-tuning, an approach that alleviates the impact of spurious correlations on the fine-tuning learning process. To achieve this goal, Mask-tuning integrates masked language training into the fine-tuning learning process. In this case, Mask-tuning perturbs the linguistic relation of downstream tasks' training examples and computes masked language training loss. Then, the perturbed examples are fed into fine-tuning process to be classified based on their ground-truth label and compute the fine-tuning training loss. Afterward, Mask-tuning loss-a weighted aggregation of masked language model training loss and fine-tuning loss-updates the masked language model and fine-tuning through training iterations. Extensive experiments show that Masktuning consistently improves the pre-trained language models' generalization on out-ofdistribution datasets and enhances their performance on in-distribution datasets. The source code and pre-trained models will be available on the author's GitHub page.

1. INTRODUCTION

One of the challenges in building a pre-trained language model with robust generalization is that training sets do not represent the linguistic diversity of real-world language. Thus its performance dramatically drops when encountering out-of-distribution datasets. This type of performance divergence on the in-distribution to out-of-distribution datasets is named generalization gap. Previous studies (Zhang et al., 2019; McCoy et al., 2019; Tu et al., 2020) have shown that pre-trained language models trained on the specific dataset are likely to learn spurious correlations, which are prediction rules that work for the majority examples but do not hold in general. It means that the fine-tuning loss function cannot incentivize the language model to learn the linguistics patterns from the minority examples and generalize them to more challenging examples (e.g., Out-of-distribution dataset). There are several solutions that have been referred the generalization gap as: annotation artifacts (Gururangan et al., 2018 ), dataset bias (He et al., 2019; Clark et al., 2019; Mahabadi et al., 2020) , spurious correlation (Tu et al., 2020; Kaushik et al., 2020) , and group shift (Oren et al., 2019) . However, most of these methods rely on the strong assumption of knowing the datasets' spurious correlations or biased keywords in advance and also suffer from decreasing the model performance on the in-distribution dataset. One of the primary solutions for mitigating the generalization gap is directly increasing the number of minority examples in the training set and creating a more balanced training dataset, which has been performed in several studies named: data augmentation (Ng et al., 2020; Garg & Ramakrishnan, 2020) , domain adaptation (Gururangan et al., 2020; Lee et al., 2020 ), multi-task learning (Pruksachatkun et al., 2020; Phang et al., 2020) , and adversarial data generation (Garg & Ramakrishnan, 2020; Li et al., 2021) . Although these solutions have achieved different levels of success, they suffer from the following challenges: a) They require coming up with large, related, and useful training examples. b) They assume that an appropriate intermediate task and dataset are readily available. c) They need to re-run the pre-training phase from scratch. Finally, d) some of these methods added a complex procedure to the learning process. All these requirements need significantly high computational resources or human annotations (Zhou & Bansal, 2020; Chen et al., 2021) . In this paper, we propose Mask-tuning which is inspired by recent studies in computer vision (Hendrycks et al., 2020b; Mohseni et al., 2020) that showed joint-training improves the robustness of deep learning models. However, computer vision studies used a mix of in-distribution and out-of-distribution training examples for the training processes (labeled and unlabeled) besides changing the model's architecture by adding multiple auxiliary heads. In contrast, our approach (Mask-tuning) employs the original pre-trained language models' masked language modeling and fine-tuning and solely uses the downstream task's training dataset as labeled and unlabeled datasets. As a result, Mask-tuning is applicable with any pre-trained language model that works with the original fine-tuning. Mask-tuning reinforces the learning process of the pre-trained language models' fine-tuning by integrating the masked language model training into the fine-tuning process. The proposed approach hinders learning the spurious correlations by using masked language model loss to incentivize the fine-tuning training loss when the correct output likely be according to the learned spurious correlations. We used masked language modeling to perturb the linguistics relation between input training examples by masking a certain percentage of input tokens, predicting those masked ones, and computing the masked language model loss. Then the perturbed example is fed to fine-tuning for classification according to the ground-truth label. Afterward, Mask-tuning loss-a weighted aggregation of masked language modeling loss and fine-tuning loss-updates the masked language model and fine-tuning through training iterations. Mask-tuning-especially the fine-tuning step after masked language modeling-learns that training examples with different linguistic relations can have the same semantics and belong to the same label. Our main contributions are as follows: • We study the effect of integrating the masked language modeling into fine-tuning training process on the pre-trained language models' generalization and proposed Mask-tuning. Our proposed method is a plug-and-play tool applicable to any pre-trained language model without the need for making any changes to the pre-trained language models' architecture. Indeed, our approach solely uses the downstream task's training dataset. To our knowledge, we are the first to integrate the masked language model into the fine-tuning learning process that includes training losses aggregation. • Conducting comprehensive experiments under a consistent evaluation process to verify the effectiveness of Mask-tuning. Using BERT, RoBERTa, and autoregressive BART language models, we show that Mask-tuning outperforms eight state-of-the-art approaches from literature on three downstream tasks and three in-distribution and five out-of-distribution datasets. • Our ablation studies show the necessity of simultaneously running masked language model training with fine-tuning in mask-tuning. Extensive experiments show that when Mask-tuning trains only by fine-tuning loss or applying each training step separately (only fine-tuning the perturbed training examples), it could not mitigate the generalization gap as well as Mask-tuning with joint-loss.

2. APPROACH

In this section, we formally introduce the proposed approach setup named Mask-tuning. We first define the details of the Mask-tuning setup that is also illustrated in Fig. 1 . Then we explain the perturbation strategy and insights behind the Mask-tuning.

