MAST: MASKED AUGMENTATION SUBSPACE TRAIN-ING FOR GENERALIZABLE SELF-SUPERVISED PRIORS

Abstract

Recent Self-Supervised Learning (SSL) methods are able to learn feature representations that are invariant to different data augmentations, which can then be transferred to downstream tasks of interest. However, different downstream tasks require different invariances for their best performance, so the optimal choice of augmentations for SSL depends on the target task. In this paper, we aim to learn self-supervised features that generalize well across a variety of downstream tasks (e.g., object classification, detection and instance segmentation) without knowing any task information beforehand. We do so by Masked Augmentation Subspace Training (or MAST) to encode in the single feature space the priors from different data augmentations in a factorized way. Specifically, we disentangle the feature space into separate subspaces, each induced by a learnable mask that selects relevant feature dimensions to model invariance to a specific augmentation. We show the success of MAST in jointly capturing generalizable priors from different augmentations, using both unique and shared features across the subspaces. We further show that MAST benefits from uncertainty modeling to reweight ambiguous samples from strong augmentations that may cause similarity mismatch in each subspace. Experiments demonstrate that MAST consistently improves generalization on various downstream tasks, while being task-agnostic and efficient during SSL. We also provide interesting insights about how different augmentations are related and how uncertainty reflects learning difficulty.

1. INTRODUCTION

Self-Supervised Learning (SSL) for image representation has made significant progress over the past few years. The feature representations are typically learned to be invariant to different data augmentations (e.g., Random Flip and Color Jitter). For example, the popular contrastive SSL methods (Chen et al., 2020a; He et al., 2020) learn invariances by discriminating augmented views of the same image (positive pair) from those of different images (negative pair), while recent noncontrastive SSL methods (Chen & He, 2021; Grill et al., 2020; Bardes et al., 2022) simply maximize the similarity between positive pairs. Such learned features are shown to generalize across many downstream tasks, including classification, object detection, instance segmentation, etc. Despite achieving strong transfer performance, we lack a good theoretical understanding about why both contrastive and non-contrastive SSL methods generalize so well. Balestriero & LeCun (2022) recently proposed a unified framework that demonstrates the key for generalization lies in the alignment between the pairwise relation in SSL (characterized by augmented inputs) and downstream task. This is also in line with other theories (Arora et al., 2019; HaoChen et al., 2021) that quantify how data augmentations implicitly encode the class distributions in downstream tasks. Motivated by these theoretical analyses, we aim at a working SSL method that can directly capture meaningful priors from data augmentations in order to encourage generalization for a range of tasks. Invariance, as achieved through augmentation, is a useful mechanism to facilitate generalization. For example, one can imagine that invariance to Random Flip will boost generalization on many vision tasks. However, as shown in Fig. 1 , the optimal set of augmentations (thus invariances) highly depends on the downstream task, which has been similarly observed in (Tian et al., 2020) . Sometimes, invariances (e.g., to Color Jitter) that prove helpful for one downstream task (e.g., object detection) 1

