PREDICTING WHAT YOU ALREADY KNOW HELPS: PROVABLE SELF-SUPERVISED LEARNING

Abstract

Self-supervised representation learning solves auxiliary prediction tasks (known as pretext tasks), that do not require labeled data, to learn semantic representations. These pretext tasks are created solely using the input features, such as predicting a missing image patch, recovering the color channels of an image from context, or predicting missing words in text, yet predicting this known information helps in learning representations effective for downstream prediction tasks. This paper posits a mechanism based on approximate conditional independence to formalize how solving certain pretext tasks can learn representations that provably decrease the sample complexity of downstream supervised tasks. Formally, we quantify how the approximate independence between the components of the pretext task (conditional on the label and latent variables) allows us to learn representations that can solve the downstream task with drastically reduced sample complexity by just training a linear layer on top of the learned representation.

1. INTRODUCTION

Self-supervised learning revitalizes machine learning models in computer vision, language modeling, and control problems (see reference therein (Jing & Tian, 2020; Kolesnikov et al., 2019; Devlin et al., 2018; Wang & Gupta, 2015; Jang et al., 2018) ). Training a model with auxiliary tasks based only on input features reduces the extensive costs of data collection and semantic annotations for downstream tasks. It is also known to improve the adversarial robustness of models (Hendrycks et al., 2019; Carmon et al., 2019; Chen et al., 2020a) . Self-supervised learning creates pseudo labels solely based on input features, and solves auxiliary prediction tasks in a supervised manner (pretext tasks). However, the underlying principles of self-supervised learning are mysterious since it is a-priori unclear why predicting what we already know should help. We thus raise the following question:

What conceptual connection between pretext and downstream tasks ensures good representations?

What is a good way to quantify this? As a thought experiment, consider a simple downstream task of classifying desert, forest, and sea images. A meaningful pretext task is to predict the background color of images (known as image colorization (Zhang et al., 2016) ). Denote X 1 , X 2 , Y to be the input image, color channel, and the downstream label respectively. Given knowledge of the label Y , one can possibly predict the background X 2 without knowing much about X 1 . In other words, X 2 is approximately independent of X 1 conditional on the label Y . Consider another task of inpainting (Pathak et al., 2016) the front of a building (X 2 ) from the rest (X 1 ). While knowing the label "building" (Y ) is not sufficient for successful inpainting, adding additional latent variables Z such as architectural style, location, window positions, etc. will ensure that variation in X 2 given Y, Z is small. We can mathematically interpret this as X 1 being approximate conditionally independent of X 2 given Y, Z. In the above settings with conditional independence, the only way to solve the pretext task for X 1 is to first implicitly predict Y and then predict X 2 from Y . Even without labeled data, the information of Y is hidden in the prediction for X 2 . Contributions. We propose a mechanism based on approximate conditional independence (ACI) to explain why solving pretext tasks created from known information can learn representations that provably reduce downstream sample complexity. For instance, learned representation will only require Õ(k) samples to solve a k-way supervised task under conditional independence (CI). Under ACI (quantified by the norm of a certain partial covariance matrix), we show similar sample complexity improvements. We verify our main Theorem (4.2) using simulations. We check that pretext task helps when CI is approximately satisfied in text domain, and demonstrate on a real-world image dataset that a pretext task-based linear model outperforms or is comparable to many baselines. 1 et al., 2018; Oord et al., 2018; Tian et al., 2019) . A popular approach for text domain is based on language modeling where models like BERT and GPT create auxiliary tasks for next word predictions (Devlin et al., 2018; Radford et al., 2018) . The natural ordering or topology of data is also exploited in video-based (Wei et al., 2018; Misra et al., 2016; Fernando et al., 2017 ), graph-based (Yang et al., 2020; Hu et al., 2019) or map-based (Zhang et al., 2019) self-supervised learning. For instance, the pretext task is to determine the correct temporal order for video frames as in (Misra et al., 2016) . Theory for self-supervised learning: Our work initiates some theoretical understanding on the reconstruction-based SSL. Related to our work is the recent theoretical analysis of contrastive learning. Arora et al. (2019) shows guarantees for representations from contrastive learning on linear classification tasks using a class conditional independence assumption, but do not handle approximate conditional independence. Recently, Tosh et al. (2020a) show that contrastive learning representations can linearly recover any continuous functions of the underlying topic posterior under a topic modeling assumption for text. While their assumption bears some similarity to ours, the assumption of independent sampling of words that they exploit is strong and not generalizable to other domains like images. More recently, concurrent work by Tosh et al. (2020b) shows guarantees for contrastive learning, but not reconstruction-based SSL, with a multi-view redundancy assumptions that is very similar to our CI assumption. (Wang & Isola, 2020) theoretically studies contrastive learning on the hypersphere through intuitive properties like alignment and uniformity of representations; however there is no theoretical connection made to downstream tasks. There is a mutual information maximization view of contrastive learning, but (Tschannen et al., 2019) points out issues with it. Previous attempts to explain negative sampling (Mikolov et al., 2013) based methods use the theory of noise contrastive estimation (Gutmann & Hyvärinen, 2010; Ma & Collins, 2018) . However, guarantees are only asymptotic and not for downstream tasks. CI is also used in sufficient dimension reduction Fukumizu et al. (2009; 2004) . CI and redundancy assumptions on multiple views (Kakade & Foster, 2007; Ando & Zhang, 2007) are used to analyze a canonical-correlation based dimension reduction algorithm. Finally, (Alain & Bengio, 2014; Vincent, 2011) provide a theoretical analysis for denoising auto-encoder.

1.2. OVERVIEW OF RESULTS:

Section 2 introduces notation, setup, and the self-supervised learning procedure considered in this work. In Section 3, we analyze downstream sample complexity under CI. Section 4 presents our main result with relaxed conditions: under ACI with latent variables, and assuming finite samples in both pretext and downstream tasks, for various function classes, and both regression and classification tasks. Experiments verifying our theoretical findings are in Section 5.



.1 RELATED WORK Self-supervised learning (SSL) methods in practice: There has been a flurry of self-supervised methods lately. One class of methods reconstruct images from corrupted or incomplete versions of it, like denoising auto-encoders(Vincent et al., 2008), image inpainting(Pathak et al., 2016), and split-brain autoencoder(Zhang et al., 2017). Pretext tasks are also created using visual common sense, including predicting rotation angle(Gidaris et al., 2018), relative patch position(Doersch  et al., 2015), recovering color channels(Zhang et al., 2016), solving jigsaw puzzle games(Noroozi  & Favaro, 2016), and discriminating images created from distortion(Dosovitskiy et al., 2015). We refer to the above procedures as reconstruction-based SSL. Another popular paradigm is contrastive learning(Chen et al., 2020b;c). The idea is to learn representations that bring similar data points closer while pushing randomly selected points further away(Wang & Gupta, 2015; Logeswaran &  Lee, 2018; Arora et al., 2019)  or to maximize a contrastive-based mutual information lower bound between different views (Hjelm

