SELF-SUPERVISED REPRESENTATION LEARNING WITH RELATIVE PREDICTIVE CODING

Abstract

This paper introduces Relative Predictive Coding (RPC), a new contrastive representation learning objective that maintains a good balance among training stability, minibatch size sensitivity, and downstream task performance. The key to the success of RPC is two-fold. First, RPC introduces the relative parameters to regularize the objective for boundedness and low variance. Second, RPC contains no logarithm and exponential score functions, which are the main cause of training instability in prior contrastive objectives. We empirically verify the effectiveness of RPC on benchmark vision and speech self-supervised learning tasks. Lastly, we relate RPC with mutual information (MI) estimation, showing RPC can be used to estimate MI with low variance 1 .

1. INTRODUCTION

Unsupervised learning has drawn tremendous attention recently because it can extract rich representations without label supervision. Self-supervised learning, a subset of unsupervised learning, learns representations by allowing the data to provide supervision (Devlin et al., 2018) . Among its mainstream strategies, self-supervised contrastive learning has been successful in visual object recognition (He et al., 2020; Tian et al., 2019; Chen et al., 2020c) , speech recognition (Oord et al., 2018; Rivière et al., 2020) , language modeling (Kong et al., 2019) , graph representation learning (Velickovic et al., 2019) and reinforcement learning (Kipf et al., 2019) . The idea of self-supervised contrastive learning is to learn latent representations such that related instances (e.g., patches from the same image; defined as positive pairs) will have representations within close distance, while unrelated instances (e.g., patches from two different images; defined as negative pairs) will have distant representations (Arora et al., 2019) . Prior work has formulated the contrastive learning objectives as maximizing the divergence between the distribution of related and unrelated instances. In this regard, different divergence measurement often leads to different loss function design. For example, variational mutual information (MI) estimation (Poole et al., 2019) inspires Contrastive Predictive Coding (CPC) (Oord et al., 2018) . Note that MI is also the KL-divergence between the distributions of related and unrelated instances (Cover & Thomas, 2012) . While the choices of the contrastive learning objectives are abundant (Hjelm et al., 2018; Poole et al., 2019; Ozair et al., 2019) , we point out that there are three challenges faced by existing methods. The first challenge is the training stability, where an unstable training process with high variance may be problematic. For example, Hjelm et al. (2018); Tschannen et al. (2019); Tsai et al. (2020b) show that the contrastive objectives with large variance cause numerical issues and have a poor downstream performance with their learned representations. The second challenge is the sensitivity to minibatch size, where the objectives requiring a huge minibatch size may restrict their practical usage. For instance, SimCLRv2 (Chen et al., 2020c) utilizes CPC as its contrastive objective and reaches state-of-the-art performances on multiple self-supervised and semi-supervised benchmarks. Nonetheless, the objective is trained with a minibatch size of 8, 192, and this scale of training requires enormous computational power. The third challenge is the downstream task performance, which is the one that we would like to emphasize the most. For this reason, in most cases, CPC represents the distribution of related samples (positively-paired), and PX PY represents the distribution of unrelated samples (negatively-paired). f (x, y) ∈ F for F being any class of functions f : X × Y → R. †: Compared to JCPC and JRPC, we empirically find JWPC performs worse on complex real-world image datasets spanning CIFAR-10/-100 (Krizhevsky et al., 2009) ,yj ) relating to JS-divergence between PXY and PX PY : JDV(X, Y ) := sup f ∈F EP XY [f (x, y)] -log(EP X PY [e f (x,y) ]) JNWJ(X, Y ) := sup f ∈F EP XY [f (x, y)] -EP X PY [e f (x,y)-1 ] JCPC(X, Y ) := sup f ∈F E (x,y1)∼PXY ,{yj } N j=2 ∼PY log e f (x,y1) / 1 N N j=1 e f (x JJS (Nowozin et al., 2016) JJS(X, Y ) := sup f ∈F EP XY [-log(1 + e -f (x,y) )] -EP X PY [log(1 + e f (x,y) )] relating to Wasserstein-divergence between PXY and PX PY : JWPC (Ozair et al., 2019), with FL denoting the space of 1-Lipschitz functions JWPC(X, Y ) := sup f ∈FL E (x,y1)∼PXY ,{yj } N j=2 ∼PY log e f (x,y1) / 1 N N j=1 e f (x,yj ) † relating to χ 2 -divergence between PXY and PX PY : JRPC (ours) JRPC(X, Y ) := sup f ∈F EP XY [f (x, y)] -αEP X PY [f (x, y)] -β 2 EP XY f 2 (x, y) -γ 2 EP X PY f 2 (x, y) is the objective that we would adopt for contrastive representation learning, due to its favorable performance in downstream tasks (Tschannen et al., 2019; Baevski et al., 2020) . This paper presents a new contrastive representation learning objective: the Relative Predictive Coding (RPC), which attempts to achieve a good balance among these three challenges: training stability, sensitivity to minibatch size, and downstream task performance. At the core of RPC is the relative parameters, which are used to regularize RPC for its boundedness and low variance. From a modeling perspective, the relative parameters act as a 2 regularization for RPC. From a statistical perspective, the relative parameters prevent RPC from growing to extreme values, as well as upper bound its variance. In addition to the relative parameters, RPC contains no logarithm and exponential, which are the main cause of the training instability for prior contrastive learning objectives (Song & Ermon, 2019). To empirically verify the effectiveness of RPC, we consider benchmark self-supervised representation learning tasks, including visual object classification on CIFAR-10/-100 (Krizhevsky et al., 2009) , STL-10 (Coates et al., 2011), and ImageNet (Russakovsky et al., 2015) and speech recognition on LibriSpeech (Panayotov et al., 2015) . Comparing RPC to prior contrastive learning objectives, we observe a lower variance during training, a lower minibatch size sensitivity, and consistent performance improvement. Lastly, we also relate RPC with MI estimation, empirically showing that RPC can estimate MI with low variance.

2. PROPOSED METHOD

This paper presents a new contrastive representation learning objective -the Relative Predictive Coding (RPC). At a high level, RPC 1) introduces the relative parameters to regularize the objective for boundedness and low variance; and 2) achieves a good balance among the three challenges in the contrastive representation learning objectives: training stability, sensitivity to minibatch size, and downstream task performance. We begin by describing prior contrastive objectives along with their limitations on the three challenges in Section 2.1. Then, we detail our presented objective and its modeling benefits in Section 2.2. An overview of different contrastive learning objectives is provided in Table 1 . We defer all the proofs in Appendix.

Notation

We use an uppercase letter to denote a random variable (e.g., X), a lower case letter to denote the outcome of this random variable (e.g., x), and a calligraphy letter to denote the sample space of this random variable (e.g., X ). Next, if the samples (x, y) are related (or positively-paired), we refer (x, y) ∼ P XY with P XY being the joint distribution of X × Y . If the samples (x, y) are unrelated (negatively-paired), we refer (x, y) ∼ P X P Y with P X P Y being the product of marginal distributions over X × Y . Last, we define f ∈ F for F being any class of functions f : X × Y → R.

2.1. PRELIMINARY

Contrastive representation learning encourages the contrastiveness between the positive and the negative pairs of the representations from the related data X and Y . Specifically, when sampling a pair



Project page: https://github.com/martinmamql/relative_predictive_coding



and ImageNet (Russakovsky et al., 2015). -divergence between PXY and PX PY : JDV(Donsker & Varadhan, 1975), JNWJ (Nguyen et al., 2010), and JCPC (Oord et al., 2018)

