SELF-DISTILLATION FOR FURTHER PRE-TRAINING OF TRANSFORMERS

Abstract

The application of pre-training large transformer models on massive amounts of unlabeled data and fine-tuning them on labeled datasets for diverse downstream tasks has demonstrated remarkable success in various vision and natural language processing tasks. However, the direct fine-tuning approach may result in suboptimal performance if there exists a significant discrepancy between the pre-training and fine-tuning domains. To address this issue, some previous studies have proposed further pre-training strategies to continue pre-training the model on the target unlabeled dataset before fine-tuning. However, these strategies are limited to language models and may result in overfitting when applied to Vision Transformers. To overcome this limitation, we present a novel approach of self-distillation as a regularization method for the further pre-training stage. Our method first further pre-trains the initial pre-trained model on the target unlabeled data, and then uses it as a teacher for self-distillation. Then we take the same initial pre-trained model as a student, and enforce its hidden representations to be close to those of the teacher while optimizing the student with a masked auto-encoding objective. Our experiments demonstrate the superiority of self-distillation over relevant baselines on various benchmark datasets for image and text classification tasks. Furthermore, we provide a theoretical analysis of our proposed method using a simplified model to shed light on how self-distillation for further pre-training can potentially enhance the performance of downstream tasks.

1. INTRODUCTION

Pre-trained transformer models (Devlin et al., 2019; Brown et al., 2020; Liu et al., 2019; He et al., 2022) have been effective on various vision and natural language processing tasks. The pre-trained models learn general representation from a large volume of unlabeled data so that they generalize well to various downstream tasks when they are fine-tuned on each task with a labeled dataset. However, in many of real-world applications, it requires a considerable amount of effort to adapt the pre-trained model to a specific downstream task domain since there exists a significant distributional discrepancy between data for the pre-training and fine-tuning stage. Moreover, it is difficult to collect a large amount of labeled data for such specific domains, which renders adaptation of the pre-trained model to downstream tasks more challenging. Several works have proposed to tackle the problem of adapting pre-trained models to a specific domain. A prevalent approach for adaptation of the pre-trained model is further pre-training where we continue to update the parameters of the pre-trained model on additionally curated domain-specific unlabeled data with self-supervision (Beltagy et al., 2019; Lee et al., 2020) , before fine-tuning it on the target labeled data as depicted in Figure 2b . Gururangan et al. (2020) also show that further pretraining only with the target unlabeled data is still effective without any extra data. However, most of the existing further pretraining approaches have focused on language models, and we find that the further pre-training proposed to tackle the overfitting issue of large pre-trained models, however, they do not consider the adaptation process such as further pre-training. Instead, they enforce the distance between the final fine-tuned weight and the pre-trained weight to be small to promote the transfer of the knowledge acquired from pre-training to downstream tasks for better generalization. However, these regularizations hinder the adaptation of pre-trained models to downstream tasks especially when there is a significant distributional shift between the pre-trained data and target data. It eventually results in worse generalization than the simple fine-tuning strategy. To tackle these limitations, we propose self-distillation as a regularization for further pre-training on a target unlabeled dataset so that we can effectively adapt pre-trained models to the downstream task of various domains with a limited amount of labeled data. For self-supervision, we focus on masked auto-encoding for pre-training since it does not depend on any data augmentations, compared to other self-supervised learning methods (Chen et al., 2020b; He et al., 2020; Grill et al., 2020; Zbontar et al., 2021; Chen & He, 2021; Caron et al., 2021) which require data augmentations to construct positive pairs for self-supervised learning objective such as contrastive learning. This is especially useful when it is hard to define meaningful data augmentations for a target domain. Specifically, we take the pre-trained model with an encoder f θinit and a decoder g ϕinit which are pretrained on a massive amount of unlabeled data from general domain, and continue to pre-train it with masked auto-encoding (MAE) (Devlin et al., 2019; He et al., 2022) objective on the target unlabeled data to obtain f θ0 and g ϕ0 . After that, we set the encoder f θ0 as a teacher for self-distillation. Then we take the copy of the pre-trained model (f θinit , g ϕinit ) as a student, and match the representations of the student encoder and those of the teacher encoder while optimizing the student with the MAE on the target unlabeled data. Finally, we fine-tune the self-distilled student f θ1 on the target labeled data for the downstream task. We illustrate the overview of our method in Figure 2c . To verify the efficacy of our method, we empirically show that it significantly improves the generalization performance of a pre-trained ViT and language model RoBERTA (Liu et al., 2019) , and outperforms the relevant baselines on various image and text classification datasets. Moreover, we theoretically analyze the proposed method with a simplified model to understand how self-distillation for further pre-training can potentially help improve the generalization performance on the target tasks after fine-tuning. Our contribution is threefold: • We propose self-distillation for further pre-training on the target unlabeled dataset, where we enforce representations of the student to be close to those of the further pre-trained teacher while training the student with masked-auto encoding objective. • We theoretically analyze the proposed method with a simplified model to understand how self-distillation for further pre-training can potentially lead to better generalization performance of downstream tasks.



Figure 1: Acc. with varying the number of further pre-training steps.

Figure 2: Concepts. Comparison between methods adapting pre-trained transformers to the target domain. (a) Fine-tuning without any further pre-training. (b) Further pre-training and fine-tuning. (c) Self-distillation in further pre-training and fine-tuning. strategy is not effective for Vision Transformer (ViT) (Dosovitskiy et al., 2021). As shown in Figure 1, ViT is vulnerable to overfitting and does not generalize well to downstream tasks as when we continue to pre-train it on the target unlabeled data.Several regularization methods(Chen et al., 2020a; Gouk et al., 2021; Aghajanyan et al., 2021)  have proposed to tackle the overfitting issue of large pre-trained models, however, they do not consider the adaptation process such as further pre-training. Instead, they enforce the distance between the final fine-tuned weight and the pre-trained weight to be small to promote the transfer of the knowledge acquired from pre-training to downstream tasks for better generalization. However, these regularizations hinder the adaptation of pre-trained models to downstream tasks especially when there is a significant distributional shift between the pre-trained data and target data. It eventually results in worse generalization than the simple fine-tuning strategy.

