TOWARDS THE GENERALIZATION OF CONTRASTIVE SELF-SUPERVISED LEARNING

Abstract

Recently, self-supervised learning has attracted great attention, since it only requires unlabeled data for model training. Contrastive learning is one popular method for self-supervised learning and has achieved promising empirical performance. However, the theoretical understanding of its generalization ability is still limited. To this end, we define a kind of (σ, δ)-measure to mathematically quantify the data augmentation, and then provide an upper bound of the downstream classification error rate based on the measure. It reveals that the generalization ability of contrastive self-supervised learning is related to three key factors: alignment of positive samples, divergence of class centers, and concentration of augmented data. The first two factors are properties of learned representations, while the third one is determined by pre-defined data augmentation. We further investigate two canonical contrastive losses, InfoNCE and cross-correlation, to show how they provably achieve the first two factors. Moreover, we conduct experiments to study the third factor, and observe a strong correlation between downstream performance and the concentration of augmented data.

1. INTRODUCTION

Contrastive Self-Supervised Learning (SSL) has attracted great attention for its fantastic data efficiency and generalization ability in computer vision (He et al., 2020; Chen et al., 2020a; b; Grill et al., 2020; Chen & He, 2021; Zbontar et al., 2021) and natural language processing (Fang et al., 2020; Wu et al., 2020; Giorgi et al., 2020; Gao et al., 2021; Yan et al., 2021) . It learns the representation through a large number of unlabeled data and manually designed supervision signals (i.e., regarding the augmented views of a data sample as positive samples). The model is updated by encouraging the features of positive samples close to each other. To overcome the feature collapse issue, various losses (e.g., InfoNCE (Chen et al., 2020a; He et al., 2020) and cross-correlation (Zbontar et al., 2021) ) and training strategies (e.g., stop gradient (Grill et al., 2020; Chen & He, 2021 )) are proposed. In spite of the empirical success of contrastive SSL in terms of their generalization ability on downstream tasks, the theoretical understanding is still limited. Arora et al. (2019) propose a theoretical framework to show the provable downstream performance of contrastive SSL based on the InfoNCE loss. However, their results rely on the assumption that positive samples are drawn from the same latent class, instead of the augmented views of a data point as in practice. Wang & Isola (2020) propose alignment and uniformity to explain the downstream performance, but they are empirical indicators and lack of theoretical generalization guarantees. Both of the above works avoid characterizing the important role of data augmentation, which is the key to the success of contrastive SSL, since the only human knowledge is injected via data augmentation. Recently, HaoChen et al. (2021) propose to model the augmented data as a graph and study contrastive SSL from a matrix decomposition perspective, but it is only applicable to their own spectral contrastive loss. Besides the limitations of existing contrastive SSL theories, there are also some interesting empirical observations that have not been unraveled theoretically yet. For example, why does the richer data In this paper, we focus on exploring the generalization ability of contrastive SSL provably, which can explain the above interesting observations. We start with understanding the role of data augmentation in contrastive SSL. Intuitively, samples from the same latent class are likely to have similar augmented views, which are mapped to the close locations in the embedding space. Since the augmented views of each sample are encouraged to be clustered in the embedding space by contrastive learning, different samples from the same latent class tend to be pulled closer. As an example, let's consider two images of dogs with different backgrounds (Figure 2 ). If we augment them with transformation "crop", we may get two similar views (dog heads), whose representations (gray points in the embedding space) are close. As the augmented views of each dog image are enforced to be close in the embedding space due to the objective of contrastive learning, the representations of two dog images (green and blue points) will be pulled closer to their augmented views (gray points). In this way, aligning positive samples is able to gather samples from the same class, and thus results in the clustered embedding space. Following the above intuition, we define the augmented distance between two samples as the minimum distance between their augmented views, and further introduce the (σ, δ)-augmentation to measure the concentration of augmented data, i.e., for each latent class, the proportion of samples located in a ball with diameter δ (w.r.t. the augmented distance) is larger than σ. With the mathematical description of data augmentation settled, we then prove an upper bound of downstream classification error rate in Section 3. It reveals that the generalization of contrastive SSL is related to three key factors. The first one is alignment of positive samples, which is a common objective that contrastive learning algorithms aim to optimize. The second one is divergence of class centers, which prevents the collapse of representation. The third factor is concentration of augmented data, i.e., a sharper concentration of augmented data indicates a better generalization error bound. We remark that the first two factors are properties of representations that can be optimized during the learning process. However, the third factor is determined by pre-defined data augmentation and is independent of the learning process. Thus, data augmentation plays a crucial role in contrastive SSL. We then study the above three factors in more depth. In Section 4, we rigorously prove that not only the InfoNCE loss but also the cross-correlation loss (which does not directly optimize the geometry of embedding space) can satisfy the first two factors. For the third factor, we conduct various experiments on the real-world datasets and observe that the downstream performance of contrastive SSL is highly correlated to the concentration of augmented data in Section 5. In summary, our contributions include: 1) proposing a novel (σ, δ)-measure to quantify data augmentation; 2) presenting a theoretical framework for contrastive SSL that highlights alignment,



Figure 1: SimCLR's embedding space with different richnesses of data augmentations on CIFAR-10.

