MASK-TUNING: TOWARDS IMPROVING PRE-TRAINED LAN-GUAGE MODELS' GENERALIZATION

Abstract

Pre-trained language models have the known generalization problem. This issue emerges from the pre-trained language models' learning process that heavily relies on spurious correlations, which work for the majority of training examples but do not hold in general. As a consequence, the models' performance drops substantially on out-of-distribution datasets. Previous studies proposed various solutions, including data augmentation and learning process improvement. In this paper, we present Mask-tuning, an approach that alleviates the impact of spurious correlations on the fine-tuning learning process. To achieve this goal, Mask-tuning integrates masked language training into the fine-tuning learning process. In this case, Mask-tuning perturbs the linguistic relation of downstream tasks' training examples and computes masked language training loss. Then, the perturbed examples are fed into fine-tuning process to be classified based on their ground-truth label and compute the fine-tuning training loss. Afterward, Mask-tuning loss-a weighted aggregation of masked language model training loss and fine-tuning loss-updates the masked language model and fine-tuning through training iterations. Extensive experiments show that Masktuning consistently improves the pre-trained language models' generalization on out-ofdistribution datasets and enhances their performance on in-distribution datasets. The source code and pre-trained models will be available on the author's GitHub page.

1. INTRODUCTION

One of the challenges in building a pre-trained language model with robust generalization is that training sets do not represent the linguistic diversity of real-world language. Thus its performance dramatically drops when encountering out-of-distribution datasets. This type of performance divergence on the in-distribution to out-of-distribution datasets is named generalization gap. Previous studies (Zhang et al., 2019; McCoy et al., 2019; Tu et al., 2020) have shown that pre-trained language models trained on the specific dataset are likely to learn spurious correlations, which are prediction rules that work for the majority examples but do not hold in general. It means that the fine-tuning loss function cannot incentivize the language model to learn the linguistics patterns from the minority examples and generalize them to more challenging examples (e.g., Out-of-distribution dataset). There are several solutions that have been referred the generalization gap as: annotation artifacts (Gururangan et al., 2018 ), dataset bias (He et al., 2019; Clark et al., 2019; Mahabadi et al., 2020 ), spurious correlation (Tu et al., 2020; Kaushik et al., 2020) , and group shift (Oren et al., 2019) . However, most of these methods rely on the strong assumption of knowing the datasets' spurious correlations or biased keywords in advance and also suffer from decreasing the model performance on the in-distribution dataset. One of the primary solutions for mitigating the generalization gap is directly increasing the number of minority examples in the training set and creating a more balanced training dataset, which has been performed in several studies named: data augmentation (Ng et al., 2020; Garg & Ramakrishnan, 2020) , domain adaptation

