EXPLORING THE ROLE OF MEAN TEACHERS IN SELF-SUPERVISED MASKED AUTO-ENCODERS

Abstract

Masked image modeling (MIM) has become a popular strategy for self-supervised learning (SSL) of visual representations with Vision Transformers. A representative MIM model, the masked auto-encoder (MAE), randomly masks a subset of image patches and reconstructs the masked patches given the unmasked patches. Concurrently, many recent works in self-supervised learning utilize the student/teacher paradigm which provides the student with an additional target based on the output of a teacher composed of an exponential moving average (EMA) of previous students. Although common, relatively little is known about the dynamics of the interaction between the student and teacher. Through analysis on a simple linear model, we find that the teacher conditionally removes previous gradient directions based on feature similarities which effectively acts as a conditional momentum regularizer. From this analysis, we present a simple SSL method, the Reconstruction-Consistent Masked Auto-Encoder (RC-MAE) by adding an EMA teacher to MAE. We find that RC-MAE converges faster and requires less memory usage than state-of-the-art self-distillation methods during pre-training, which may provide a way to enhance the practicality of prohibitively expensive selfsupervised learning of Vision Transformer models. Additionally, we show that RC-MAE achieves more robustness and better performance compared to MAE on downstream tasks such as ImageNet-1K classification, object detection, and instance segmentation.

1. INTRODUCTION

The Transformer (Vaswani et al., 2017) is the de facto standard architecture in natural language processing (NLP), and has also surpassed state-of-the-art Convolutional Neural Network (He et al., 2016; Tan & Le, 2019) (CNN) feature extractors in vision tasks through models such as the Vision Transformer (Dosovitskiy et al., 2021) (ViT) . Prior to the advent of ViTs, self-supervised learning (SSL) algorithms in the vision community (He et al., 2020; Chen et al., 2020c; Grill et al., 2020; Chen et al., 2021 ) utilized CNNs (e.g., ResNet (He et al., 2016) ) as a backbone, performing instance discrimination pretext tasks through contrastive learning (He et al., 2020; Chen et al., 2020c) . Interestingly, self-distillation schemes (Grill et al., 2020; Caron et al., 2021) using a teacher consisting of an exponential moving average (EMA) of the previous students, (i.e., a "mean" teacher) (Tarvainen & Valpola, 2017) , have been shown to exhibit strong performance. Inspired by the success of masked language modeling (MLM) pre-training in NLP, recent SSL approaches (Bao et al., 2022; Zhou et al., 2022; Xie et al., 2022; He et al., 2022; Assran et al., 2022) in the vision community have proposed forms of masked image modeling (MIM) pretext tasks, using ViT-based backbones. MIM is a simple pretext task which first randomly masks patches of an image, and then predicts the contents of the masked patches (i.e., tokens) using various reconstruction targets, e.g., visual tokens (Bao et al., 2022; Dong et al., 2021) , semantic features (Zhou et al., 2022; Assran et al., 2022) and raw pixels (He et al., 2022; Xie et al., 2022) . In particular, iBOT (Zhou et al., 2022) and MSN (Assran et al., 2022) the original image at a semantic feature level (i.e., global abstraction). Methods using semantic-level target representations exhibit strong performance on image-level classification tasks. On the contrary, SimMIM (Xie et al., 2022) and MAE (He et al., 2022) provide pixel-level reconstructions of masked patches, and lead to superior performance on dense prediction tasks such as object detection and segmentation. However, self-distillation for pixel-level MIM has been under-explored as of yet. A recent SSL approach, BYOL (Grill et al., 2020) , has shown that a slight architectural asymmetry between a student and EMA teacher can create a stable model which outperforms previous contrastive learning methods. The success of BYOL (Grill et al., 2020) inspired empirical (Chen & He, 2021) and theoretical (Tian et al., 2021) analyses into what enables BYOL to effectively learn and avoid collapse with the EMA Teacher during pre-training. Still, despite the popularity of the EMA Teacher in SSL, relatively little is known about how the teacher interacts with the student throughout the training process. In this work, we explore the dynamics of self-distillation in pixel-level MIM, e.g., MAE. Through analyzing a simple linear model, we investigate the dynamics between the gradients of an image reconstruction loss and a teacher consistency loss, learning that the gradients provided by the teacher's consistency loss conditionally adjust the current reconstruction gradient by a weighted mixture of previous gradients. The weights of the mixture are derived from similarities between current and previous features. Thus, the teacher acts like a conditional momentum regularizer. For example, Fig. 1(a) shows the case where the inputs which created the previous gradient momentum are similar to the ones which created the current gradients. In this case, the teacher makes a conditional correction to remove the previous direction from the momentum, allowing the student to learn from the newer knowledge in the current batch. If however, the inputs which created both gradients are nearly orthogonal, the teacher would instead respond with minimal to no correction. We derive this conditional gradient effect in Proposition 4.1, and show evidence in both a simple linear model as well as in a deep ViT-based (Dosovitskiy et al., 2021) MAE model (Fig. 2 ). To empirically validate our analysis of the contributions of EMA Teachers, we present a simple yet effective SSL approach, the Reconstruction-Consistent Masked Auto-Encoder (RC-MAE), by equipping MAE with an EMA Teacher, and providing a consistency target. Additionally, we study the effects of using different image masking strategies between the student and teacher models on the consistency objective, finding that using the same mask generally leads to better performance in both pre-training and downstream tasks. The same mask tends to form an orthogonal objective (Fig. 3(b) ) to the reconstruction loss, which has been shown (Suteu & Guo, 2019; Ajemian et al., 2013) to be beneficial for multi-task models as there is limited interference between tasks. This observation may be of interest to any future SSL works which leverage multiple pre-training objectives.



Figure 1: Overview. (a): When the inputs which led to the previous gradients and current gradients are similar, the consistency gradient provides a conditional correction, allowing the student to learn from newer knowledge. (b): In RC-MAE, the reconstructed patches from the student are compared with the original input (reconstruction loss L r ), and with the predicted patches from the teacher (consistency loss L c ). (c): ImageNet-1K Fine-tuning top-1 accuracy curve: RC-MAE achieves comparable accuracy (83.4%) at 800 epochs compared to MAE trained for 1600 epochs.

