A THEORETICAL STUDY OF INDUCTIVE BIASES IN CONTRASTIVE LEARNING

Abstract

Understanding self-supervised learning is important but challenging. Previous theoretical works study the role of pretraining losses, and view neural networks as general black boxes. However, the recent work of Saunshi et al. (2022) argues that the model architecture -a component largely ignored by previous works -also has significant influences on the downstream performance of selfsupervised learning. In this work, we provide the first theoretical analysis of self-supervised learning that incorporates the effect of inductive biases originating from the model class. In particular, we focus on contrastive learning -a popular self-supervised learning method that is widely used in the vision domain. We show that when the model has limited capacity, contrastive representations would recover certain special clustering structures that are compatible with the model architecture, but ignore many other clustering structures in the data distribution. As a result, our theory can capture the more realistic setting where contrastive representations have much lower dimensionality than the number of clusters in the data distribution. We instantiate our theory on several synthetic data distributions, and provide empirical evidence to support the theory. Recent years have witnessed the effectiveness of pre-trained representations, which are learned on unlabeled data with self-supervised losses and then adapted to a wide range of downstream tasks (



A simplified version of the synthetic example proposed in Saunshi et al. (2022) . The orange points are the original data and blue points are augmented data (obtained by adding noise in the spurious dimension). The dimension invariant to augmentation is desired. Edges represent positive pairs that are constructed from augmentation. We say a real-valued function implements a cluster if it outputs 1 on the cluster and outputs 0 on all other data. We note that here implementing means matching the exact value, rather than simply matching the label after applying some linear threshold. The figure above shows two possible ways to partition the data into two clusters, but only the one on the left-hand side (which captures the invariant dimension) is implementable by a linear function. Here we use black numbers to indicate the target output on the data, and green numbers to indicate the output of the implementing function which extrapolates outside of the data support. Note that linear model is not composed with a threshold function. The partition on the right hand side is not implementable because any linear model that outputs constant 1 on the upper-left small cluster would also output 1 on the bottom-left small cluster due to linear extrapolation. Here we use red numbers to indicate the output of the linear function that contradicts with the target. on contrastive learning and can be seen as a refinement of their results by further characterizing the model architecture's impact on the learned representations. We recall that HaoChen et al. (2021) shows that contrastive learning, with sufficient data and a parameterized model class of finite complexity, is equivalent to spectral clustering on a so-called population positive-pair graph, where nodes are augmented images and an edge between the nodes x and x ′ is weighted according to the probability of encountering (x, x ′ ) as a positive pair. They essentially assume that the positive-pair graph contains several major semantically-meaningful clusters, and prove that contrastive representations exhibit a corresponding clustering structure in the Euclidean space, that is, images with relatively small graph distance have nearby representations. Their results highly rely on the clustering property of the graph-the representation dimensionality and pre-training sample complexity both scale in the number of clusters. The important recent work of Saunshi et al. (2022) , however, demonstrates with a synthetic setting that contrastive learning can provably work with linear model architectures even if the number of clusters is huge (e.g., exponential in the dimensionality). Beyond the simple synthetic example discussed in their paper, there has been no previous work that formally characterizes this effect in a general setting. In this work, we develop a general theory that leverages the inductive bias to avoid the dependency on the potentially huge number of clusters: although there exists a large number of clusters in the positive-pair graph, the number of clusters implementable by the model (which we call minimal implementable clusters) could be much smaller, even exponentially. Figure 1 shows an example where a linear function can only implement one clustering structure but not the other, despite both being valid clusters in the positive-pair graph. It's possible that a minimal implementable cluster consists of multiple well-separated sub-clusters but none of these sub-clusters can be implemented by the model class. We show that contrastive representations would only recover the clustering structures that are compatible with the model class, hence low-dimensional contrastive learned representations would work well on the downstream tasks. Concretely, suppose the number of minimal implementable clusters is m which can be much smaller than the number of natural clusters in the graph m. HaoChen et al. (2021) prove the efficacy of contrastive learning assuming the representation dimensionality (hence also sample complexity) is larger than m. We develop a new theory (Theorem 1) that makes the representation dimensionality only depend on m instead of m. We also extend this result to a more complex setting where we can deal with even more structured clusters, e.g., when there are 2 s clusters with certain geometric structures, but the representation dimensionality can scale with only s instead of 2 s . See Theorem 2 and its instantiation on Example 1 for this result. We instantiate our theory on several synthetic data distributions and show that contrastive learning with appropriate model architectures can reduce the representation dimensionality, allowing better sample complexity. We consider a data distribution on a hypercube first proposed by Saunshi et al. (2022) which contains a small subspace of features that are invariant to data augmentation and a large subspace of spurious features. When the function class is linear, we show that the contrastive representations can solve downstream binary classification tasks if the downstream label only depends on one dimension of invariant features (Theorem 3). When the function class is ReLU networks (hence more expressive), we show that the contrastive representations can solve more diverse downstream classification problems where the label can depend on all invariant features (Theorem 4). We also provide examples for Lipschitz-continuous function classes (Theorem 5) and convolutional neural networks (Theorem 6). We provide experimental results to support our theory. We propose a method to test the number of implementable clusters of ResNet-18 on the CIFAR-10 dataset and show that there are indeed only a small number of implementable clusters under the model architecture constraint (Section B).

2. RELATED WORKS

The empirical success of contrastive learning has attracted a series of theoretical works that study the contrastive loss (Arora et al., 2019; HaoChen et al., 2021; 2022; Tosh et al., 2020; 2021; Lee et al., 2020; Wang et al., 2021; Nozawa & Sato, 2021; Ash et al., 2022; Tian, 2022) , most of which treat the model class as a black box except for Lee et al. (2020) which studies the learned representation with linear models, and Tian (2022) and Wen & Li (2021) which study the training dynamics of contrastive learning for linear and 2-layer ReLU networks. Several theoretical works also study non-contrastive methods for self-supervised representation learning (Wen & Li, 2022; Tian et al., 2021; Garrido et al., 2022; Balestriero & LeCun, 2022) . There are also works theoretically studying self-supervised learning in other domains such as language modeling (Wei et al., 2021; Xie et al., 2021; Saunshi et al., 2020) .

3. FROM CLUSTERS TO MINIMAL IMPLEMENTABLE CLUSTERS

In this section, we introduce our main theoretical results regarding the role of inductive biases of architectures in contrastive learning. Recall that contrastive learning encourages two different views of the same input (also called a positive pair) to have similar representations, while two random views of two different inputs (also called a negative pair) have representations that are far from each other. Formally, we use p data to denote the distribution of a random view of random input, use p pos to denote the distribution of a random positive pair, and X to denote the support of p data . For instance, X is the set of all augmentations of all images for visual representation learning. Following the setup of HaoChen et al. (2022) , for a representation map f : X → R k where k is the representation dimensionality, we learn the contrastive representation by minimizing the following generalized spectral contrastive loss: where λ > 0 is a hyperparameter indicating the regularization strength, and the regularizer normalizes the representation covariance towards the identity matrix: L λ (f ) := E (x,x+)∼ppos [ f (x) -f (x + ) … … A positive pair Positive-pair graph R(f ) := E x∼p data [f (x)f (x) ⊤ ] -I 2 F . ( ) This loss is very similar to the popular Barlow Twins loss (Zbontar et al., 2021) and has been shown to empirically work well (HaoChen et al., 2021) . Theoretically, the prior work proposes the notion of positive-pair graph with X being the vertex set and an edge between the nodes x and x ′ is weighted according to the probability of encountering (x, x ′ ) as a positive pair (i.e., p pos (x, x ′ )). This graph is defined on the population data, and intuitively captures the semantic relationship between different data -when the positive pairs are formed by applying data augmentation to the same natural data, it is expected that datapoints in the same cluster in the positive-pair graph would have similar semantic meanings. Figure 3 gives a demonstration of the positive-pair graph. Their analysis shows that learning contrastive representations with the above loss is equivalent to spectral clustering Ng et al. (2001) on this positive-pair graph, hence can learn meaningful representations when the graph has clustering structures. Different from the prior work, we study the representation map that minimizes the contrastive loss within a certain function class F. Here we assume functions in F map data in X to representations in R k for some dimensionality k. The main contribution of our result is the improvement of k due to the consideration of this specific function class: by studying the representation learned within a constrained model class F, we will show that the necessary representation dimensionality k is much smaller than that required in the prior work. As a result, the sample complexity for the downstream labeled task would be improved compared to the prior work. Let {S 1 , S 2 , • • • , S m } be a m-way partition of X , i.e., they are disjoint non-empty subsets of X such that X = ∪ i∈[m] S i . For any x ∈ X , let id x be the index such that x ∈ S idx . We consider a partition of the graph such that there is not much connection between any two clusters, which is formalized by the following assumption. Assumption 1 (α-separability). The probability of a positive pair belonging to two different sets is less than α: Pr (x,x + )∼ppos (id x ̸ = id x + ) ≤ α. We consider downstream tasks that are r-way classification problems with label function y(•) : X → [r]. We assume that the downstream task aligns with the clusters: Assumption 2. The downstream label y(x) is a constant on each S i . Our key assumptions about the function class are that it can implement desirable clustering structures (Assumption 4) but cannot break the positive-pair graph into too many clusters (Assumption 3). Let S ⊂ X be a subset of X , p S data be the distribution p data restricted to set S, and p S pos be the positive pair distribution p pos conditioned on both datapoints in the pair belonging to set S. For any function g : S → R, we define the following expansion quantity: Q S (g) := E (x,x + )∼p S pos [(g(x) -g(x + )) 2 ] E x∼p S data ,x ′ ∼p S data [(g(x) -g(x ′ )) 2 ] . (4) We let Q S (g) = ∞ if the denominator is 0. Here the numerator represents the discrepancy between a random positive pair, and the denominator represents the global variance of g. Intuitively, a smaller value Q S (g) means that function g does a better job at separating the set S into disjoint sub-clusters, and hence implements an inner-cluster connection structure that is sparse. For instance, if S contains two disjoint sub-clusters, and g has different constant values on each of them, then Q S (g) = 0. On the other hand, if S is densely connected, then Q S (g) > 0 regardless of the choice of g. The first assumption about the function class F assumes that no function in the class can break one cluster into two well-separated sub-clusters: Assumption 3 (F-implementable inner-cluster connection larger than β). For any function f ∈ F and any linear head w ∈ R k , let function g(x) = w ⊤ f (x). For any i ∈ [m] we have that: Q Si (g) ≥ β. We note that when the function class F contains all the functions from X to R k , Assumption 3 essentially says that each of {S 1 , S 2 , • • • , S m } has large internal expansion, hence recovers Assumption 3.5 in HaoChen et al. (2021) . However, when F has limited capacity, each cluster S i can still contain well-separated sub-clusters, but just those sub-clusters cannot be implemented by functions in F. Assumption 3 implies that the function class cannot be too expressive. However, in order for the learned representation map to be useful for downstream tasks, it needs to be expressive enough to represent the useful information. Thus, we introduce the following assumption on the function class. Assumption 4 (Implementability). Recall that id x is the index such that x ∈ S idx . There exists a function f ∈ F such that f (x) = e idx for all x ∈ p data where e i ∈ R m is the vector where the i-th dimension is 1 and other dimensions are 0. When both Assumption 3 and Assumption 4 hold, we say {S 1 , S 2 , • • • , S m } are minimal implementable clusters with respect to F. We also introduce the following Assumption 5 which is true for any function class implemented by a neural network where the last layer is linear. We note that this assumption is needed only for the technical rigour of the proof, and is not essential to the conceptual message of our theory. Assumption 5 (Closure under scaling). For any function f ∈ F and vector u ∈ R m , define function f ′ (x) = u ⊙ f (x) where ⊙ means element-wise product. Then, we have f ′ ∈ F. Let P min := min i∈[m] Pr x∼p data (x ∈ S i ) and P max := max i∈[m] Pr x∼p data (x ∈ S i ) be the sizes of the smallest and largest sets respectively. Under the above assumptions, we have the following theorem that shows learning a representation map within F and representation dimensionality k = m can solve the downstream task: Theorem 1. Suppose {S 1 , S 2 , • • • , S m } are minimal implementable clusters with respect to F (i.e., Assumptions 1 and 3 hold), and the function class F satisfies Assumptions 4 and 5. For λ > α/P min , consider a learned representation map f = arg min f ∈F L λ (f ) that minimizes the contrastive loss. Then, when k = m, for any downstream task that satisfies Assumption 2, there exists a linear head W ∈ R r×k which achieves downstream error E x∼p data W f (x) -e y(x) 2 2 ≤ α β • P max P min -α . ( ) We note that P max ≈ P min when the partitions are balanced. Thus, so long as α ≪ P min (i.e., the probability of a positive pair crossing different clusters is smaller than the probability of it containing data from the smallest cluster), the right-hand side is roughly α/β. Thus, when the inter-cluster connection α is smaller than the inner-cluster connection that is implementable by the function class β, the downstream accuracy would be high. Comparison with HaoChen et al. (2021) . We note that our result requires k = m, whereas HaoChen et al. ( 2021) provides analysis in a more general setting for arbitrary k that is large enough. Thus, when the function class F is the set of all functions, our theorem recovers a special case of HaoChen et al. (2021) . Our result requires a stricter choice of k mainly because when F has limited capacity, a higher dimensional feature may contain a lot of "wrong features" while omitting the "right features", which is a phenomenon that doesn't occur when F contains all functions.

4. AN EIGENFUNCTION VIEWPOINT

In this section, we introduce an eigenfunction perspective that generalizes the theory in the previous section to more general settings. We first introduce the background on eigenfunctions and discuss its relation with contrastive learning. Then we develop a theory that incorporates the model architecture with assumptions stated using the language of eigenfunctions. The advantage over the previous section is that we can further reduce the required representation dimensionality when the minimal implementable clusters exhibit certain internal structures. Our theory relies on the notion of Laplacian operator L which maps a function g : X → R to another function L(g) : X → R defined as follows. L(g)(x) := g(x) - p pos (x, x ′ ) p data (x) g(x ′ )dx ′ . ( ) We say a function g is an eigenfunction of L with eigenvalue ψ ∈ R if for some scalar ψ E x∼p data (ψ • g(x) -L(g)(x)) 2 = 0. This essentially means that L(g) = ψ • g on the support of p data . Intuitively, small eigenfunctions (i.e., eigenfunctions with small eigenvalues) correspond to clusters in the positive-pair graph. To see this, let g implement the indicator function of cluster S, i.e., g(x) = 1 if x ∈ S, and g(x) = 0 if x / ∈ S. One can verify that L(g)(x) = 0 for all x, thus g is an eigenfunction with eigenvalue 0. In this section, we provide a generalized theory based on characterizing the implementability of eigenfunctions. Intuitively, we will assume that there exists k (and only k) orthogonal eigenfunctions in the function class with very small eigenvalue, and the downstream task can be solved by these eigenfunctions. More formally, let ϕ ≥ 0 be a very small real number (can be thought as 0), and f eig (x) : X → R k be a k-dimensional representation map in the function class F such that E (x,x + )∼ppos f eig (x) -f eig (x + ) 2 2 ≤ ϕ and E x∼p data f eig (x)f eig (x) ⊤ = I. Intuitively, when ϕ is small, each dimension of f eig corresponds to one eigenfunction of the graph Laplacian with small eigenvalue, as formalized by the following Proposition 1 in the case of ϕ = 0. Proposition 1. For any i ∈ [k] and any function f eig satisfying equation 9 with ϕ = 0 and equation 10, we have that function g(x) = f eig (x) i is an eigenfunction of L with eigenvalue 0. Recall that our assumptions in the previous section intuitively say that even though a larger number of clusters exist in the positive-pair graph, many of them are not implementable by the function class. From the eigenfunction viewpoint, this means that only a small number of eigenfunctions with small eigenvalue are in the function class. Thus, we can make the following corresponding assumption which says that the vector-valued function f eig contains all the implementable eigenfunctions with small eigenvalue. Assumption 6. Suppose g is a function implementable by F (in the sense that g(x) = f (x) i for some f ∈ F and i ∈ [k]) and E (x,x + )∼ppos [(g(x) -g(x + )) 2 ] ≤ φ • E x∼p data g(x) 2 , ( ) then there exists w ∈ R k such that E x∼p data ( w⊤ f eig (x) -g(x)) 2 ≤ ϵ. ( ) Here both φ and ϵ are very small and can be thought as 0. We consider downtream tasks that can be solved by f eig . Let ⃗ y(x) ∈ R r be a vector that represents the downstream label of data x (e.g., the one-hot embedding of the label when the downstream task is classification). We have the following assumption on the downstream task: Assumption 7. There exists a linear head W * ∈ R r×m with norm ∥W * ∥ F ≤ B such that E x∼p data W * f eig (x) -⃗ y(x) 2 2 ≤ ζ. ( ) Here ζ is very small and can be thougth as 0. We have the following theorem using the above two assumptions: Theorem 2. Suppose function f eig ∈ F satisfies Assumptions 6 with ( φ, ϵ) and Assumption 7 with (B, ζ). Suppose φ > ϕ or φ = ϕ = 0. Then, for any λ > 0 such that ϕ ≤ φ(1ϕ/λ) and learned representation map f = arg min f ∈F L λ (f ), there exists a linear head W ∈ R r×k such that E x∼p data W f (x) -⃗ y(x) 2 2 ≲ ζ + B 2 k ϵ + ϕ λ . ( ) Since ζ, ϵ and ϕ are all very small values, the RHS of equation 14 is very small, hence the learned representation acheives small downstream error. As we will see in the first example in the next section, Theorem 2 indeed allows the representation dimensionality to be smaller than the number of minimal implementable clusters in the graph, hence generalizes the result in the previous section. Relationship between Theorem 2 and Theorem 1. We note that in the setting of Theorem1, the identity function of each minimal implementable cluster would be an achievable eigenfunction. Theorem 2 considers a more general situation than Theorem 1 where the minimal implementable clusters may not be well-defined, yet still we can show good results when the dimensionality is equal to the number of achievable eigenfunctions. We will mainly use Theorem 2 for the examples because it's more general and easier to be used, whereas we present Theorem 1 because it's more intuitive to understand. For instance, in our Example 1, Theorem 1 only applies when s = 1, whereas Theorem 2 applies for arbitrary s.

5. INSTANTIATIONS ON SEVERAL SYNTHETIC DATA DISTRIBUTIONS

In this section, we instantiate our previous theory on several examples of data distributions and show that when the model class has limited capacity, one can learn low-dimensional representations using contrastive learning and solve the downstream task with simple linear probing. In all of these examples, if we use a much more expressive model class, the representation dimensionality needs to be much higher, and hence more downstream samples are needed. These results demonstrate the benefit of leveraging inductive biases of the model architecture in contrastive learning.

5.1. LINEAR FUNCTIONS

Our first example is the hypercube example proposed in Saunshi et al. (2022) . Example 1. The natural data x ∼ {-1, 1} d is the uniform distribution over the d-dimensional cube. Given a natural data x, an augmented data x ∼ A(x) is sampled as follows: first uniformly sample a scalar τ ∼ [ 1 2 , 1], then scale the (s + 1)-th to d-th dimensions of x with τ , while keeping the first s dimensions the same. Intuitively, the last d -s dimensions correspond to spurious features that can be changed by data augmentation, and the first s dimensions are invariance features that contain information about the downstream task. The downstream task is a binary classification problem, where the label y(x) = sgn(x i ) is the sign function of one of the first s dimensions i ∈ [s]. We consider contrastive learning with the linear function class defined below: Definition 1 (Linear function class). Let U ∈ R k×d be a matrix and we use f U (x) = U x to denote the linear function with weight matrix U . We define the k-dimensional linear function class as 2022) directly compute the learned representations from contrastive learning. Here we show that the example can be viewed as an instantiation of our more general Theorem 2. In particular, we have the following result: Theorem 3. In Example 1, suppose we set the output dimensionality as k = s and learn a linear representation map that minimizes the contrastive loss f = arg min f ∈Flinear L λ (f ) for any λ > 0. Then, there exists a linear head w ∈ R k such that F linear = {f U : U ∈ R k×d }.

Saunshi et al. (

E x∼p data [(w ⊤ f (x) -y(x)) 2 ] = 0. ( ) In contrast, suppose the function class is the set of universal function approximators F uni . So long as the output dimensionality is no more than 2 d-1 , there exists solution f ′ ∈ arg min f ∈Funi L λ (f ) such that for any linear head w ∈ R k , we have E x∼p data [(w ⊤ f ′ (x) -y(x)) 2 ] ≥ 1. We note that as an implication of the lower bound, previous works that analyze universal function approximators (Arora et al., 2019; Tosh et al., 2021; HaoChen et al., 2021) wouldn't be able to show good downstream accuracy unless the representation dimensionality is larger than 2 d-1 . In contrast, our theory that incorporates the inductive biases of the function class manages to show that a much lower representation dimensionality k = s suffices. We also note that this example shows a situation where Theorem 2 works but Theorem 1 doesn't, hence demonstrating how our theory derived from the eigenfunction viewpoint allows for lower representation dimensionality. There are 2 s model-restricted minimal clusters in the graph, each encoded by one configuration of the s feature dimensions. However, all the function in F linear that implment a cluster span in a s-dimensional subspace, thus we can find s-dimensional eigenfunctions that satisfies Assumption 6. As a result, learning s-dimensional representations already suffices for solving the downstream task.

5.2. RELU NETWORKS

In the previous example, the downstream task is only binary classification where the label is defined by one invariant feature dimension. Here we show that when we use a ReLU network as the model architecture, the linear probing can solve more diverse downstream tasks where the label can depend on the invariant feature dimensions arbitrarily. Example 2. The natural data distribution and the data augmentation are defined in the same way as Example 1. The downstream task is a r-way classification problem such that the label function y(•) : X → [r] satisfies y(x) = y(x ′ ) if x 1:s = x ′ 1:s . In other words, the label only depends on the first s dimensions of the data. Definition 2 (ReLU networks). Let U ∈ R k×d and b ∈ R k , we use f U,b = σ(W x + b) to denote the ReLU network with weight U and bias b, where σ is the element-wise ReLU activation. We define the k-dimensional ReLU network function class as F ReLU = {f U,b : U ∈ R k×d , b ∈ R k }. We have the following theorem which shows the effectiveness of the ReLU network architecture. Theorem 4. In Example 2, suppose we set the output dimensionality k = 2 s and learn a ReLU network representation map f = arg min f ∈FReLU L λ (f ) for some λ > 0. Then, we can find a linear head W ∈ R r×k such that E x∼p data W f (x) -e y(x) 2 2 = 0. ( ) In contrast, suppose the function class is the set of universal function approximators F uni . So long as the output dimensionality is no more than 2 d-s , there exists solution f ′ ∈ arg min f ∈Funi L λ (f ) such that for any linear head W ∈ R r×k , we have E x∼p data W f ′ (x) -e y(x) 2 2 ≥ 1 2 .

5.3. LIPSCHITZ CONTINUOUS FUNCTIONS

In many real-world settings where a neural network is trained with weight decay, the resulting model usually has a limited weight norm which encourages the network to have a smaller Lipschitz constant. The implicit bias of the optimizers can further encourage the smoothness of the learned function. Here we provide an example showing that restricting the model class to Lipschitz continuous functions allows us to use lower dimensional representations. In particular, we consider the following example where a large number of clusters are located close to each other despite being disconnected in the positive-pair graph. Our result shows that contrastive learning with Lipschitz continuous functions would group those clusters together, allowing for lower representation dimensionality. Example 3. Let S 1 , S 2 , • • • , S m ⊂ R d be m manifolds in R m , each of which may contain lots of disconnected subsets. Suppose the radius of every manifold is no larger than ρ, that is for any i ∈ [m] and two data x, x ′ ∈ S i , we have ∥x -x ′ ∥ 2 ≤ ρ. We also assume that different manifolds are separated by γ, that is for any i, j ∈ [m] such that i ̸ = j, and x ∈ S i , x ′ ∈ S j , we have ∥x - x ′ ∥ 2 ≥ γ. The data distribution p data is supported on S 1 ∪ S 2 ∪ • • • ∪ S m and satisfies Pr x∼p data (x ∈ S i ) = 1/m for every i ∈ [m]. A positive pair only contains data in the same S i . The downstream task is a r-way classification problem such that the label function y(•) : X → [r] satisfies y(x) = y(x ′ ) if x and x ′ belong to the same set S i . We introduce the following family of Lipschitz continuous functions with parameter κ: Definition 3 (κ-Lipschitz continuous functions). A function f ∈ R d → R k is κ-Lipschitz if ∥f (x) -f (x ′ )∥ 2 ≤ κ ∥x -x ′ ∥ 2 for all x, x ′ ∈ R d . We define the κ-Lipschtiz function class F Lip,κ as the set of all κ-Lipschitz continuous functions in R d → R k . We have the following theorem: Theorem 5. In Example 3, suppose κ ≥ √ 2m/γ. Let the output dimensionality k = m and learn a κ-Lipschitz continuous function f ∈ arg min FLip,κ L λ (f ) for some λ > 0. Then, we can find a linear head W ∈ R r×k such that E x∼p data W f (x) -e y(x) 2 2 ≤ 2rmκ 2 ρ 2 . ( ) On the other hand, suppose the positive-pair graph contains m disconnected clusters, and the function class is the set of universal function approximators F uni . So long as the output dimensionality k < m, there exists solution f ′ ∈ arg min f ∈Funi L λ (f ) such that for any linear head W ∈ R r×k , we have E x∼p data W f ′ (x) -e y(x) 2 2 ≥ 1 m . We note that a smaller κ (hence smoother function class) decreases the RHS of equation 18 and leads to better downstream performance.

6. CONCLUSION

In this paper, we provide a theoretical analysis of contrastive learning that incoporates the inductive biases of the model class. We prove that contrastive learning with appropriate model architectures allows for lower representation dimensionality (hence better sample complexity), and instantiate this theory on several interesting examples. One open questions is to allow k > m in our theory, which we believe requires additional assumptions on the structure of the family of the models. We note that our work only concerns the inductive biases originating from the model architecture, whereas in practice the learned representations also depend on the optimization method. Hence, another interesting future direction would be studying how the implicit bias introduced by the optimizer influences self-supervised learning.

A ADDITIONAL EXAMPLES

A.1 CONVOLUTIONAL NEURAL NETWORKS Our last example shows that convolutional neural networks can learn contrastive representation more efficiently than fully connected networks when the downstream task has a certain rotational invariance structure. We consider the following data generative model where the data contains a feature patch that determines the downstream label. Example 4. The natural data x ∈ R d is defined as follows: for some consecutive s dimensions xt:t+s-1 (the informative patch), we have xt:t+s-1 ∈ {-γ, γ} s where γ > 1. 1 The other d -s dimensions of x (spurious dimensions) are all in {-1, 1}. Given a natural data x, its augmentations are generated by first sampling τ ∼ U ni[0, 1], then multiplying the spurious dimensions of x by τ , while keeping the informative patch the same. The downstream task is a r-way classification problem such that the label function y(•) : X → [r] satisfies y(x) = y(x ′ ) if the informative patches for x and x ′ are the same. We consider the following convolutional neural network model with k channels. Definition 4 (Convolutional neural networks). Let U = [u 1 , u 2 , • • • , u k ] ⊤ ∈ R k×s and b ∈ R k . We use f conv U,b : X → R k to represent the following convolutional neural network: f conv U,b (x) i = d t=1 σ(u ⊤ i x t:t+s-1 + b i ), where σ is ReLU activation function. We define the convolutional neural network class F conv = {f conv U,b : U ∈ R k×s , b ∈ R k }. We have the following theorem which shows that contrastive learning with convolutional neural networks requires lower representation dimensionality than using fully-connected ReLU networks. Theorem 6. In Example 4, let output dimensionality k = 2 s and learn a convolutional neural network f ∈ arg min Fconv L λ (f ) for some λ > 0. Then, we can find a linear head W ∈ R r×k such that E x∼p data W f (x) -e y(x) 2 2 = 0. On the other hand, suppose the function class is the set of ReLU networks F ReLU , so long as the output dimensionality is less than d × 2 s , there exists a function f ′ ∈ arg min f ∈FReLU L λ (f ) such that for any linear head W ∈ R r×k , we have E x∼p data W f ′ (x) -e y(x) 2 2 ≥ 1 d•2 s .

B NUMERICAL SIMULATIONS

Recall that our assumptions intuitively state that the model architecture cannot break the data into too many well-separated clusters. In this section, we propose a method to empirically test how many clusters a model architecture can partition positive-pair graph of the data distribution into. Given a deep neural network and a target number of cluster r, ideally we aim to find a function f from the model class that maps each data point to a one-hot vector in dimension r which includes the cluster identity. That is, f (x) ∈ {e 1 , . . . , e r } where e i is the i-th natural basis in R r . With this constraint, we minimize the disagreement between the functions outputs of a positive-pair, that is, E (x,x+)∼ppos [∥f (x) -f (x + )∥ 2 2 ], which compute the amount of inter-cluster edges. However, the one-hot vector requirement makes it challenging for optimization. Note that when the r clustering has the same probability mass 1/r, we have E[f (x)f (x) ⊤ ] = I/r. We use this equation as the constraint of f and arrive at a relaxation of the original optimization program. b r = min E (x,x+)∼ppos [ f (x) -f (x + ) 2 2 ] s.t. E[f (x)f (x) ⊤ ] = I/r Thus, we use b r as a surrogate for how the architecture can partition the graph into r clusters, and a smaller b r means that it's easier to partition. We empirically implement equation 19 by first minimizing the contrastive loss L λ (f θ ) with representation dimension k = r and a heavilytuned regularization strength λ. Then, we whiten the obtained model f θ (x) to have exactly the covariance I/r, that is, f (x) = E x∼p data [f θ (x)f θ (x) ⊤ ] -1 2 f θ (x)/ √ r is a valid solution for the program in equation 19. We compute b r = E (x,x+)∼ppos [ f (x) -f (x + ) 2 2 ]. We also try various choices of λ and pick the smallest result as the final value of the estimated b r . We run this test with a ResNet-18 model on CIFAR-10 and compute the b r for r ∈ {10, 100, 500} list the results the table below. Here we note that b r increases from 0.127 to 0.315 as r increases from 10 to 500, suggesting that although the network can partition the data relatively well into 10 clusters, it cannot partion the data into 500 well-separated clusters, which supports our theoretical assumptions. More details can be found in Appendix B.1. We train a ResNet-18 model on CIFAR-10 training set and test the b r on the test set. We train with SGD using initial learning rate 0.01 and decays with cosine schedule. All experiments are run for 200 epochs. We test with r ∈ {10, 100, 500} and grid search using λ ∈ {0.1, 0.3, 1, 3, 10, 30, 100, 300, 1000}, the result for each configurate is listed in the table below. We first show that f * achieve small contrastive loss. For the regularizer term, we have E x∼p data f * (x)f * (x) ⊤ = i∈[m] P i • 1 P i • e idx e ⊤ idx = I. Thus, we have R(f * ) = 0. For the discrepancy term, let P min := min i∈[m] P i be the probability mass of the smallest set, we have E x,x + f * (x) -f * (x + ) 2 2 ≤ 1 P min • Pr (x,x+)∼ppos (id x ̸ = id x + ) ≤ α P min . Combining equation 20 and equation 21 we have L λ (f * ) = E x,x + f * (x) -f * (x + ) 2 2 + λ • R(f ) ≤ α P min . ( ) Since f = arg min f ∈F L λ (f ) is the minimizer of contrastive loss within the function class, we have L λ ( f ) ≤ L λ (f * ) ≤ α P min . Define matrix M := E x∼p data f (x) f (x) ⊤ . ( ) We have ∥M -I∥ 2 F ≤ L λ ( f ) λ ≤ α λP min . ( ) Since λ > α Pmin , we know that M is a full rank matrix, thus we can define function f (x) := M -1 2 f (x). Let Q := E x∼p data f (x)f * (x) ⊤ , and π f (x) := f (x) -Qf * (x). We know that E x∼p data π f (x)f * (x) ⊤ = E x∼p data f (x)f * (x) ⊤ -QE x∼p data f * (x)f * (x) ⊤ = 0. ( ) Using Assumption 3 we have: E (x,x+)∼ppos π f (x) -π f (x + ) 2 2 ≥ i∈[m] (P i -α) • E (x,x+)∼ppos i π f (x) -π f (x + ) 2 2 ≥β • i∈[m] (P i -α) • E x∼p datai ,x ′ ∼p datai ∥π f (x) -π f (x ′ )∥ 2 2 =2β • i∈[m] (P i -α) • E x∼p datai ∥π f (x)∥ 2 2 =2β • (1 - α P min ) • E x∼p data ∥π f (x)∥ 2 2 . On the other hand, we have E (x,x+)∼ppos π f (x) -π f (x + ) 2 2 ≤E (x,x+)∼ppos f (x) -f (x + ) 2 2 ≤ M -1 spec • E (x,x+)∼ppos f (x) -f (x + ) 2 2 ≤ 1 + α λP min • α P min . Combining equation 30 and equation 31 we have E x∼p data ∥π f (x)∥ 2 2 ≤ 1 + α λP min • α 2β(P min -α) . By Lemma 1, we know that there exists a matrix U ∈ R m×m such that E x∼p data f * (x) -U M -1/2 f (x) 2 2 ≤ 1 + α λP min • α 2β(P min -α) (33) ≤ α β(P min -α) . Thus, if we define matrix W = diag{ √ P 1 , √ P 2 , • • • , √ P m }U M -1/2 , then we have E x∼p data e idx -W f (x) 2 2 ≤ P max E x∼p data f * (x) -U M -1/2 f (x) 2 2 (35) ≤ P max α β(P min -α) , which finishes the proof. Lemma 1. Suppose f : X → R m and g : X → R m are two functions defined on X such that E x∼p data f (x)f (x) ⊤ = E x∼p data g(x)g(x) ⊤ = I. Define the projection of f onto g's orthogonal subspace as: π f (x) = f (x) -E x ′ ∼p data f (x)g(x) ⊤ g(x). Then, there exists matrix U ∈ R m×m such that E x∼p data ∥g(x) -U f (x)∥ 2 2 = E x∼p data ∥π f (x)∥ 2 2 . ( ) Proof of Lemma 1. Let matrix U = E x ′ ∼p data g(x)f (x) ⊤ . We have E x∼p data ∥g(x) -U f (x)∥ 2 2 (41) =E x∼p data ∥g(x)∥ 2 2 -2E x∼p data g(x) ⊤ U f (x) + E x∼p data f (x) ⊤ U ⊤ U f (x) (42) =m -2 ∥U ∥ 2 F + ∥U ∥ 2 F (43) =m -∥U ∥ 2 F . On the other hand, we have E x∼p data ∥π f (x)∥ 2 2 (45) =E x∼p data f (x) -U ⊤ g(x) 2 2 (46) =E x∼p data ∥f (x)∥ 2 2 -2E x∼p data f (x) ⊤ U ⊤ g(x) + E x∼p data g(x) ⊤ U U ⊤ g(x) (47) =m -∥U ∥ 2 F . Thus, we have E x∼p data ∥g(x) -U f (x)∥ 2 2 = E x∼p data ∥π f (x)∥ 2 2 , which finishes the proof.

D PROOFS FOR SECTION 4

Proof of Proposition 1. Define function g(x) = p data (x)g(x). Define the symmetric Laplacian operator L(g)(x) = g(x) - p pos (x, x ′ ) p data (x) p data (x ′ ) g(x ′ )dx ′ . It can be verified that x g(x) L(g)(x) = 0. Notice that the operator L is PSD, we have that x ( L(g)(x)) 2 = 0, which is equivalent to E x∼p data (L(g)(x)) 2 = 0, hence finishes the proof. Proof of Theorem 2. Notice that L λ (f eig ) ≤ ϕ, we know that L λ ( f ) ≤ ϕ, so E x∼p data f (x) f (x) ⊤ -I 2 F ≤ ϕ λ . In Assumption 6, set f = f and sum over i = 1, 2, • • • , k, we have that for some matrix W , E x∼p data W f eig (x) -f (x) 2 2 ≤ kϵ. Let matrix Q := E x∼p data [ f (x) f (x) ⊤ ], we have that E x∼p data Q -1/2 f (x) -f (x) 2 2 ≤ 2ϕ λ • E x∼p data f (x) 2 2 ≤ 2ϕ λ k 1 + ϕ λ . Thus, E x∼p data W f eig (x) -Q -1/2 f (x) 2 2 ≤ 2kϵ + 4ϕ λ k 1 + ϕ λ . Define matrix M := E x∼p data f eig (x)Q -1/2 f (x) ⊤ Using Lemma 1 and equation 58 we have E x∼p data f eig (x) -M Q -1/2 f (x) 2 2 ≤ 2kϵ + 4ϕ λ k 1 + ϕ λ ≤ 2kϵ + 8ϕ λ k. Thus, using Assumption 7, we have E x∼p data W * M Q -1/2 f (x) -e y(x) 2 2 (61) ≤2E x∼p data W * f eig (x) -e y(x) 2 2 + 2E x∼p data W * M Q -1/2 f (x) -W * f eig (x) 2 2 (62) ≤2ζ + 4B 2 kϵ + 16ϕ λ B 2 k. (63) E PROOFS FOR SECTION 5 E.1 PROOF FOR EXAMPLE 1 Proof of Theorem 3. Define Û = [e 1 , e 2 , • • • , e s ] ⊤ ∈ R s×d . We can verify that E (x,x + )∼ppos f Û (x) -f Û (x + ) 2 2 = 0 and E x∼p data f Û (x)f Û (x) ⊤ = I. Thus, we can view f Û as the f eig in Section 4. Let U ∈ R k×d and i ∈ [k] such that E (x,x + )∼ppos [(f U (x) i -f U (x + ) i ) 2 ] = 0. Notice that x and x + only differs on the s + 1-th to d-th dimensions, we know that U i is 0 on the s + 1-th to d-th dimensions. Thus, we have that U i is in the span of e 1 , e 2 , • • • , e s , and as a result Assumption 6 holds with ϵ = 0. Since the downstream task's label is equal to x i for i ∈ [s], we can set W * = e ⊤ i and we would have E x∼p data (W * f Û (x) -⃗ y(x)) 2 = 0. (67) Hence Assumption 7 holds with α = 0 and B = 1. Applying Theorem 2 finishes the proof for the linear function class case. For the case of universal function approximators, without loss of generality we assume the downstream task's label only depends on the first dimension of x, i.e., y(x) = sgn(x 1 ). When k ≤ 2 d-1 , we can construct a function f : X → R k such that for every diemnsion j ∈ [k], we have f (x) j = √ k when x 2:d viewed as a binary number equals to j, otherwise f (x) j = 0. It can be verified that L λ (f ) = 0 hence f is a minimizer of the contrastive loss. However, f (x) is agnostic to the first dimension of x, hence the downstream error is at least 1.

E.2 PROOF FOR EXAMPLE 2

Proof of Theorem 4. For any vector h ∈ {-1, 1} s , we define function bin(h) ∈ {0, 1, • • • , 2 s -1} be the function that maps h to the corresponding number when viewing 1 2 (h + 1) as binary. Since bin(•) is a one-to-one mapping, we can define U ∈ {k × d} such that the i-th row of U satisfies: the first s dimensions equal to For Assumption 6, consider a function f U ′ ,b ′ ∈ F ReLU and index i ∈ [k] such that E (x,x + )∼ppos (f U ′ ,b ′ (x) i -f U ′ ,b ′ (x + ) i ) 2 = 0. Suppose there exist x ̸ = x′ and their augmentations x, x ′ such that f U ′ ,b ′ (x) i > f U ′ ,b ′ (x ′ ) i . Then, there must be (U i ) r+1:d ̸ = 0 and σ(U ⊤ i x) > 0. This suggests that there must exist another x which is also an augmentation of x but σ(U ⊤ i x) ̸ = σ(U ⊤ i x). Hence, we have E (x,x + )∼ppos (f U ′ ,b ′ (x) i -f U ′ ,b ′ (x + ) i ) 2 > 0, leading to contradiction. Hence, we know that f U ′ ,b ′ (x) i = f U ′ ,b ′ (x ′ ) i , so (f U ′ ,b ′ ) i can only be a function of x 1:s . Therefore, there exists a vector w ∈ R k such that f U ′ ,b ′ (x) i = w ⊤ f U,b (x), which means Assumption 6 holds with ϵ = 0. Applying Theorem 2 finishes the proof for equation 17. The result about universal function approximators follows the same proof as for Theorem 3 execpt for constructing the function using the last (d-s) dimensions rather than the last (d-1) dimensions.

E.3 PROOF FOR EXAMPLE 3

Proof of Theorem 5. Let id x be the index such that x ∈ S idx , and define function f eig (x) = √ m•e idx . It can be verified that f eig satisfies equation 9 and equation 10. For f ∈ F Lip,κ and i ∈ [m], define g(x) = f (x) i . Suppose E x∼p data [g(x) 2 ] = 1, we can chooose m data x 1 , x 2 , • • • , x m such that x i ∈ S i and 1 m i∈[m] g(x i ) 2 ≤ 1. Define vector w ∈ R m such that wi = 1 √ m • g(x i ). We have E x∼p data ( w⊤ f eig (x) -g(x)) 2 = 1 m i∈[m] E x∼Si (g(x i ) -g(x)) 2 (69) ≤ 1 m i∈[m] κ 2 ρ 2 = κ 2 ρ 2 . ( ) Thus, f eig satisfies Assumption 6 with ϵ = κ 2 ρ 2 . Since the data in the same S i have the same downstream label, we know that Assumption 7 holds with B = √ r and α = 0. Thus, applying Theorem 2 finishes the proof for the upper bound. For the lower bound, Let set S be the set among those m clusters that has the largest size. When k < m, we can construct a function that maps all data in S to 0, hence the final error would be at least 1 m .



] + λ • R(f ),(1) Here we denote xd+i = xi.



Figure1: A simple example where the linear function class learns the correct feature and ignores the spurious feature. A simplified version of the synthetic example proposed inSaunshi et al. (2022). The orange points are the original data and blue points are augmented data (obtained by adding noise in the spurious dimension). The dimension invariant to augmentation is desired. Edges represent positive pairs that are constructed from augmentation. We say a real-valued function implements a cluster if it outputs 1 on the cluster and outputs 0 on all other data. We note that here implementing means matching the exact value, rather than simply matching the label after applying some linear threshold. The figure above shows two possible ways to partition the data into two clusters, but only the one on the left-hand side (which captures the invariant dimension) is implementable by a linear function. Here we use black numbers to indicate the target output on the data, and green numbers to indicate the output of the implementing function which extrapolates outside of the data support. Note that linear model is not composed with a threshold function. The partition on the right hand side is not implementable because any linear model that outputs constant 1 on the upper-left small cluster would also output 1 on the bottom-left small cluster due to linear extrapolation. Here we use red numbers to indicate the output of the linear function that contradicts with the target.

Figure 2: A demonstration of the positive-pair graph. When the positive pairs are formed by applying data augmentation (such as rotation) to the same natural image, data with the same semantic meaning (e.g., the two butterfly images) tend to belong to the same cluster in the positive-pair graph. Datapoints with different semantic meanings (e.g., a butterfly image and a clock image) would not be connected in the positive-pair graph, hence belongs to different clusters.

0.127 0.134 0.144 0.151 0.155 0.215 0.343 0.901 r = 100 0.887 0.660 0.408 0.245 0.204 0.220 0.254 0.424 1.579 r = 500 1.031 0.981 0.710 0.554 0.427 0.372 0.315 0.481 1.231 C PROOFS FOR SECTION 3 Proof of Theorem 1. Let P i := Pr x∼p data (x ∈ S i ) be the probability of S i . Let f * be the function f * (x) = 1 √ Pid x e idx . From Assumption 5 and Assumption 4, we have f * ∈ F.

k • bin -1 (i -1), and the rest d -s dimensions are 0. Let bias vectorb ∈ R k such that every dimension is -√ k • (r -1). We have f U,b (x) = √ k • e bin(x1:s)+1 ∈ R k . Since E x∼p data [f U,b (x)f U,b (x) ⊤ ] = I and E (x,x + )∼ppos [∥f U,b (x) -f U,b (x + )∥22 ] = 0, we can view f U,b as the f eig in Section 4. Assumption 7 naturally hold wihth B = 1.

E.4 PROOF FOR EXAMPLE 4

Proof of Theorem 6. For any vector h ∈ {-1, 1} s , we define function bin(h) ∈ {0, 1, • • • , 2 s -1} be the function that maps h to the corresponding number when viewing 1 2 (h + 1) as binary. Since bin(•) is a one-to-one mapping, we can define U ∈ {k × s} such that the i-th row of U equal to, where t is the starting position of the informative patch in x. It can be verified that] = 0. Also, Assumption 7 holds with B = 1 when viewingThen, we know that for any x ∈ p data , suppose we define x as the vector that replaces spurious dimensions of x with 0. Notice that x is in the support of x's augmentations, and the model is continuous, we know) i Further notice that for any two data x, x ′ with the same informative patch (location might be different) and corresponding x, x′ , there must be f conv, hence finishes the proof for the upper bound.For the lower bound, we note that due to the lack of invariance to informative patch location, we can construct a network with d • 2 s -dimensional output that satisfies equation 9 and equation 10. When then output dimension is less than d • 2 s , there would exist a minimizer of the contrastive loss that merges two of these d • 2 s clusters. If these two clusters have different downstream label, there would be at least 1 d•2 s loss incurred due to the data being mapped to the same feature, hence finishes the proof for the lower bound.

