LAYER GRAFTED PRE-TRAINING: BRIDGING CON-TRASTIVE LEARNING AND MASKED IMAGE MODEL-ING FOR LABEL-EFFICIENT REPRESENTATIONS

Abstract

Recently, both Contrastive Learning (CL) and Mask Image Modeling (MIM) demonstrate that self-supervision is powerful to learn good representations. However, naively combining them is far from success. In this paper, we start by making the empirical observation that a naive joint optimization of CL and MIM losses leads to conflicting gradient directions -more severe as the layers go deeper. This motivates us to shift the paradigm from combining loss at the end, to choosing the proper learning method per network layer. Inspired by experimental observations, we find that MIM and CL are suitable to lower and higher layers, respectively. We hence propose to combine them in a surprisingly simple, "sequential cascade" fashion: early layers are first trained under one MIM loss, on top of which latter layers continue to be trained under another CL loss. The proposed Layer Grafted Pre-training learns good visual representations that demonstrate superior label efficiency in downstream applications, in particular yielding strong few-shot performance besides linear evaluation. For instance, on ImageNet-1k, Layer Grafted Pre-training yields 65.5% Top-1 accuracy in terms of 1% few-shot learning with ViT-B/16, which improves MIM and CL baselines by 14.4% and 2.1% with no bells and whistles. The code is available at https://github. com/VITA-Group/layerGraftedPretraining_ICLR23.git.

1. INTRODUCTION

Self-supervision has demonstrated undoubted power in learning strong visual representation, with two mainstream representative methods: Contrastive learning (CL) (Chen et al., 2020b; He et al., 2020; Chen et al., 2020d; 2021; Grill et al., 2020; Caron et al., 2021) , and Mask Image Modeling (MIM) (Bao et al., 2021; He et al., 2021; Xie et al., 2022; Dong et al., 2021; 2022) . The two methods follow different mechanisms, and often manifest different strengths. Generally, CL performs the instance-level task that pulls augmented views from the same image to be similar while pushing different images to distribute diversely, making it versatile at learning semantic-aware clustering structures across images. In contrast, MIM draws inspiration from BERT (Devlin et al., 2018) and performs masked token or pixel reconstruction that facilitates the learning of rich local structures within the same image. In particular, although the latter one, MIM, has recently surpassed CL on the fine-tuning performance of many datasets, CL often remains to be a top competitor in data-scarce, few-shot downstream applications (Chen et al., 2020c; d; Tian et al., 2020) . A natural question then follows: are CL and MIM indeed complementary to each other, and is there a way to best combine their strengths?. One immediate, conceptually simple idea is to refer to multiple task learning (MTL) and jointly optimize the two losses on top of the same backbone. Unfortunately, our preliminary experiment (See Section 2.2) shows that such a vanilla combination fails to improve over either baseline, in fact often compromising the single loss's performance. A deeper dive reveals that the two losses, when being optimized together, will incur increasingly severe conflicts in the gradient directions, as the layers go deeper (see Figure 1 ). That causes considerable hurdles for the (pre-)training to effectively proceed. We are then inspired to ask: if the two losses conflict when both are placed at the end, how about placing them differently, such as appending them to different layers? Based on experimental observations, it appears that lower layers tend to learn better from the MIM loss in order to capture local spatial details; while higher layers tend to benefit more from the CL loss in order to learn semantically-aware grouping and invariance. Inspired by so, we propose a simple MIM→CL Grafting idea to combine the bests of both worlds: (step i) first training the lower layers with MIM loss and fixing their weights, on top of which (step ii) higher layer weights continue to be trained under another CL loss. This simple cascaded training idea neatly separates MIM and CL losses to avoid their conflicts against each other if placed together; each loss is also strategically placed to pre-training its most suitable portion. Practically, we "'smooth out" the grafting by allowing lower layers to be slowly tuned in step ii. Our ablation experiments also find that the order of grafting matters, i.e., reversing MIM/CL loss locations and performing CL→MIM will considerably damage the performance. The contributions of this paper are summarized as follows: • We propose Layer Grafted Pre-training, a principled framework to merge MIM and CL, improving representation learning beyond both, with no bells and whistles. • We investigate the different preferences of lower and higher layers towards CL and MIM losses, and show the order of grafting to matter. • Despite its embarrassing simplicity, the proposed Layer Grafted Pre-training demonstrates more desirable representation quality, and consequently superior label efficiency in downstream applications, yielding strong few-shot performance besides linear evaluation. For example, we achieve [65.5%, 77.8%, 77.7%] in terms of [1% few-shot, 10% few-shot, linear evaluation] performance, improving over MIM and CL baselines by [14.4%, 4.5%, 9.7%] and [2.1%, 2.4%, 1.0%], respectively.

2.1. PRELIMINARY AND OVERVIEW

In Contrastive Learning (CL), the learning target is to pull the positive pairs together in the feature space while pushing negative pairs apart. Formally, the loss can be defined as: M(v i , v + i , V -, τ ) = 1 N N i=1 -log exp v i • v + i /τ exp v i • v + i /τ + v - i ∈V -exp v i • v - i /τ where (v i , v + i ) represents features of the positive pairs while (v i , v - i ) means features of negative pairs. Also, V -is the pool of negative features. τ denotes the temperature. N is the number of samples. In practice, the positive pairs are often the different augmented views from the same image while the negative pool is composed by all the views from different images (Chen et al., 2021) . On the other hand, Mask Image Modeling (MIM) learns to reconstruct a corrupted image where some parts of the image or feature map are masked out. The learning target can be formulated as: L(x i , M ) = 1 N N i=1 D(d(f (M x i )), x i ) where x i and M are input images and randomly generated masks, respectively. f and d represent the encoding and decoding functions, respectively. d(f (M x i )) is the generated image conditioned by masked image M x i . D measures the difference between d(f (M x i )) and the original image x i . Overview. In the following parts of this section, we first introduce our preliminary exploration on the MTL of MIM and CL tasks in Section 2.2, which reveals the existence of the conflicting gradient direction. Afterward, in Section 2.3, we provide a simple separating idea towards mitigating the conflicts, which further leads to the proposed Layer Grafted Pre-training in Section 2.4.

