MAST: MASKED AUGMENTATION SUBSPACE TRAIN-ING FOR GENERALIZABLE SELF-SUPERVISED PRIORS

Abstract

Recent Self-Supervised Learning (SSL) methods are able to learn feature representations that are invariant to different data augmentations, which can then be transferred to downstream tasks of interest. However, different downstream tasks require different invariances for their best performance, so the optimal choice of augmentations for SSL depends on the target task. In this paper, we aim to learn self-supervised features that generalize well across a variety of downstream tasks (e.g., object classification, detection and instance segmentation) without knowing any task information beforehand. We do so by Masked Augmentation Subspace Training (or MAST) to encode in the single feature space the priors from different data augmentations in a factorized way. Specifically, we disentangle the feature space into separate subspaces, each induced by a learnable mask that selects relevant feature dimensions to model invariance to a specific augmentation. We show the success of MAST in jointly capturing generalizable priors from different augmentations, using both unique and shared features across the subspaces. We further show that MAST benefits from uncertainty modeling to reweight ambiguous samples from strong augmentations that may cause similarity mismatch in each subspace. Experiments demonstrate that MAST consistently improves generalization on various downstream tasks, while being task-agnostic and efficient during SSL. We also provide interesting insights about how different augmentations are related and how uncertainty reflects learning difficulty.

1. INTRODUCTION

Self-Supervised Learning (SSL) for image representation has made significant progress over the past few years. The feature representations are typically learned to be invariant to different data augmentations (e.g., Random Flip and Color Jitter). For example, the popular contrastive SSL methods (Chen et al., 2020a; He et al., 2020) learn invariances by discriminating augmented views of the same image (positive pair) from those of different images (negative pair), while recent noncontrastive SSL methods (Chen & He, 2021; Grill et al., 2020; Bardes et al., 2022) simply maximize the similarity between positive pairs. Such learned features are shown to generalize across many downstream tasks, including classification, object detection, instance segmentation, etc. Despite achieving strong transfer performance, we lack a good theoretical understanding about why both contrastive and non-contrastive SSL methods generalize so well. Balestriero & LeCun (2022) recently proposed a unified framework that demonstrates the key for generalization lies in the alignment between the pairwise relation in SSL (characterized by augmented inputs) and downstream task. This is also in line with other theories (Arora et al., 2019; HaoChen et al., 2021) that quantify how data augmentations implicitly encode the class distributions in downstream tasks. Motivated by these theoretical analyses, we aim at a working SSL method that can directly capture meaningful priors from data augmentations in order to encourage generalization for a range of tasks. Invariance, as achieved through augmentation, is a useful mechanism to facilitate generalization. For example, one can imagine that invariance to Random Flip will boost generalization on many vision tasks. However, as shown in Fig. 1 , the optimal set of augmentations (thus invariances) highly depends on the downstream task, which has been similarly observed in (Tian et al., 2020) . Sometimes, invariances (e.g., to Color Jitter) that prove helpful for one downstream task (e.g., object detection) Crop helps for all tasks, other augmentations encode their specific invariances that can be more relevant for some tasks than others. Color even hurts Performance drop after removing one augmentation indicates that it helps the corresponding task, while performance gain indicates harm. We observe that Random-ResizedCrop strongly benefits all tasks. Other augmentations encode their specific invariances that can be more helpful for some tasks than others. Some augmentations (e.g., Color Jitter) even hurt one task (CUB-200 bird classification) despite being quite helpful for other tasks. may even hurt generalization on another (e.g., birds classification which requires accurate color information of similarly shaped bird species). Hence it is impossible to maximize generalization by finding good augmentations for SSL in a task-dependent way, not only because the invariances learned across augmentations may contradict each other, but also because we often do not know the target task a priori during SSL. Furthermore, manually finding suitable augmentations for generalizing to a new task is quickly cumbersome. In this paper, we propose a new SSL method to learn generalizable features without presuming any downstream task information. Our method called Masked Augmentation Subspace Training (MAST) achieves this goal by learning a single but disentangled feature space to encode a set of potentially contradicting augmentation invariances. For each augmentation, we learn invariance to it in a specialized feature subspace, which is induced from the full space with a learned masking operation. This allows us to learn unique features for each augmentation subspace as well as features shared between them. There are two main benefits of such subspace training: 1) we explicitly avoid feature suppression (Li et al., 2020) by jointly training separate feature subspaces, such that the learning of one subspace will not compromise that of another; 2) we obtain a disentangled, full feature space that is pre-trained in a task-agnostic way, but does not discard the diverse information (i.e., invariances) required by all possible downstream tasks. We further model uncertainties in the feature representations to reduce harm of ambiguous samples from strong augmentations. In order to examine how representation effectiveness scales with augmentation diversity, we run experiments with different numbers of augmentations, starting from 5 standard augmentations used typically in SSL (Chen et al., 2020a) , and extending to an additional 10 augmentations (totaling 15). Note Tian et al. (2020) and Wang & Qi (2021) also learn from stronger augmentations, but not in our factorized and uncertainty-aware fashion. When it comes to transfer learning on downstream tasks, we simply drop our subspace masks and finetune the full feature representations for high efficiency. This is in contrast to LooC (Xiao et al., 2021) which needs to combine multiple feature "heads", with one head being invariant to all augmentations and other heads being sensitive to a particular augmentation but invariant to others. Both the "leave-one-out" training and feature ensembling strategies in LooC lead to redundancy of parameters and high cost (thus only allowing a few augmentation-specific heads). Experiments show that MAST, while being efficient and task-agnostic, achieves state-of-the-art transfer performance on diverse downstream vision tasks. Investigations of the subspace masks and uncertainties also provide interesting insights in how different augmentations are related, and in how uncertainty reflects learning difficulty to avoid similarity mismatch during invariance learning. To summarize, here are our main contributions: • We introduce MAST to make SSL representations disentangled and uncertainty-aware to effectively encode different augmentation invariances for good generalization. • We show MAST is efficient, is resistant to feature suppression, and achieves state-of-the-art downstream performance on diverse vision tasks without presuming any task information during pre-training.



Figure1: The change of downstream performance per task due to the selective removal of each augmentation during SSL, relative to one baseline SSL method MoCo (He et al., 2020) trained with 5 standard data augmentations, including Color Jitter, Gaussian Blur, Random Flip, Random Grayscale, RandomResizedCrop. Performance drop after removing one augmentation indicates that it helps the corresponding task, while performance gain indicates harm. We observe that Random-ResizedCrop strongly benefits all tasks. Other augmentations encode their specific invariances that can be more helpful for some tasks than others. Some augmentations (e.g., Color Jitter) even hurt one task (CUB-200 bird classification) despite being quite helpful for other tasks.

