UNDERSTANDING THE ROLE OF NONLINEARITY IN TRAINING DYNAMICS OF CONTRASTIVE LEARNING

Abstract

While the empirical success of self-supervised learning (SSL) heavily relies on the usage of deep nonlinear models, existing theoretical works on SSL understanding still focus on linear ones. In this paper, we study the role of nonlinearity in the training dynamics of contrastive learning (CL) on one and two-layer nonlinear networks with homogeneous activation h(x) = h ′ (x)x. We have two major theoretical discoveries. First, the presence of nonlinearity can lead to many local optima even in 1-layer setting, each corresponding to certain patterns from the data distribution, while with linear activation, only one major pattern can be learned. This suggests that models with lots of parameters can be regarded as a brute-force way to find these local optima induced by nonlinearity. Second, in the 2-layer case, linear activation is proven not capable of learning specialized weights into diverse patterns, demonstrating the importance of nonlinearity. In addition, for 2-layer setting, we also discover global modulation: those local patterns discriminative from the perspective of global-level patterns are prioritized to learn, further characterizing the learning process. Simulation verifies our theoretical findings. Published as a conference paper at ICLR 2023 distributions, where f is a nonlinear network with parameters θ. Note that in Eqn. 2, the stop gradient operator sg(α) means that while the value of α may depend on θ, when studying the local property of θ, α makes no contribution to the gradient and should be treated as an independent variable. Since C α is an abstract mathematical object with complicated definitions, as the first contribution, we give its connection to regular variance V[•], if the pairwise importance α has certain kernel structures (Ghojogh et al., 2021; Paulsen & Raghupathi, 2016) : Definition 1 (Kernel structure of pairwise importance α). There exists a (kernel) function Definition 2 (Adjusted PDF pl (x)). For l-th component ϕ l of the mapping ϕ, we define the adjusted density pl (x; α) := 1 z l (α) ϕ l (x; α)p D (x), where z l (α) := ϕ l (x)p D (x)dx ≥ 0 is the normalizer. Obviously α ij ≡ 1 (uniform α corresponding to quadratic loss) satisfies Def. 1 with 1D mapping ϕ ≡ 1. Here we show a non-trivial case, Gaussian α, whose normalized version leads to InfoNCE: Lemma 1 (Gaussian α). For any function g(•) that is bounded below, if we use ) as the pairwise importance, then it has kernel structure (Def. 1). Note that Gaussian α computes N 2 pairwise distances using un-augmented samples x 0 , while InfoNCE (and most of CL losses) uses augmented views x and x ′ and normalizes along one dimension to yield asymmetric α ij . Here Gaussian α is a convenient tool for analysis. We now show C α is a summation of regular variances but with different probability of data, adjusted by the pairwise importance α that has kernel structures. Please check Appendix A.1 for detailed proofs. Lemma 2 (Relationship between Contrastive Covariance and Variance in large batch size). If α satisfies Def. 1, then for any function g(•), C α [g(x)] is asymptotically PSD when N → +∞: Corollary 1 (No augmentation and large batchsize). With the condition of Lemma 2, if we further assume there is no augmentation (i.e., p aug (x|x Receptive field 0 Receptive field 1 Receptive field 2 Receptive field 3 Figure 12: Learned filters (of the 4 disjoint receptive field) in MNIST dataset without using augmentation. The 4 receptive fields corresponds to upper left (0), upper right (1), bottom left (2) and bottom right (3) part of the input image. Receptive field 0 Receptive field 1 Receptive field 2 Receptive field 3

1. INTRODUCTION

Over the last few years, deep models have demonstrated impressive empirical performance in many disciplines, not only in supervised but also in recent self-supervised setting (SSL), in which models are trained with a surrogate loss (e.g., predictive (Devlin et al., 2018; He et al., 2021) , contrastive (Chen et al., 2020; Caron et al., 2020; He et al., 2020) or noncontrastive loss (Grill et al., 2020; Chen & He, 2020) ) and its learned representation is then used for downstream tasks. From the theoretical perspective, understanding the roles of nonlinearity in deep neural networks is one critical part of understanding how modern deep models work. Currently, most works focus on linear variants of deep models (Jacot et al., 2018; Arora et al., 2019a; Kawaguchi, 2016; Jing et al., 2022; Tian et al., 2021; Wang et al., 2021) . When nonlinearity is involved, deep models are often treated as richer families of black-box functions than linear ones (Arora et al., 2019b; HaoChen et al., 2021) . The role played by nonlinearity is also studied, mostly on model expressibility (Gühring et al., 2020; Raghu et al., 2017; Lu et al., 2017) in which specific weights are found to fit the complicated structure of the data well, regardless of the training algorithm. However, many questions remain open: if model capacity is the key, why traditional models like k-NN (Fix & Hodges, 1951) or kernel SVM (Cortes & Vapnik, 1995) do not achieve comparable empirical performance, even if theoretically they can also fit any functions (Hammer & Gersmann, 2003; Devroye et al., 1994) . Moreover, while traditional ML theory suggests carefully controlling model capacity to avoid overfitting, large neural models often generalize well in practice (Brown et al., 2020; Chowdhery et al., 2022) . In this paper, we study the critical role of nonlinearity in the training dynamics of contrastive learning (CL). Specifically, by extending the recent α-CL framework (Tian, 2022) and linking it to kernels (Paulsen & Raghupathi, 2016) , we show that even with 1-layer nonlinear networks, nonlinearity plays a critical role by creating many local optima. As a result, the more nonlinear nodes in 1-layer networks with different initialization, the more local optima are likely to be collected as learned patterns in the trained weights, and the richer the resulting representation becomes. Moreover, popular loss functions like InfoNCE tends to have more local optima than quadratic ones. In contrast, in the linear setting, contrastive learning becomes PCA under certain conditions (Tian, 2022) , and only the most salient pattern (i.e., the maximal eigenvector of the data covariance matrix) is learned while other less salient ones are lost, regardless of the number of hidden nodes. Based on this finding, we extend our analysis to 2-layer ReLU setting with non-overlapping receptive fields. In this setting, we prove the fundamental limitation of linear networks: the gradients of multiple weights at the same receptive field are always co-linear, preventing diverse pattern learning. Finally, we also characterize the interaction between layers in 2-layer network: while in each receptive field, many patterns exist, those contributing to global patterns are prioritized to learn by the training dynamics. This global modulation changes the eigenstructure of the low-level covariance matrix so that relevant patterns are learned with higher probability. In summary, through the lens of training dynamics, we discover unique roles played by nonlinearity which linear activation cannot do: (1) nonlinearity creates many local optima for different patterns of the data, and (2) nonlinearity enables weight specialization to diverse patterns. In addition, we also discover a mechanism for how global pattern prioritizes which local patterns to learn, shedding light on the role played by network depth. Preliminary experiments on simulated data verify our findings. Related works. Many works analyze network at initialization (Hayou et al., 2019; Roberts et al., 2021) and avoid the complicated training dynamics. Previous works (Wilson et al., 1997; Li & Yuan, 2017; Tian et al., 2019; Tian, 2017; Allen-Zhu & Li, 2020 ) that analyze training dynamics mostly focus on supervised learning. Different from Saunshi et al. (2022) ; Ji et al. (2021) that analyzes feature learning process in linear models of CL, we focus on the critical role played by nonlinearity. Our analysis is also more general than Li & Yuan (2017) that focuses on 1-layer ReLU network with symmetric weight structure trained on sparse linear models. Along the line of studying dynamics of contrastive learning, Jing et al. (2022) analyzes dimensional collapsing on 1 and 2 layer linear networks. Tian (2022) proves that such collapsing happens in linear networks of any depth and further analyze ReLU scenarios but with strong assumptions (e.g., one-hot positive input). Our work uses much more relaxed assumptions and performs in-depth analysis for homogeneous activations.

2. PROBLEM SETUP

Notation. In this section, we introduce our problem setup of contrastive learning. Let x 0 ∼ p D (•) be a sample drawn from the dataset, and x ∼ p aug (•|x 0 ) be a augmentation view of the sample x 0 . Here both x 0 and x are random variables. Let f = f (x; θ) be the output of a deep neural network that maps input x into some representation space with parameter θ to be optimized. Given a batch of size N , x 0 [i] represent i-th sample (i.e., instantiation) of corresponding random variables, and x[i] and x[i ′ ] are two of its augmented views. Here x[•] has 2N samples, 1 ≤ i ≤ N and N + 1 ≤ i ′ ≤ 2N . Contrastive learning (CL) aims to learn the parameter θ so that the representation f are distinct from each other: we want to maximize squared distance dfoot_1 ij := ∥f [i] -f [j]∥ 2 2 /2 between samples i ̸ = j and minimize d 2 i := ∥f [i] -f [i ′ ]∥ 2 2 /2 between two views x[i] and x[i ′ ] from the same sample x 0 [i]. Many objectives in contrastive learning have been proposed to combine these two goals into one. For example, InfoNCE (Oord et al., 2018) minimizes the following (here τ is the temperature): L nce := -τ N i=1 log exp(-d 2 i /τ ) ϵ exp(-d 2 i /τ )+ j̸ =i exp(-d 2 ij /τ ) In this paper, we follow α-CL (Tian, 2022) that proposes a general CL framework that covers a broad family of existing CL losses. α-CL maximizes an energy function E α (θ) using gradient ascent: θ t+1 = θ t + η∇ θ E sg(α(θt)) (θ), where η is the learning rate, sg(•) is the stop gradient operator, the energy function E α (θ) := 1 2 trC α [f , f ] and C α [•, •] is the contrastive covariance (Tian, 2022; Jing et al., 2022) foot_0 : C α [a, b] := 1 2N 2 N i,j=1 α ij (a[i] -a[j])(b[i] -b[j]) ⊤ -(a[i] -a[i ′ ])(b[i] -b[i ′ ]) ⊤ (3) One important quantity is the pairwise importance α(θ) = [α ij (θ)] N i,j=1 , which are N 2 weights on pairwise pairs of N samples in a batch. Intuitively, these weights make the training focus more on hard negative pairs, i.e., distinctive sample pairs that are similar in the representation space but are supposed to be separated away. Many existing CL losses (InfoNCE, triplet loss, etc) are special cases of α-CL (Tian, 2022) by choosing different α(θ), e.g., quadratic loss corresponds to α ij := const and InfoNCE (with ϵ = 0) corresponds to α ij := exp(-d 2 ij /τ )/ j̸ =i exp(-d 2 ij /τ ). For brevity C α [x] := C α [x, x]. For the energy function E α (θ) := trC α [f (x; θ)], in this work we mainly study its landscape, i.e., existence of local optima, their local properties and overall 3 ONE-LAYER CASE Now let us first consider 1-layer network with K hidden nodes: f (x; θ) = h(W x), where W = [w 1 , . . . , w K ] ⊤ ∈ R K×d , θ = {W } and h(x) is the activation. The k-th row of W is a weight w k and its output is f k := h(w ⊤ k x). In this case, trC α [f ] = K k=1 C α [f k ]. We consider per-filter normalization ∥w k ∥ 2 = 1, which can be achieved by imposing BatchNorm (Ioffe & Szegedy, 2015) at each node k (Tian, 2022) . In this case, optimization can be decoupled into each filter w k : max θ E α (θ) = 1 2 max ∥w k ∥2=1,1≤k≤K trC α [f ] = 1 2 K k=1 max ∥w k ∥2=1 C α [h(w ⊤ k x)] Now let's think about, which parameters w k maximizes the summation? For the linear case, since C α [h(w ⊤ x)] = C α [w ⊤ x] = w ⊤ C α [x] w, all w k converge to the maximal eigenvector of C α [x] (a constant matrix), regardless of how they are initialized and what the distribution of x is. Therefore, the linear case will only learn the most salient single pattern due to the (overly-smooth) landscape of the objective function, a winner-take-all effect that neglects many patterns in the data. In contrast, nonlinearity can change the landscape and create more local optima in C α [h(w ⊤ x)], each capturing one pattern. In this paper, we consider a general category of nonlinearity activations: Assumption 1 (Homogeneity (Du et al., 2018) /Reversibility (Tian et al., 2020) ). The activation satisfies h(x) = h ′ (x)x. Many activations satisfy this assumption, including linear, ReLU, LeakyReLU and monomial activations like h(x) = x p (with an additional global constant). In this case we have: where xw := x • h ′ (w ⊤ x) is the input data after nonlinear gating. When there is no ambiguity, we just write xw as x and omit the weight superscript. One property is h(w ⊤ x) = w ⊤ h ′ (w ⊤ x)x = w ⊤ xw , ( ) ℂ ! [𝒙] 𝜙 ! ℂ ! [% 𝒙 𝒘 ] 𝜙 ! (𝒘) C α [h(w ⊤ x)] = w ⊤ C α [ xw ]w. Now let A(w) := C α [ xw ]. With the constraint ∥w∥ 2 = 1, the learning dynamics is: Lemma 3 (Training dynamics of 1-layer network with homogeneous activation in contrastive learning). The gradient dynamics of Eqn. 5 is (note that α is treated as an independent variable): ẇk = P ⊥ w k A(w k )w k (7) Here P ⊥ w k := I -w k w ⊤ k projects a vector into the complementary subspace spanned by w k . See Appendix B.2 for derivations. Now the question is that: what is the critical point of the dynamics and whether they are attractive (i.e., local optima). In linear case, the maximal eigenvector is the one fixed point; in nonlinear case, we are looking for locally maximal eigenvectors, called LME. Definition 3 (Locally maximal eigenvector (LME)). w * is a locally maximal eigenvector of A(w), if A(w * )w * = λ * w * , where λ * = λ max (A(w * )) is the distinct maximal eigenvalue of A(w * ). It is easy to see each LME is a critical point of the dynamics, since P ⊥ w * A(w * )w * = λP ⊥ w * w * = 0.

3.1. EXAMPLES WITH MULTIPLE LMES IN RELU SETTING

To see why the nonlinear activation leads to many LMEs in Eqn. 7, we first give two examplar generative models of the input x that show Eqn. 7 has multiple critical points, then introduce more general cases. To make the examples simple and clear, we assume the condition of Corollary 1 (no augmentation and large batchsize), and let α ij ≡ 1. Notice that xw is a deterministic function of x, therefore A(w) := C α [ xw ] = V[ xw ]. We also use ReLU activation h(x) = max(x, 0). Let U = [u 1 , . . . , u M ] be orthonormal bases (u ⊤ m u m ′ = I(m = m ′ )). Here are two examples: Latent categorical model. Suppose y is a categorical random variable taking M possible values, P[x|y = m] = δ(x -u m ). Then we have (see Appendix B.1 for detailed steps): A(w) w=um := C α [ xw ] = V[ xw ] = P[y = m](1 -P[y = m])u m u ⊤ m (8) Now it is clear that w = u m is an LME for any m. Latent summation model. Suppose there is a latent variable y so that x = U y, where y := [y 1 , y 2 , . . . , y M ]. Each y m is a standardized Bernoulli random variable: E[y m ] = 0 and E[y 2 m ] = 1. This means that y m = y + m := (1 -q m )/q m with probability q m and y m = y - m := -q m /(1 -q m ) with probability 1 -q m . For m 1 ̸ = m 2 , y m1 and y m2 are independent. Then we have: A(w) w=um := C α [ xw ] = V [ x] = (1 -q m ) 2 u m u ⊤ m + q m (I -u m u ⊤ m ) which has a maximal and distinct eigenvector of u m with a unique eigenvalue (1 -q m ) 2 , when q m < 1 2 (3 -√ 5) ≈ 0.382. Therefore, different w leads to different LMEs. In both cases, the presence of ReLU removes the "redundant energy" so that A(w) can focus on specific directions, creating multiple LMEs that correspond to multiple learnable patterns. The two examples can be computed analytically due to our specific choices on nonlinearity h and α.

3.2. RELATE LMES TO LOCAL OPTIMA

Once LMEs are identified, the next step is to check whether they are attractive, or stable critical points, or local optima. That is, whether the weights converge into them and stay there during training. For this, some notations are introduced below. Notations. Let λ i (w) be the i-th largest eigenvalue of A(w), and ϕ i (w) the corresponding unit eigenvector, λ gap (w) := λ 1 (w) -λ 2 (w) the eigenvalue gap. Let ρ(w) be the local roughness measure: ρ(w) is the smallest scalar to satisfy ∥(A(v)-A(w))w∥ 2 ≤ ρ(w)∥v-w∥ 2 +O(∥v-w∥ 2 2 ) in a local neighborhood of w. The following theorem gives a sufficient condition for stability of w * : Theorem 1 (Stability of w * ). If w * is a LME of A(w * ) and λ gap (w * ) > ρ(w * ), then w * is stable. This shows that lowering roughness measure ρ(w * ) at critical point w * could lead to more local optima and more patterns to be learned. To characterize such a behavior, we bound ρ(w * ): Theorem 2 (Bound of local roughness ρ(w) in ReLU setting). If input ∥x∥ 2 ≤ C 0 is bounded, α has kernel structure (Def. 1) and batchsize N → +∞, then ρ(w * ) ≤ C 3 0 vol(C0) π r(w * , α), where r(w, α) := +∞ l=0 z 2 l (α) max w ⊤ x=0 pl (x; α). From Thm. 2, the bound critically depends on r(α) that contains the adjusted density pl (x; α) (Def. 2) at the plane w ⊤ * x = 0. This is because a local perturbation of w * leads to data inclusion/exclusion close to the plane, and thus changes ρ(w * ). Different α leads to different pl (x; α), and thus different upper bound of ρ(w * ), creating fewer or more local optima (i.e., patterns) to learn. Here is an example that shows Gaussian α (see Lemma 1), whose normalized version is used in InfoNCE, can lead to more local optima than uniform α, by lowering roughness bound characterized by r(w * , α): Corollary 2 (Effect of different α). For uniform α u (α ij := 1) and 1-D Gaussian α g (α ij := exp(-∥h(w ⊤ x 0 [i]) -h(w ⊤ x 0 [j])∥ 2 2 /2τ )), we have r(w * , α g ) = z 0 (α g )r(w * , α u ) with z 0 (α g ) := exp(-h 2 (w ⊤ * x)/2τ )p D (x)dx ≤ 1. As a result, z 0 (α g ) ≪ 1 leads to r(w * , α g ) ≪ r(w * , α u ). In practice, z 0 (α g ) can be exponentially small (e.g., when most data appear on the positive side of the weight w * ) and the roughness with Gaussian α can be much smaller than that of uniform α, which is presumably the reason why InfoNCE outperforms quadratic CL loss (Tian, 2022) .

3.3. FINDING CRITICAL POINTS WITH INITIAL GUESS

In the following, we focus on how can we find an LME, when A(w) does not have analytic form. We show that if there is an "approximate eigenvector" of A(w) := C α [ xw ], then a real one is nearby. Let L be the Lipschitz constant of A(w): ∥A(w) -A(w ′ )∥ 2 ≤ L∥w -w ′ ∥ 2 for any w, w ′ on the unit sphere ∥w∥ 2 = 1, and the correlation function c(w) := w ⊤ ϕ 1 (w) be the inner product between w and the maximal eigenvector of A(w). We can construct a fixed point using Power Iteration (PI) (Golub & Van Loan, 2013) , starting from initial value w = w(0): w(t + 1) ← A(w(t))w(t)/∥A(w(t))w(t)∥ 2 (PI) We show that even A(w) varies over ∥w∥ 2 = 1, the iteration can still converge to a fixed point w * , if the following quantity ω(w), called irregularity, is small enough. Definition 4 (Irregularity ω(w) in the neighborhood of fixed points). Let µ(w) : = .5(1 + c(w))c -2 (w) [1 -λ gap (w)/λ 1 (w)] 2 and ω(w) := ω(c(w), λ gap (w), λ 1 (w), L, κ) ≥ 0 defined as ω(w) := µ(w) + 2κL 2 (1 + µ(w)c(w)) + 2Lλ -1 gap (w) µ(w)(1 + µ(w)c(w)), ) here κ is the high-order eigenvector bound defined in Appendix (Lemma 9). Intuitively, when w(0) is sufficiently close to any LME w * , i.e., w(0) is an "approximate" LME, we have ω(w(0)) ≪ 1. In such a case, w(0) can be used to find w * using power iteration (Eqn. PI). Theorem 3 (Existence of critical points). Let c 0 := c(w(0)) ̸ = 0. If there exists γ < 1 so that: sup w∈Bγ ω(w) ≤ γ, where B γ := w : w ⊤ w(0) ≥ c0-cγ 1-cγ , c γ := 2 √ γ 1+γ is the neighborhood of initial value w(0). Then Power Iteration (Eqn. PI) converges to a critical point w * ∈ B γ of Eqn. 7. 1 ). We leave it for future work to further relax the condition. 𝑥 ! 𝑥 " 𝑥 # 𝑉 𝑊 {𝑤 !" } {𝑤 #" } {𝑤 $" } {𝑓 !" } 𝑅 ! 𝑅 " 𝑅 # (b) Conditional Independence 𝑅 ! 𝑅 " 𝑅 # 𝑅 $ 𝑧 = 0 𝑧 = 1 𝑧 = 2 Possible relation to empirical observations. Since there exist many local optima in the dynamics (Eqn. 7), even if objective involving w k are identical (Eqn. 5), each w k may still converge to different local optima due to initialization. We suspect that this can be a tentative explanation why larger model performs better: more local optima are collected and some can be useful. Other empirical observations like lottery ticket hypothesis (LTH) (Frankle & Carbin, 2019; Morcos et al., 2019; Tian et al., 2019; Yu et al., 2020) , recently also verified in CL (Chen et al., 2021) , may also be explained similarly. In LTH, first a large network is trained and pruned to be a small subnetwork S, then retraining S using its original initialization yields comparable or even better performance, while retraining S with a different initialization performs much worse. For LTH, our explanation is that S contains weights that are initialized luckily, i.e., close to useful local optima and converge to them during training. We leave a thorough empirical study to justify this line of thought for future work. Given this intuition, it is tempting to study the distribution of local optima of Eqn. 7, their attractive basin Basin(w * ) := {w : w(0) = w, lim t→+∞ w(t) = w * } for each local optimum w * , and the probability of random initialized weights fall into them. Interestingly, data augmentation may play an important role, by removing unnecessary local optima with symmetry (see Appendix B.5), focusing the learning on important patterns. Theorem 3 also gives hints. A formal study is left for future work.

4. TWO-LAYER SETTING

Now we understand how 1-layer nonlinearity learns in contrastive learning setting. In practice, many patterns exist and most of them may not be relevant for the downstream tasks. A natural question arises: how does the network prioritizes which patterns to learn? To answer this question, we analyze the behavior of 2-layer nonlinear networks with non-overlapping receptive fields (Fig. 2(a) ). Setting and Notations. In the lower layer, there are K disjoint receptive fields (abbreviated as RF) {R k }, each has input x k and M weight w km ∈ R d where m = 1 . . . M . The output of the bottom-layer is denoted as f 1 , f 1km for its km-th component, and w km ] is a block-diagonal matrix putting all projections together. Lemma 4 (Dynamics of 2-layer nonlinear network with contrastive loss). f 1 [i] for i-th sample. The top layer has weight V ∈ R dout×KM . Define S := V ⊤ V . As the (km, k ′ m ′ ) entry of the matrix S, s km,k ′ m ′ := [S] km,k ′ m ′ = v ⊤ km v k ′ m ′ . At each RF R k , define xkm as an brief notation of gated input xw km k := x k • h ′ (w ⊤ km x k ). V = V C α [f 1 ], ẇ = P ⊥ w (S ⊗ 1 d 1 ⊤ d ) • C α [ x] w (12) where 1 d is d-dimensional all-one vector, ⊗ is Kronecker product and • is Hadamard product. See Appendix C.1 for the proof. Now we analyze the stationary points of the equations. If C α [f 1 ] has unique maximal eigenvector s, then following similar analysis as in Tian (2022) , a necessary condition for (W, V ) to be a stationary point is that V = vs ⊤ , where v is any arbitrary unit vector. Therefore, we have S = V ⊤ V = ss ⊤ as a rank-1 matrix and s km,k ′ m ′ = s km s k ′ m ′ . Note that s, as a unique maximal eigenvector of C α [f 1 ] , is a function of the low-level feature computed by W . On the other hand, the stationary point of W can be much more complicated, since it has the feedback term S from the top level. A more detailed analysis requires further assumptions, as we list below: Assumption 2. For analysis of two-layer networks, we assume: • Uniform α, large batchsize and no augmentation. Then C α [g(x)] = V[g(x) ] for any function g(•) following Corollary 1. • Fast top-level training. V undergoes fast training and has always converged to its stationary point given C α [f 1 ]. That is, S = ss ⊤ is a rank-1 matrix; • Conditional Independence. The input in each R k are conditional independent given a latent global random variable z taking C different values: P[x|z] = K k=1 P[x k |z] Explanation of the assumptions. The uniform α condition is mainly for notation simplicity. For kernel-like α, the analysis is similar by combining multiple variance terms using Lemma 1. The no augmentation condition is mainly technical. With all the assumptions, we can compute the term Conclusion still holds if E paug [g(x)|x 0 ] ≈ g(E paug [x|x 0 ]) for g(x) := xw , i.e., A k (w k ) := C α [ xk ] = V[ xk ]. Our Assumption 2 is weaker than orthogonal mixture condition in Tian (2022) that is used to analyze CL, which requires the instance of input x k [i] to have only one positive component.

4.1. WHY NONLINEARITY IS CRITICAL: LINEAR ACTIVATION FAILS

Since in each RF R k , there are M filters {w km }, it would be ideal to have one filter to capture one distinct pattern in the covariance matrix A k . However, with linear activation, xk = x k and as a result, learning of diverse features never happens, no matter how large M is (proof in Appendix C.4): Theorem 4 (Gradient Colinearity in linear networks). With linear activation, W follows the dynamics: ẇkm = s km b k (W, V ) (14) where b k (W, V ) := C α x k , k ′ ,m ′ s k ′ m ′ w ⊤ k ′ m ′ x k ′ is a linear function w.r.t. W . As a result, (1) ẇkm are co-linear over m, and (2) If s km ̸ = 0, from any critical point with distinct {w km }, there exists a path of critical points to identical weights (w km = w k ). This brings about the weakness of linear activation. First, the gradient of w km within one RF R k during CL training all points towards the same direction b k ; Second, even if the critical points w km have any diversity within RF R k , there exist a path for them to converge to identical weights. Therefore, diverse features, even they reside in the data, cannot be learned by the linear models.

4.2. THE EFFECT OF GLOBAL MODULATION IN THE SPECIAL CASE OF C = 2 AND M = 1

When z is binary (C = 2) with a single weight per RF (M = 1), w k 's dynamics has close form. Let w k represent w k1 , the only weight at each R k , ∆ k := E[ xk |z = 1] -E[ xk |z = 0]. We have: Theorem 5 (Dynamics of w k under conditional independence). When C = 2 and M = 1, the dynamics of w k is given by (s 2 k and δ k ≥ 0 are scalars defined in the proof): ẇk = P ⊥ w k s 2 k A k (w k ) + δ k ∆ k ∆ ⊤ k w k (15) Generator **C*B*A*D* ACCEBDAFDG CFCFBEABDA Input sequence Input 𝒙 concat(𝑢 ! , 𝑢 " , 𝑢 " , 𝑢 # , … , 𝑢 $ ) concat(𝑢 " , 𝑢 % , 𝑢 " , 𝑢 % , … , 𝑢 ! ) G*E*B*DF** GAECBBDFGC concat(𝑢 $ , 𝑢 ! , 𝑢 # , 𝑢 " , … , 𝑢 " ) Positive pairs Negative pairs When P = 1, linear model works well regardless of the degree of over-parameterization β, while ReLU requires large over-parameterization to perform well. When each R k has multiple patterns (P > 1) related to generators, ReLU models can capture diverse patterns better than linear ones in the over-parameterization region β > 1. We found similar trend for other homogeneous activations such as LeakyReLU (with negative slope 0.05) and quadratic. In contrast, linear models are much less affected by over-parameterization. While the trends are similar, quadratic loss is not as effective as InfoNCE in feature learning. Each setting is repeated 3 times and mean/std are reported. See Appendix (Fig. 9 and Fig. 10 ) for χ-. See proof in Appendix C.3. There are several interesting observations. First, the dynamics are decoupled (i.e., ẇk = A k (W )w k ) and other w k ′ with k ′ ̸ = k only affects the dynamics of w k through the matrix A k (W ). Second, while A k (w k ) contains multiple patterns (i.e., local optima) in R k , the additional term ∆ k ∆ ⊤ k , as the global modulation from the top level, encourages the model to learn the pattern like ∆ k which is a discriminative feature that separates the event of z = 0 and z = 1. Quantitatively: Theorem 6 (Global modulation of attractive basin). If the structural assumption holds: A k (w k ) = l g(u ⊤ l w k )u l u ⊤ l with g(•) > 0 a linear increasing function and {u l } orthonormal bases, then for A k + cu l u ⊤ l , its attractive basin of w k = u l is larger than A k 's for c > 0. Therefore, if ∆ k is a LME of A k and w k is randomly initialized, Thm 5 tells that P[w k → ∆ k ] is higher than the probability that w k goes to other patterns of A k , i.e., the global variable z modulates the training of the lower layer. This is similar to "Backward feature correction" (Allen-Zhu & Li, 2020) and "top-down modulation" (Tian et al., 2019) in supervised learning, here we show it in CL. We also analyze how BatchNorm helps alleviates diverse variances among RFs (see Appendix D).

5. EXPERIMENTS

Setup. To verify our finding, we perform contrastive learning with a 2-layer network on a synthetic dataset containing token sequences, generated as follows. From a pool of G = 40 generators, we pick a generator of length K in the form of ** C * B * A * D * (here K = 10) and generate EFCDBAACDB by sampling from d = 20 tokens for each wildcard * . The final input x is then constructed by replacing each token a with the pre-defined embedding u a ∈ R d . {u a } forms a orthonormal bases (see Fig. 3 ). The data augmentation is achieved by generating another sequence from the same generator. While there exists d = 20 tokens, in each RF R k we pick a subset R g k of P < d tokens as the candidates used in the generator, to demonstrate the effect of global modulation. Before training, each generator is created by first randomly picking 5 receptive fields, then picking one of the P tokens from R g k at each RF R k and filling the remaining RFs with wildcard * . Therefore, if a token appears at R k but a / ∈ R g k , then a must be instantiated from the wildcard. Any a / ∈ R g k is noise and should not to be learned in the weights of R k since it is not part of any global pattern from the generator. We train a 2-layer network on this dataset. The 2-layer network has K = 10 disjoint RFs, within each RF, there are M = βP filters. Here β ≥ 1 is a hyper-parameter that controls the degree of over-parameterization. The network is trained with InfoNCE loss and SGD with learning rate 2 × 10 -3 , momentum 0.9, and weight decay 5 × 10 -3 for 5000 minibatches and batchsize 128. Code is in PyTorch runnable on a single modern GPU. Evaluation metric. We check whether the weights corresponding to each token is learned in the lower layer. At each RF R k , we know R g k , the subsets of tokens it contains, as well as their embeddings {u a } a∈R g k due to the generation process, and verify whether these embeddings are learned after the model is trained. Specifically, for each token a ∈ R g k , we look for its best match on the learned filter {w km }, as formulated by the following per-RF score χ + (R k ) and overall matching score χ+ ∈ [-1, 1] as the average over all RFs (similarly we can also define χfor a / ∈ R g k ): χ + (R k ) = 1 P a∈R g k max m w ⊤ km u a ∥w km ∥ 2 ∥u a ∥ 2 , χ+ = 1 K k χ + (R k ) 5.1 RESULTS Linear v.s ReLU activation and the effect of over-parameterization (Sec. 4.1). From Fig. 4 , we can clearly see that ReLU (and other homogeneous) activations achieve better reconstruction of the input patterns, when each RF contains many patterns (P > 1) and specialization of filters in each RF is needed. On the other hand, when P = 1, linear activation works better. ReLU activation clearly benefits from over-parameterization (β > 1): the larger β is, the better χ+ becomes. In contrast, for linear activation, over-parameterization does not quite affect the performance, which is consistent with our theoretical analysis. Quadratic versus InfoNCE. Fig. 4 shows that quadratic CL loss underperforms InfoNCE, while the trend of linear/ReLU and over-parameterization remains similar. According to Corollary 2, nonuniform α (e.g., Gaussian α, Lemma 1) creates more and deeper local optima that better accommodate local patterns, yielding better performance. This provides a novel landscape point of view on why non-uniform α is better, expanding the intuition that it focuses more on important sample pairs. Global modulation (Sec. 4.2). As shown in Fig. 5 , the learned weights indeed focus on the token subset R g k that receives top-down support from the generators and no noise token is learned. We also verify that quantitatively by computing χover multiple runs, provided in Appendix (Fig. 9 -10) .

A PROOFS

A.1 PROBLEM SETUP (SEC. 2) Lemma 1 (Gaussian α). For any function g(•) that is bounded below, if we use α ij := exp(-∥g(x 0 [i]) -g(x 0 [j])∥ 2 2 /2τ ) as the pairwise importance, then it has kernel structure (Def. 1). Proof. Since g(•) is bounded below, there exists a vector v so that each component of g(x) -v is always nonnegative for any x. Let y [i] := g(x 0 [i]) -v ∈ R d , then y[i] ≥ 0 and we have: α ij = exp - ∥y[i] -y[j]∥ 2 2 2τ (17) = exp - ∥y[i]∥ 2 2 2τ exp - ∥y[j]∥ 2 2 2τ exp y ⊤ [i]y[j] τ And using Taylor expansion, we have exp y ⊤ [i]y[j] τ = 1 + y ⊤ [i]y[j] τ + 1 2 y ⊤ [i]y[j] τ 2 + . . . + 1 k! y ⊤ [i]y[j] τ k + . . . (19) Let φ(y) :=         1 τ -1/2 y 1 √ 2! AllChoose(τ -1/2 y, 2) . . . 1 √ k! AllChoose(τ -1/2 y, k) . . .         ≥ 0 be an infinite dimensional vector, where AllChoose(y, k) is a d k -dimensional column vector that enumerates all possible d k products y i1 y i2 . . . y i k , where 1 ≤ i k ≤ d and y i is the i-th component of y. Then it is clear that exp(y ⊤ [i]y[j]/τ ) = φ⊤ (y[i]) φ(y[j] ) and thus α ij = ϕ ⊤ (x 0 [i])ϕ(x 0 [j]) = +∞ l=0 ϕ l (x 0 [i])ϕ l (x 0 [j]) which satisfies Def. 1. Here ϕ(x) := exp - ∥y∥ 2 2 2τ φ(y) = exp - ∥g(x) -v∥ 2 2 2τ φ(g(x) -v) is the infinite dimensional feature mapping for input x, and ϕ l (x) is its l-th component. Lemma 2 (Relationship between Contrastive Covariance and Variance in large batch size). If α satisfies Def. 1, then for any function g(•), C α [g(x) ] is asymptotically PSD when N → +∞: C α [g(x)] → l z 2 l V x0∼ pl (•;α) E x∼paug(•|x0) [g(x)|x 0 ] (4) Proof. First let C inter α [a, b] := 1 2N 2 N i=1 j̸ =i α ij (a[i] -a[j])(b[i] -b[j]) ⊤ (23) C intra α [a, b] := 1 2N N i=1   1 N j̸ =i α ij   (a[i] -a[i ′ ])(b[i] -b[i ′ ]) ⊤ (24) and C inter α [a] := C inter α [a, a], C intra α [a] := C inter α [a, a]. Then we have C α [g] = C inter α [g] -C intra α [g]. Published as a conference paper at ICLR 2023 With the condition, for the first term C inter α [g], we have C inter α [g] = 1 2N 2 ij K(x 0 [i], x 0 [j])(g(x[i]) -g(x[j]))(g(x[i]) -g(x[j])) ⊤ (26) When N → +∞, we have: C inter α [g] → 1 2 K(x 0 , y 0 )(g(x) g(y))(g(x) -g(y)) ⊤ P(x, x 0 )P(y, y 0 )dxdydx 0 dy 0 We integrate over x 0 and y 0 first: (g(x) -g(y))(g(x) -g(y)) ⊤ P(x|x 0 )P(y|y 0 )dxdy (27) = E •|x0 [gg ⊤ ] + E •|y0 [gg ⊤ ] -E •|x0 [g]E •|y0 [g ⊤ ] -E •|y0 [g]E •|x0 [g ⊤ ] (28) We now compute the four terms separately. With the condition that K(x 0 , y 0 ) = l ϕ l (x 0 )ϕ l (y 0 ), and the definition of adjusted probability pl (x) := 1 z l ϕ l (x)P(x) where z l := ϕ l (x)P(x)dx, for the first term, we have: ϕ l (x 0 )ϕ l (y 0 )E •|x0 [gg ⊤ ]P(x 0 )P(y 0 )dx 0 dy 0 = z 2 l E •|x0 [gg ⊤ ]p l (x 0 )dx 0 (29) = z 2 l E x0∼ pl E •|x0 [gg ⊤ ] (30) So we have: C inter α [g] → l z 2 l E x0∼ pl E •|x0 [gg ⊤ ] -E x0∼ pl E •|x0 [g]E x0∼ pl E •|x0 [g ⊤ ] (31) = l z 2 l V x0∼ pl ,x∼paug(•|x0) [g] On the other hand, for C intra α [g], when N → +∞, we have: 1 N j̸ =i α ij = 1 N j̸ =i K(x 0 [i], x 0 [j]) → K(x 0 , y 0 )P(y 0 )dy 0 (33) = l ϕ l (x 0 ) ϕ l (y 0 )P(y 0 )dy 0 = l z l ϕ l (x 0 ) Therefore, we have: C intra α [g] → 1 2 l z l ϕ l (x 0 )(g(x) -g(x ′ ))(g(x) -g(x ′ )) ⊤ P(x, x ′ |x 0 )P(x 0 )dxdx ′ dx 0 (35) Similarly, (g(x) -g(x ′ ))(g(x) -g(x ′ )) ⊤ P(x, x ′ |x 0 )dxdx ′ (36) = 2 g(x)g ⊤ (x)P(x|x 0 )dx -2 g(x)P(x|x 0 )dx g ⊤ (x ′ )P(x ′ |x 0 )dx ′ (37) = 2E x∼paug(•|x0) [gg ⊤ ] -2E x∼paug(•|x0) [g]E x∼paug(•|x0) [g ⊤ ] (38) = 2V x∼paug(•|x0) [g] (39) So we have: C intra α [g] → 1 2 l z l ϕ l (x 0 )2V x∼paug(•|x0) [g]P(x 0 )dx 0 (40) = l z 2 l E x0∼ pl V x∼paug(•|x0) [g] Using the law of total variation, finally we have: C α [g] → l z 2 l V x0∼ pl E x∼paug(•|x0) [g] B ONE-LAYER MODEL (SEC. 3)

B.1 COMPUTATION OF THE TWO EXAMPLE MODELS

Here we assume ReLU activation h(x) := max(x, 0), which is a homogeneous activation h(x) = h ′ (x)x. Note that we consider h ′ (0) = 0. Therefore, for any sample x, if w ⊤ x = 0, then we don't consider it to be included in the active region of ReLU, i.e., xw = x • h ′ (w ⊤ x) = 0. Let z be a hidden binary variable and we could compute A(w) (here p 0 := P[z = 0] and p 1 := P[z = 1]): V[ xw ] = V z [E[ xw |z]] + E z [V[ xw |z]] = p 0 p 1 ∆(w)∆ ⊤ (w) + p 0 Σ 0 (w) + p 1 Σ 1 (w) (43) where ∆(w) := E[ x|z = 1] -E[ x|z = 0] and Σ z (w) := V[ x|z]. Latent categorical model. If w = u m , let z := I(y = m). This leads to Σ 1 (u m ) = Σ 0 (u m ) = 0 and ∆(u m ) = u m . Therefore, we have: A(w) w=um := C α [ xw ] = V[ xw ] = P[y = m] (1 -P[y = m]) u m u ⊤ m (44) Latent summation model. If w = u m , first notice that due to orthogonal constraints we have w ⊤ x = m ′ y m ′ u ⊤ m ′ w = y m . Let z := I(y m > 0), then we can compute ∆(u m ) = y + m u m , Σ 1 (u m ) = I -u m u ⊤ m and Σ 0 (u m ) = 0. Therefore, we have: A(w) w=um := C α [ xw ] = V [ x] = (1 -q m ) 2 u m u ⊤ m + q m (I -u m u ⊤ m )

B.2 DERIVATION OF TRAINING DYNAMICS

Lemma 3 (Training dynamics of 1-layer network with homogeneous activation in contrastive learning). The gradient dynamics of Eqn. 5 is (note that α is treated as an independent variable): ẇk = P ⊥ w k A(w k )w k Here P ⊥ w k := I -w k w ⊤ k projects a vector into the complementary subspace spanned by w k . Proof. First of all, it is clear that from Eqn. 5, each w k evolves independently. Therefore, we omit the subscript k and derive the dynamics of one node w. To compute the training dynamics, we only need to compute the differential of C α [h(w ⊤ k x)]. We use matrix differential form (Giles, 2008) to make the derivation easier to understand. Note that for one-layer network with K = 1 nodes, E(w ) := 1 2 C α [h(w ⊤ x)] = 1 2 C α [h(w ⊤ x), h(w ⊤ x) ] be the objective function to be maximized. Using the fact that y ] is a bilinear form (linear w.r.t x and y) given fixed α, • C α [x, • for any vector a and b, we have a ⊤ C α [x, y]b = C α [a ⊤ x, b ⊤ y], • for scalar x and y, C α [x, y] = C α [y, x], and by the product rule d(x • y) = dx • y + x • dy, we have: dE = 1 2 C α [h(w ⊤ x), h ′ (w ⊤ x)dw ⊤ x] + 1 2 C α [h ′ (w ⊤ x)dw ⊤ x, h(w ⊤ x)] = C α [h(w ⊤ x), h ′ (w ⊤ x)x]dw 46) Now use the homogeneous condition (Assumption 1) for activation h: h(x) = h ′ (x)x, which gives h(w ⊤ x) = h ′ (w ⊤ x)w ⊤ x, therefore, we have: dE = w ⊤ C α [h ′ (w ⊤ x)x, h ′ (w ⊤ x)x]dw = w ⊤ A(w)dw (47) where A(w) := C α [h ′ (w ⊤ x)x, h ′ (w ⊤ x)x] = C α [ xw , xw ] . Therefore, by checking the coefficient associated with the differential form dw, we know ∂E ∂w = A(w)w. By gradient ascent, we have ẇ = A(w)w. Since w has the additional constraint ∥w∥ 2 = 1, the final dynamics is ẇ = P ⊥ w A(w)w where P ⊥ w := I -ww ⊤ is a projection matrix that projects a vector into the orthogonal complement subspace of the subspace spanned by w. Proof. Suppose w * and its local perturbation w are on the unit sphere ∥w∥ 2 = ∥w * ∥ 2 = 1. Since w is a local perturbation, we have w ⊤ w * ≥ 1 -ϵ for ϵ ≪ 1. In the following we will check how we bound ∥(A(w) -A(w * ))w * ∥ 2 in terms of ∥w -w * ∥ 2 and then we can get the upper bound of local roughness metric ρ(w * ). Let the function g(x) := xw , apply Corollary 1 with no augmentation and the large batch limits, we have A(w) := C α [ xw ] = l z 2 l V pl [ xw ]. ( ) where pl (x) = 1 z l P(x)ϕ l (x) is the probability distribution of the input x, adjusted by the mapping of the kernel function determined by the pairwise importance α ij (Def. 1). z l is its normalization constant. To study (A(w) -A(w * ))w * , we will study each component (V pl [ xw ] -V pl [ xw * ]) w * . Note that since xw := xI(w ⊤ x ≥ 0), we have V pl [ xw ] = E pl [xx ⊤ I(w ⊤ x ≥ 0)]-E pl [xI(w ⊤ x ≥ 0)]E pl [x ⊤ I(w ⊤ x ≥ 0)]. Let e := w ⊤ x≥0 xp l (x)dx, e * := w ⊤ * x≥0 xp l (x)dx E := w ⊤ x≥0 xx ⊤ pl (x)dx, E * := w ⊤ * x≥0 xx ⊤ pl (x)dx So we can write  V pl [ xw ] = E -ee ⊤ , V pl [ xw * ] = E * -e * e ⊤ * Ω + := {x : w ⊤ * x ≥ 0, w ⊤ x ≤ 0} ) Ω -:= {x : w ⊤ * x ≤ 0, w ⊤ x ≥ 0} (71) Ω := Ω + ∪ Ω - Now let's bound (E -E * )w * and (e * e ⊤ * -ee ⊤ )w * . Bound (E -E * )w * . We have: E -E * = Ω- xx ⊤ pl (x)dx - Ω+ xx ⊤ pl (x)dx and thus (E -E * )w * = Ω- xx ⊤ w * pl (x)dx - Ω+ xx ⊤ w * pl (x)dx For any x ∈ Ω + , we have: 0 ≤ w ⊤ * x = w ⊤ x + (w * -w) ⊤ x ≤ (w * -w) ⊤ x ≤ C 0 ∥w * -w∥ 2 (75) Therefore, |w ⊤ * x| ≤ M ∥w * -w∥ 2 and we have Ω+ xx ⊤ w * pl (x)dx 2 ≤ Ω+ |w ⊤ * x|∥x∥ 2 pl (x)dx (76) ≤ C 2 0 ∥w * -w∥ 2 max x∈Ω+ pl (x) Ω+,∥x∥2≤C0 dx (77) = C 3 0 ∥w * -w∥ 2 max x∈Ω+ pl (x) vol(C 0 ) 2π arccos w ⊤ w * ( ) where vol(C 0 ) is the volume of the d-dimensional ball of radius C 0 . Similarly for x ∈ Ω -, we have 0 ≥ w ⊤ * x = w ⊤ x + (w * -w) ⊤ x ≥ (w * -w) ⊤ x ≥ -C 0 ∥w * -w∥ 2 (79) hence |w ⊤ * x| ≤ C 0 ∥w * -w∥ 2 and overall we have: ∥(E -E * )w * ∥ 2 ≤ C 3 0 vol(C 0 ) π ∥w * -w∥ 2 max x∈Ω pl (x) arccos w ⊤ w * (80) Since for x ∈ (0, 1], arcsin √ 1 -x 2 ≤ √ 1-x 2 x , we have: arccos w ⊤ w * = arcsin 1 -(w ⊤ w * ) 2 ≤ 1 -(w ⊤ w * ) 2 w ⊤ w * (81) = 1 + w ⊤ w * 1 -w ⊤ w * ) w ⊤ w * ≤ 2(1 -w ⊤ w * ) w ⊤ w * (82) = 1 1 -ϵ ∥w -w * ∥ 2 (83) we have: ∥(E -E * )w * ∥ 2 ≤ C 3 0 vol(C 0 ) π 1 1 -ϵ ∥w * -w∥ 2 2 max x∈Ω pl (x) Therefore, ∥(E -E * )w * ∥ 2 is a second-order term w.r.  xp l (x)dx - Ω+ xp l (x)dx Using similar derivation, we conclude that ∥e(e -e * ) ⊤ w * ∥ 2 is also a second-order term. The only first-order term is ∥(e -e * )e ⊤ * w * ∥ 2 : ∥(e -e * )e ⊤ * w * ∥ 2 ≤ E pl [h(w ⊤ x)] Ω ∥x∥ 2 pl (x)dx (87) ≤ C 2 0 Ω pl (x)dx ≤ C 2 0 max x∈Ω pl (x) Ω:∥x∥2≤C0 dx (88) ≤ C 3 0 vol(C 0 ) π arccos w ⊤ w * max x∈Ω pl (x) ≤ C 3 0 vol(C 0 ) π 1 1 -ϵ ∥w -w * ∥ 2 max x∈Ω pl (x) Overall we have: ∥(A(w) -A(w * ))w * ∥ 2 ≤ l z 2 l ∥ (V pl [ xw ] -V pl [ xw * ]) w * ∥ 2 (91) ≤ C 3 0 vol(C 0 ) π 1 1 -ϵ l z 2 l max x∈Ω pl (x) ∥w -w * ∥ 2 + O(∥w -w * ∥ 2 2 ) Since ρ(w * ) is the smallest scalar that makes the local roughness metric hold and ϵ is arbitrarily small, we have: ρ(w * ) ≤ C 3 0 vol(C 0 ) π r(w * , α) where r(w, α) := l z 2 l max w ⊤ x=0 pl (x; α). Corollary 2 (Effect of different α). For uniform α u (α ij := 1) and 1-D Gaussian α g (α ij := exp(-∥h(w ⊤ x 0 [i]) -h(w ⊤ x 0 [j])∥ 2 2 /2τ )), we have r(w * , α g ) = z 0 (α g )r(w * , α u ) with z 0 (α g ) := exp(-h 2 (w ⊤ * x)/2τ )p D (x)dx ≤ 1. As a result, z 0 (α g ) ≪ 1 leads to r(w * , α g ) ≪ r(w * , α u ). Proof. For uniform α u , it is clear that the mapping ϕ u (x) ≡ 1 is 1-dimensional. Therefore, p0 (x; α u ) := 1 z0(αu) ϕ u0 (x)p D (x) = p D (x) with z 0 (α u ) = ϕ u0 (x)p D (x)dx = 1. This means that r(w * , α u ) := +∞ l=0 z 2 l (α u ) max w ⊤ * x=0 pl (x; α u ) (94) = z 2 0 (α u ) max w ⊤ * x=0 p0 (x; α u ) = max w ⊤ * x=0 p D (x) For Gaussian α g , from Lemma 1 we know that its infinite-dimensional mapping ϕ g (x) has the following form for w = w * : ϕ g (x) = e -h 2 (w ⊤ * x) 2τ         1 τ -1/2 h(w ⊤ * x) 1 τ 2/2 √ 2! h 2 (w ⊤ * x) . . . 1 τ k/2 √ k! h k (w ⊤ * x) . . .         (96) When l ≥ 1, z 2 l pl (x; α g ) = z l ϕ gl (x)p D (x) = 0 for any x on the plane w ⊤ * x = 0, since ϕ gl (x) = 0 on the plane. On the other hand, ϕ g0 (x) = e -h 2 (w ⊤ * x) 2τ . On the plane, ϕ g0 (x) = 1 and is a constant. Therefore, we have: r(w * , α g ) := +∞ l=0 z 2 l max w ⊤ * x=0 pl (x; α g ) = z 2 0 (α g ) max w ⊤ * x=0 p0 (x; α g ) (97) = z 0 (α g ) max w ⊤ * x=0 ϕ g0 (x)p D (x) = z 0 (α g ) max Notation. Let λ i (w) and ϕ i (w) be the i-th eigenvalue and unit eigenvector of A(w) where ϕ 1 (w) is the largest. We first assume A(w) is positive definite (PD) and then remove this assumption later. In this case, λ 1 (w) ≥ λ 2 (w) ≥ . . . ≥ λ d (w) > 0. Let c(w) := w ⊤ ϕ 1 (w) be the inner product between w and the maximal eigenvector of A(w). Consider the following Power Iteration (PI) format: w(t + 1) ← A(w(t))w(t), w(t + 1) ← w(t + 1) ∥ w(t + 1)∥ 2 (101) Along the trajectory, let ϕ i (t) := ϕ 1 (A(w(t))) be the i-th unit eigenvector of A(w(t)) and λ i (t) to be the i-th eigenvalue. Define δw(t) := w(t + 1) -w(t), δA(t) := A(w(t + 1)) -A(w(t)), and c t := c(w(t)) = ϕ ⊤ 1 (t)w(t), d t := ϕ ⊤ 1 (t)w(t + 1) Then -1 ≤ c t , d t ≤ 1 since they are inner product of two unit vectors. Theorem 3 (Existence of critical points). Let c 0 := c(w(0)) ̸ = 0. If there exists γ < 1 so that: sup w∈Bγ ω(w) ≤ γ, where B γ := w : w ⊤ w(0) ≥ c0-cγ 1-cγ , c γ := 2 √ γ 1+γ is the neighborhood of initial value w(0). Then Power Iteration (Eqn. PI) converges to a critical point w * ∈ B γ of Eqn. 7. Due to stop gradient, α and thus ϕ l (•; α) is treated as a constant term when checking the local property of the current parameters w. This means that in the local neighborhood of w, E α (w) = E α (R(t ′ )w).

Now notice an important observation: if w

′ := R(t ′ )w ̸ = w, then E α (w) = E α (R(t ′ )w) = E α (w ′ ) and therefore, w cannot be a local optimal. Intuitively, this means that the data augmentation can remove certain local optima of w, if they are not locally invariant (i.e., R(t ′ )w ̸ = w) to the transformation of the data augmentation. Therefore, augmentation removes certain patterns in the input data and their local optima in the training, to only keep patterns (local optima) that are most relevant to the tasks. Here we only use 1-dimensional rotation group as one simple example. In practice, the augmentation may not globally form a Lie group, and there could be multiple different types of augmentations, yielding high-dimensional transformation space. Therefore, we may use Lie algebra instead to capture the local transformation structure, without making assumptions about the global structure. We will give a formal study in the future work.

C TWO LAYER CASE (SEC. 4)

C.1 LEARNING DYNAMICS Lemma 4 (Dynamics of 2-layer nonlinear network with contrastive loss). V = V C α [f 1 ], ẇ = P ⊥ w (S ⊗ 1 d 1 ⊤ d ) • C α [ x] w where 1 d is d-dimensional all-one vector, ⊗ is Kronecker product and • is Hadamard product. Proof. The output of the 2-layer network can be written as the following: f 2l = k v lk h(w ⊤ k x k ) For convenience, we use f 1 := [h(w ⊤ k x k )] to represent the column vector that collects all the outputs of intermediate nodes, and v ⊤ l is the l-th row vector in V . According to Theorem 1 in Tian (2022), the gradient descent direction of contrastive loss corresponds to the gradient ascent direction of the energy function E α (θ). From Eqn. 25 of that theorem, we have: ∂E ∂θ = l C α ∂f 2l ∂θ , f 2l Therefore, for V = [v ik ] we have: Now we compute ∂E/∂w k : vi = ∂E ∂v i = l C α ∂f 2l ∂v i , f 2l (111) = C α f 1 , v ⊤ i f 1 (112) = C α [f 1 , f 1 ] v i (113) So we have vi = C α [f 1 ]v i , or V = V C α [f 1 ]. 𝕍 " 𝒙 = 𝔼 ! 𝕍 " 𝒙 𝑧 + 𝕍 ! 𝔼[" 𝒙|𝑧] ẇk = ∂E ∂w k = l C α ∂f 2l ∂w k , f 2l (114) = l C α [v lk h ′ (w ⊤ k x k )x k , v ⊤ l f 1 ] (115) = l v lk C α [ xk , v ⊤ l f 1 ] (116) = l v lk C α xk , k ′ v lk ′ h(w ⊤ k ′ x k ′ ) (117) = k ′ l v lk v lk ′ C α [ xk , xk ′ ] w k ′ (118) = k ′ s kk ′ C α [ xk , xk ′ ] w k ′ ( ) where S = [s kk ′ ] = V ⊤ V = l v l v ⊤ l . Let w := [w 1 ; . . . ; w K ] and it leads to the conclusion. When M > 1, the proof is similar.

C.2 VARIANCE DECOMPOSITION

Let p c := P[z = c] be the probability that the latent variable z takes categorical value c. Lemma 5 (Close-form of variance under Assumption 2). With Assumption 2, we have V[ x] = diag k [L k ] + C-1 c=0 p c (1 -p c ) 2 ∆(c)∆ ⊤ (c) ( ) where L k := E z V[ xk |z] ∈ R M d and ∆(c) := E[ x|z = c] -E[ x|z ̸ = c] ∈ R M Kd . In particular when C = 2, the second term becomes p 0 p 1 ∆∆ ⊤ , a rank-1 matrix. Here ∆ := ∆(0) for brevity. Proof. Use variance decomposition, we have: V[ x] = E z V[ x|z] + V z E[ x|z] (121) Remember that xkm is an abbreviation of gated input: xkm := xw km k := x k • h ′ (w ⊤ km x k ) By conditional independence, we have Cov[ xkm , xk ′ m ′ |z] = 0 ∀k ̸ = k ′ (123) This is because xkm and xk ′ m ′ are deterministic functions of x k and x k ′ and thus are also independent of each other. Let xk :=    xk1 xk2 . . . xkM    ∈ R M d (124) and L k := E z V[ xk |z]. Then we know that E z V[ x|z] = diag k [L k ] is a block diagonal matrix (See Fig. 6 ). On the other hand, V z E[ x|z] is a low-rank matrix: V z E[ x|z] = E z (E[ x|z] -E[ x])(E[ x|z] -E[ x]) ⊤ (125) Let q c := E[ x|z = c] and q -c := E[ x|z ̸ = c], then we have: E[ x|z = c] -E[ x] = q c - c p c q c = (1 -p c )   q c - c ′ ̸ =c p c ′ 1 -p c q c ′   (126) = (1 -p c )   q c - c ′ ̸ =c P[z = c ′ |z ̸ = c]q c ′   (127) = (1 -p c )(q c -q -c ) Therefore, we have: V z E[ x|z] = E z (E[ x|z] -E[ x])(E[ x|z] -E[ x]) ⊤ (129) = c p c (1 -p c ) 2 (q c -q -c )(q c -q -c ) ⊤ (130) = c p c (1 -p c ) 2 ∆(c)∆ ⊤ (c) where ∆(c) := ∆(c; W ) := q c -q -c = ∆ 11 (c) . . . ∆ KM (c) ∈ R KM d (132) and ∆ km (c) := ∆ km (c; w km ) := E[ xkm |z = c] -E[ xkm |z ̸ = c] We can see that V z E[ x|z] is at most rank-C, since it is a summation of C rank-1 matrix. In particular, when C = 2, it is clear that ∆(0) = -∆(1) and thus ∆(0)∆ ⊤ (0) = ∆(1)∆ ⊤ (1) and c p c (1 -p c ) 2 = p 0 p 2 1 + p 1 p 2 0 = p 0 p 1 . Hence the conclusion. C.3 GLOBAL MODULATION WHEN C = 2 AND M = 1 Theorem 5 (Dynamics of w k under conditional independence). When C = 2 and M = 1, the dynamics of w k is given by (s 2 k and δ k ≥ 0 are scalars defined in the proof): ẇk = P ⊥ w k s 2 k A k (w k ) + δ k ∆ k ∆ ⊤ k w k (15) Proof. Since M = 1, each receptive field (RF) R k only output a single node with output f k . Let: L k := E z V[ xk |z] ( ) d k := w ⊤ k L k w k = E z V[f k |z] ≥ 0 (135) D := diag k [d k ] (136) b := [b k ] := [w ⊤ k ∆ k ] ∈ R K (137) and λ be the maximal eigenvalue of V[f 1 ]. Here L k is a PSD matrix and D is a diagonal matrix. Then V[f 1 ] = D + p 0 p 1 bb ⊤ ( ) is a diagonal matrix plus a rank-1 matrix. Since p 0 p 1 bb ⊤ is always PSD, λ = λ max (V[f 1 ]) ≥ λ max (D) = max k d k . Then using Bunch-Nielsen-Sorensen formula (Bunch et al., 1978) , for largest eigenvector s, we have: s k = 1 Z b k d k -λ (139) where λ is the corresponding largest eigenvalue satisfying 1 + p 0 p 1 k b 2 k d k -λ = 0, and Z = k b k d k -λ 2 . Note that the above is well-defined, since if k * = arg max k d k and b k * ̸ = 0, then λ > max k d k = d k * . So b k /(d k -λ) won't be infinite. So we have: ẇk = k ′ s k s k ′ C α [ xk , xk ′ ]w k ′ (140) = k ′ s k s k ′ (L k I(k = k ′ ) + p 0 p 1 ∆ k ∆ ⊤ k ′ )w k ′ = s 2 k V[ xk ]w k + p 0 p 1 s k ∆ k k ′ ̸ =k s k ′ ∆ ⊤ k ′ w k ′ (141) = s 2 k V[ xk ]w k + p 0 p 1 b k Z 2 (d k -λ) ∆ k k ′ ̸ =k b 2 k ′ d k ′ -λ = s 2 k V[ xk ]w k + δ k ∆ k ∆ ⊤ k w k = s 2 k V[ xk ] + δ k ∆ k ∆ ⊤ k w k where δ k := p 0 p 1 Z 2 (λ -d k ) k ′ ̸ =k b 2 k ′ λ -d k ′ (143) Since λ ≥ max k d k , we have δ k ≥ 0 and thus the modulation term is non-negative. Note that since p 0 p 1 k b 2 k λ-d k = 1, we can also write δ k = 1 - p0p1b 2 k λ-d k . Theorem 6 (Global modulation of attractive basin). If the structural assumption holds: A k (w k ) = l g(u ⊤ l w k )u l u ⊤ l with g(•) > 0 a linear increasing function and {u l } orthonormal bases, then for A k + cu l u ⊤ l , its attractive basin of w k = u l is larger than A k 's for c > 0. Proof. Since A k (w) = l g(u ⊤ l w)u l u ⊤ l , we could write down its dynamics (we omit the projection P ⊥ w for now): ẇ = A k (w)w = l g(u ⊤ l w)u l u ⊤ l w Let y l (t) := u ⊤ l w(t), i.e., y l (t) is the projected component of the weight w(t) onto the l-th direction, i.e., a change of bases to orthonormal bases {u l }, then the dynamics above can be written as ẏl = g(y l )y l (145) which is the same for all l, so we just need to study ẋ = g(x)x. g(x) > 0 is a linear increasing function, so we can assume g(x) = ax + b with a > 0. Without loss of generality, we could just set a = 1. Then we just want to analyze the dynamics: ẏl = (y l + b l )y l , b l > 0 (146) which also includes the case of A k + cu l u ⊤ l , that basically sets b l = b + c. Solving the dynamics leads to the following close-form solution:  y l (t) y l (t) + b l = y l (0) y l (0) + b l e b l t (147) 𝑦 ! (𝑡) 𝑂 -𝑏 ! 𝑡 𝑡 ! * = 1 𝑏 ! ln 1 + 𝑏 ! 𝑦 ! 0 𝒚 𝒍 𝟎 > 𝟎 𝒚 𝒍 𝟎 < 𝟎 * l = 1 b l ln 1 + b l y l (0) . The 1-d dynamics has an unstable fixed points y l = 0 and a stable one y l = -b l < 0. Therefore, when the initial condition y l (0) < 0, the dynamics will converge to y l (+∞) = -b l , which is a finite number. On the other hand, when y l (0) > 0, the dynamics has finite-time blow-up Thompson et al. (1990) ; Goriely & Hyde (1998) , i.e., there exists a critical time t * l < +∞ so that y l (t * l ) = +∞. See Fig. 7 . Note that this finite time blow-up is not physical, since we don't take into consideration of normalization Z(t), which depends on all y l (t). The real quality to be considered is ŷl (t) = 1 Z(t) y l (t). Fortunately, we don't need to estimate Z(t) since we are only interested in the ratio: r l/l ′ (t) := ŷl (t) ŷl ′ (t) = y l (t) y l ′ (t) If for some l and any l ′ ̸ = l, r l/l ′ (t) → ∞, then y l (t) dominates and ŷl (t) → 1, i.e., the dynamics converges to u l . Now our task is to know which initial condition of y l and b l makes r l/l ′ (t) → +∞. By comparing the critical time we know which component l shoots up the earliest and that l * = arg min l t * l is the winner, without computing the normalization constant Z(t). The critical time satisfies y l (0) y l (0) + b l e b l t * l = 1 (149) so t * l = 1 b l ln 1 + b l y l (0) It is clear that when y l (0) is larger, the critical time t * l becomes smaller and the l-th component becomes more advantageous over other components. For b l > 0, we have: ∂t * l ∂b l = 1 b 2 l b l /y l (0) 1 + b l /y l (0) -ln(1 + b l /y l (0)) < 0 (151) where the last inequality is due to the fact that x 1+x < ln(1 + x) for x > 0. Therefore, larger b l leads to smaller t * l . Since adding cu l u ⊤ l with c > 0 to A k increase b l , it leads to smaller t * l and thus increases the advantage of the l-th component. Therefore, larger b l and larger y l (0) both leads to smaller t * l . For the same t * l , larger b l can trade for smaller y l (0), i.e., larger attractive basin. Remark. Special case. We start by assuming only one ϵ l ̸ = 0 and all other ϵ l ′ = 0 for l ′ ̸ = l, and then we generalize to the case when all {ϵ l } are real numbers. To quantify the probability that a random weight initialization leads to convergence of u l , we setup some notations. Let the event E l be "a random weight initialization of y leads to y → e l ", or equivalently w → u l . Let Y l be the random variable that instantiates the initial value of y l (0) due to random weight initialization. Then the convergence event E l is equivalent to the following: (1) Y l > 0 (so that the l-component has the opportunity to grow), and (2) Y l + ϵ l is the maximum over all Y l ′ for any l ′ ̸ = l, where ϵ l is an advantage (> 0) or disadvantage (< 0) achieved by having larger/smaller b l due to global modulation (e.g., c). Therefore, we also call ϵ l the modulation factor. Here we discuss about a simple case that Y l ∼ U [-1, 1] and for l ′ ̸ = l, Y l and Y l ′ are independent. In this case, for a given l, max l ′ ̸ =l Y l ′ is a random variable that is independent of Y l , and has cumulative density function (CDF) F max (x) := P[max l ′ ̸ =l Y l ′ ≤ x] = F d-1 (x), where F (x) is the CDF for Y l . Then we have: P[E l ] = P Y l > 0, Y l + ϵ l ≥ max l ′ ̸ =l Y l ′ (152) = +∞ 0 P max l ′ ̸ =l Y l ′ ≤ Y l + ϵ l Y l = y l P[Y l = y l ]dy l (153) = +∞ 0 F d-1 (y l + ϵ l )dF (y l ) (154) When Y l ∼ U [-1, 1], F (x) = min 1 2 (x + 1 ), 1 has a close form and we can compute the integral: P[E l ] = P Y l > 0, Y l + ϵ l ≥ max l ′ ̸ =l Y l ′ =      1 2 ϵ l > 1 1 d 1 -1+ϵ l 2 d + ϵ l 2 0 ≤ ϵ l ≤ 1 1 d (1 + ϵ l 2 ) d -( 1+ϵ l 2 ) d -1 < ϵ l < 0 We can see that the modulation factor ϵ l plays an important role in deciding the probability that w → u l : • No modulation. If ϵ l = 0, then P[E l ] ∼ 1 d . This means that each dimension of y has equal probability to be the dominant component after training; • Positive modulation. If ϵ l > 0, then P[E l ] ≥ ϵ l 2 , and that particular l-th component has much higher probability to become the dominant component, independent of the dimensionality d. Furthermore, the stronger the modulation, the higher the probability becomes. • Negative modulation. Finally, if ϵ l < 0, since 1 + ϵ l /2 < 1, P[E l ] ≤ 1 d (1 + ϵ l 2 ) d decays exponentially w.r.t the dimensionality d. General case. We then analyze cases if all ϵ l are real numbers. Let l * = arg max l ϵ l and c(k) be the k-th index of ϵ l in descending order, i.e., c(1) = l * . • For l = c(1) = l * , ϵ l is the largest over {ϵ l }. Since 2) is the gap between the largest ϵ l and second largest ϵ l . Then this case is similar to positive modulation and thus P[E l ] = P Y l ≥ 0, Y l + ϵ l ≥ max l ′ ̸ =l Y l ′ + ϵ l ′ ≥ P Y l ≥ 0, Y l + ϵ c(1) -ϵ c(2) ≥ max l ′ ̸ =l Y l ′ where ϵ c(1) -ϵ c( P[E c(1) ] ≥ 1 2 ϵ c(1) -ϵ c(2) • For l with rank r (i.e., c(r) = l), and any r ′ < r, we have: P[E l ] = P Y l ≥ 0, Y l + ϵ l ≥ max l ′ ̸ =l Y l ′ + ϵ l ′ ≤ P Y l ≥ 0, Y l + ϵ l ≥ max l ′ :c -1 (l ′ )≤r ′ Y l ′ + ϵ l ′ = P Y l ≥ 0, Y l + ϵ l -ϵ c(r ′ ) ≥ max l ′ :c -1 (l ′ )≤r ′ Y l ′ + ϵ l ′ -ϵ c(r ′ ) ≤ P Y l ≥ 0, Y l + ϵ l -ϵ c(r ′ ) ≥ max l ′ :c -1 (l ′ )≤r ′ Y l ′ Then it reduces to the case of negative modulation. Therefore, we have: P[E c(r) ] ≤ min r ′ <r 1 r ′ + 1 1 - ϵ c(r ′ ) -ϵ c(r) 2 r ′ +1 and the probability is exponentially small if r is large, i.e., ϵ l ranks low.

C.4 FUNDAMENTAL LIMITATION OF LINEAR MODELS

Theorem 4 (Gradient Colinearity in linear networks). With linear activation, W follows the dynamics: ẇkm = s km b k (W, V ) where b k (W, V ) := C α x k , k ′ ,m ′ s k ′ m ′ w ⊤ k ′ m ′ x k ′ is a linear function w.r.t. W . As a result, (1) ẇkm are co-linear over m, and (2) If s km ̸ = 0, from any critical point with distinct {w km }, there exists a path of critical points to identical weights (w km = w k ). Proof. In the linear case, we have xkm = x k since there is no gating and all M shares the same input x k . Therefore, we can write down the dynamics of w km as the following: ẇkm = k ′ ,m ′ s km,k ′ m ′ C α [ xkm , xk ′ m ′ ]w k ′ m ′ (158) = k ′ ,m ′ s km,k ′ m ′ C α [x k , x k ′ ]w k ′ m ′ Now we use the fact that the top-level learns fast so that s km,k ′ m ′ = s km s k ′ m ′ , which gives: ẇkm = s km k ′ ,m ′ s k ′ m ′ C α [x k , x k ′ ]w k ′ m ′ (160) = s km C α   x k , k ′ ,m ′ s k ′ m ′ w ⊤ k ′ m ′ x k ′   (161) Let b k (W, V ) := C α x k , k ′ ,m ′ s k ′ m ′ w ⊤ k ′ m ′ x k ′ be a linear function of W , and we have: ẇkm = s km b k (W, V ) (162) Since b k is independent of m, all ẇkm are co-linear. For the second part, first all if W * is a critical point, we have the following two facts: • Since there exists m so that s km ̸ = 0, we know that b k (W * ) = 0; • If W * contains two distinct filters w k1 = µ 1 ̸ = w k2 = µ 2 covering the same receptive field R k , then by symmetry of the weights, W ′ * in which w k1 = µ 2 and w k2 = µ 1 , is also a critical point. Then for any c ∈ [0, 1], since b k (W ) is linear w.r.t. W , for the linear combination W c := cW * + (1 -c)W ′ * , we have: b k (W c ) = b k (cW * + (1 -c)W ′ * ) = cb k (W * ) + (1 -c)b k (W ′ * ) = 0 (163) Therefore, W c is also a critical point, in which w k1 = cµ 1 + (1 -c)µ 2 and w k2 = (1 -c)µ 1 + cµ 2 . In particular when c = 1/2, w k1 = w k2 . Repeating this process for different m, we could finally reach a critical point in which all w km = w k .

D ANALYSIS OF BATCH NORMALIZATION

From the previous analysis of global modulation, it is clear that the weight updating can be much slower for RF with small d k , due to the factor 1 λ-d k in both s 2 k (Eqn. 139) and β k (Eqn. 143) and the fact that λ ≥ max k d k . This happens when the variance of each receptive fields varies a lot (i.e., some d k are large while others are small). In this case, adding BatchNorm at each node alleviates this issue, as shown below. We consider BatchNorm right after f : f bn k [i] = (f k [i] -µ k )/σ k , where µ k and σ k are the batch statistics computed from BatchNorm on all 2N samples in a batch: µ k := 1 2N i f k [i] + f k [i ′ ] σ 2 k := 1 2N i (f k [i] -µ k ) 2 + (f k [i ′ ] -µ k ) 2 (165) When N → +∞, we have µ k → E[f k ] and σ 2 k → V[f k ] = w ⊤ k V[ xk ]w k . Let xbn k := σ -1 k xk and xbn :=    xbn 1 xbn 2 . . . xbn K   . When computing gradient through BatchNorm layer, we consider the following variant: Definition 5 (mean-backprop BatchNorm). When computing backpropagated gradient through BatchNorm, we only backprop through µ k . This leads to a model dynamics that has a very similar form as Lemma 4: Lemma 6 (Dynamics with mean-backprop BatchNorm). With mean-backprop BatchNorm (Def. 5), the dynamics is: V = V C α [f bn 1 ], ẇ = (S ⊗ 1 d 1 ⊤ d ) • C α [ xbn ] w Proof. The proof is similar to Lemma 4. For V it is the same by replacing f 1 with f bn 1 , which is the input to the top layer. For ẇ, similarly we have: ẇk = ∂E ∂w k = l C α ∂f 2l ∂w k , f 2l (167) = l C α v lk ∂f bn 1k ∂w k , k ′ v lk ′ f bn 1k ′ (168) = l C α v lk ∂f bn 1k ∂w k , k ′ v lk ′ (f 1k ′ -µ k ′ )σ -1 k ′ (169) Note that C α [•, µ k ′ σ -1 k ′ ] = 0 since µ k ′ and σ k ′ are statistics of the batch and is constant. On the other hand, for ∂f bn 1k /∂w k , we have: ∂f bn 1k ∂w k = 1 σ k ∂f 1k ∂w k - ∂µ k ∂w k - f bn 1k σ k ∂σ k ∂w k (170) Note that ∂µ k ∂w k = E sample [ xk ] where E sample [•] is the sample mean, which is a constant over the batch. Therefore C α [•, ∂µ k /∂w k ] = 0. For mean-backprop BatchNorm, since the gradient didn't backpropagate through the variance, the second term is simply zero. Therefore, we have: ẇk = l C α v lk σ -1 k ∂f 1k ∂w k , k ′ v lk ′ f 1k ′ σ -1 k ′ (172) = l C α v lk σ -1 k xk , k ′ v lk ′ w ⊤ k ′ xk ′ σ -1 k ′ (173) = k ′ s kk ′ C α [σ -1 k xk , σ -1 k ′ xk ′ ]w k ′ (174) Let xbn k := σ -1 k xk and xbn :=    xbn 1 xbn 2 . . . xbn K    ∈ R Kd . The conclusion follows. Corollary 3 (Dynamics of w k under conditional independence and BatchNorm). Let A bn k := V[ xbn k ] = σ -2 k A k (175) d bn k := σ -2 k d k (176) ∆ bn k := σ -1 k ∆ k = E[ xbn k |z = 1] -E[ xbn k |z = 0] and λ bn be the maximal eigenvalue of V[f bn 1 ]. Then we have • (1) λ bn ≥ max k d bn k ; • (2) For λ bn , the associated unit eigenvector is s bn := 1 Z bn w ⊤ k ∆ bn k λ bn -d bn k ∈ R K , where Z bn is the normalization constant; • (3) the dynamics of w k is given by: ẇk = (s bn k ) 2 A bn k + δ bn k (∆ bn k )(∆ bn k ) ⊤ w k ( ) where δ bn k := p 0 p 1 (Z bn ) 2 (λ bn -d bn k ) k ′ ̸ =k (w ⊤ k ′ ∆ bn k ′ ) 2 λ bn -d bn k ′ ≥ 0 (179) Proof. Similar to Theorem 5.

Remarks

In the presence of BatchNorm, Lemma 5 still holds, since it only depends on the generative structure of the data. Therefore, we have We added a simple experiment to visualize the local maxima of E α (w) := 1 2 C α [h(w ⊤ x)], when w = [cos θ, sin θ] ⊤ is a 2D unit vector parameterized by θ, and h is ReLU activation. For simplicity, here we use a uniform α := 1. σ 2 k → V[f k ] = w ⊤ k V[ xk ]w k = d k + p 0 p 1 w ⊤ k ∆ k 2 and thus d bn k = σ -2 k d k → d k d k + p 0 p 1 w ⊤ k ∆ k 2 = 1 1 + p 0 p 1 w ⊤ k ∆ k / √ d k 2 becomes more uniform. This is because d k := w ⊤ k L k w k = E z V[f k |z] ≥ 0 (Eqn. We put a few data points {x i } on the unit circle, which are also parameterized by θ. The data points are located at {-4π 5 , -π 2 , -2π 5 , -π 6 , 0, π 5 , 2π 5 , 3π 5 , 4π 5 } and no data augmentation is used. The objective function E α (θ) is plotted in Fig. 8 . From the figure, we can see many local maxima (≥ 8) caused by nonlinearity (solid line), much more than 2 × 2 = 4, the maximal possible number of local maxima counting all PCA components in 2D case (i.e., ±ϕ 1 and ±ϕ 2 , where ϕ 1 and ϕ 2 are orthogonal PCA directions in this 2D example). Moreover, unlike PCA directions, these local optima are not orthogonal to each other. On the other hand, in the linear case (dotted line), the curve is much smoother. There are only two local maxima corresponding to ±ϕ 1 , where ϕ 1 is the largest PCA eigenvector.

E.2 2-LAYER SETTING

We also do more experiments on the 2-layer setting, to further verify our theoretical findings. Overall matching score χ+ and overall irrelevant-matching score χ-. As defined in the main text (Eqn. 16), the matching score χ + (R k ) is the degree of matching between learned weights and the embeddings of the subset R g k of tokens that are allowed in the global patterns at each receptive field R k . And the overall matching score χ+ is χ+ averaged over all receptive fields: χ + (R k ) = 1 P a∈R g k max m w ⊤ km u a ∥w km ∥ 2 ∥u a ∥ 2 , χ+ = 1 K k χ + (R k ) Similarly, we can also define irrelevant-matching score χ -(R k ) which is the degree of matching between learned weights and the embeddings of the tokens that are NOT in the subset R g k at each receptive field R k . And the overall irrelevant-matching score χis defined similarly. χ -(R k ) = 1 P a / ∈R g k max m w ⊤ km u a ∥w km ∥ 2 ∥u a ∥ 2 , χ-= 1 K k χ -(R k ) Ideally, we want to see high overall matching score χ+ and low overall irrelevant-matching score χ-, which means that the important patterns in R g k (i.e., the patterns that are allowed in the global generators) are learned, but noisy patterns that are not part of the global patterns (i.e., the generators) are not learned. Fig. 9 shows that this indeed is the case. Non-uniformity ζ and how BatchNorm interacts with it. When the scale of input data varies a lot, BatchNorm starts to matter in discovering features with low magnitude (Sec. D). To model the scale non-uniformity, we set ∥u a ∥ 2 = ζ for ⌊d/2⌋ tokens and ∥u a ∥ 2 = 1/ζ for the remaining tokens. Larger ζ corresponds to higher non-uniformity across inputs. Fig. 11 shows that BN with ReLU activations handles large non-uniformity (large ζ) very well, compared to the case without BN. Specifically, BN yields higher χ+ in the presence of high nonuniformity (e.g., ζ = 10) when the network is over-parameterized (β > 1) and there are multiple candidates per R k (P > 1), a setting that is likely to hold in real-world scenarios. Note that in the real-world scenario, features from different channels/modalities indeed will have very different scales, and some local features that turn out to be super important to global features, can have very small scale. In such cases, normalization techniques (e.g., BatchNorm) can be very useful and our formulation justifies it in a mathematically consistent way. Selectively Backpropagating µ k and σ 2 k in BatchNorm. In our analysis of BatchNorm, we assume that gradient backpropagating the mean statistics µ k , but not variance σ 2 k (see Def. 5). Note that this is different from regular BatchNorm, in which both µ k and σ 2 k get backpropagated gradients. Therefore, we test how this modified BN affects the matching score χ+ : we change whether µ k and σ 2 k gets backpropagated gradients, while the forward pass remains the same, yielding the four following variants: f bn k [i] := f bn k [i] -µ k σ k (Vanilla BatchNorm) f bn k [i] := f bn k [i] -stop-gradient(µ k ) σ k (BatchNorm with backpropated σ k ) f bn k [i] := f bn k [i] -µ k stop-gradient(σ k ) (BatchNorm with backpropated µ k ) f bn k [i] := f bn k [i] -stop-gradient(µ k ) stop-gradient(σ k ) (BatchNorm without backpropating statistics) As shown in Tbl. 1, it is interesting to see that if σ 2 k is not backpropagated, then the matching score χ+ is actually better. This justifies our BN variant. Quadratic versus InfoNCE loss. Fig. 10 shows that quadratic loss (constant pairwise importance α) shows worse matching score than InfoNCE. A high-level intuition is that InfoNCE dynamically adjusts the pairwise importance α (i.e., the focus of different sample pairs) during training to focus on the most important sample pairs, which makes learning patterns more efficient. We leave a comprehensive study for future work. When each R k has multiple local patterns that are related to the global patterns (P > 1) related to generators, ReLU models can capture diverse patterns better than linear ones in the over-parameterization region β > 1 and stay focus on relevant local patterns that are related to the global patterns (i.e., low χ-). Among all activations (homogeneous or non-homogeneous), ReLU shows its strength by achieving the lowest irrelevant-matching score χ-. In contrast, linear models are much less affected by over-parameterization. Each setting is repeated 3 times and mean/standard derivations are reported. 9 , with all experiments setting being the same, except for the loss function. While we see similar trends as in Sec. 5.1, quadratic loss is not as effective as InfoNCE in feature learning. We could see that indeed filters in each receptive field capture a diverse set of pattern of the input data points (e.g., part of the digits). Furthermore, with additional data augmentation (i.e., random cropping and resizing with transform transforms.RandomResizedCrop ((28,28) , scale=(0.5, 1.0), ratio=(0.5, 1.5) in PyTorch), the resulting learned patterns becomes smoother with much weaker high-frequency components. This is because the augmentation implicitly removes some of the local optima (Sec. B.5), leaving more informative local optima.

F CONTEXT AND MOTIVATIONS OF THEOREMS

Here are the motivations and intuitions behind each major theoretical results: • First of all, Lemma 2 and Corollary 1 try to make connection between a relatively new (and somehow abstract) concept (i.e., contrastive covariance C α [•]) and a well-known concept (i.e., regular covariance V[•]). This will also enable us to leverage existing properties of covariance, in order to deepen our understanding of this concept, which seems to play an important role in contrastive learning (CL). • After that, we mainly focus on studying the CL energy function E α , which can be represented in terms of the contrastive covariance C α [•] . One important property of the energy function is whether there exists any local optima and what are the properties of these local optima, since these local optima are the final destinations that network weights will converge into. Previous works in landscape analysis often talk about local optima in neural networks as abstract objects in high-dimensional space, but here, we would like to make them as concrete as possible. • For such analysis, we always start from the simplest case (e.g., one-layer). Therefore, naturally we have Lemma 3 that characterizes critical points (as a superset of local optima), a few examples in Sec. 3.2, properties of these critical points and when they become local optima in Sec. 3. Finally, Appendix B.5 further gives a preliminary study on how the data augmentation affects the distribution of the local optima. • Then we extend our analysis to 2-layer setting. The key question is to study the additional benefits of 2-layer network compared to K independent 1-layer cases. Here the assumption of disjoint receptive fields is to make sure there is an apple-to-apple comparison, otherwise additional complexity would be involved, e.g., overlapping receptive fields. As demonstrated in Theorem 5 and Theorem 6, we find the effect of global modulation in 2-layer case, which clearly tells that the interactions across different receptive fields lead to additional terms in the dynamics that favors patterns related to latent variable z that leads to conditional independence across the disjointed receptive fields. • As a side track, in 2-layer case, we also have Theorem 4 that shows linear activation does not learn distinct features, which is consistent with 1-layer case that linear activation h(x) = x only gives to maximal PCA directions (Sec. 6). (185)

G OTHER LEMMAS

≥ 1 - 1 2 λ 2 (t) λ 1 (t) 2 1 c 2 t -1 =: 1 -µ t (1 -c t ) The first inequality is due to the fact that i>1 λ 2 i (t) ϕ ⊤ i (t)w(t) 2 = 1 -c 2 t (Parseval's identity). The last inequality is due to the fact that for x > -1, (1 + x) α ≥ 1 + αx when α ≥ 1 or α < 0 (Bernoulli's inequality). Therefore the conclusion holds. Lemma 8 (Bound of weight difference). If c t > 0 and λ i (t) > 0 for all i, then ∥δw(t)∥ 2 ≤ 2(1 + µ t c t )(1 -c t ) Proof. First, for w ⊤ (t + 1)w(t), we have (notice that λ i (t) ≥ 0): Lemma 9. Let δA = A ′ -A, then the maximal eigenvector ϕ 1 := ϕ 1 (A) and ϕ ′ 1 := ϕ 1 (A ′ ) has the following Taylor expansion: w ⊤ (t + ϕ ′ 1 = ϕ 1 + ∆ϕ 1 + O(∥δA∥ 2 2 ) ( ) where λ i is the i-th eigenvalue of A, ∆ϕ 1 := j>1 ϕ ⊤ j δAϕ1 λ1-λj ϕ j is the first-order term of eigenvector perturbation. In terms of inequality, there exist κ > 0 so that: ∥ϕ ′ 1 -(ϕ 1 + ∆ϕ 1 )∥ 2 ≤ κ∥δA∥ 2 2 (191) Proof. See time-independent perturbation theory in Quantum Mechanics (Fernández, 2000) . Lemma 10. Let L be the minimal Lipschitz constant of A so that ∥A(w ′ ) -A(w)∥ 2 ≤ L∥w -w ′ ∥ 2 holds. If c t > 0 and λ i (t) > 0 for all i, then we have: Proof. Using Lemma 9 and the fact that ∥w(t + 1)∥ 2 = 1, we have: 198)  |d t -c t+1 | = (ϕ 1 (t)



Compared to Tian (2022), our Cα definition has an additional constant term 1/2N to simply the notation.



Figure 1: Left: Summary of Sec. 3. (a) We analyze the dynamics of one-layer network h(w ⊤ x) under CL loss (Eqn. 2). (b) With linear activation h(x) = x, then there is only one fixed point (PCA direction). (c) Non-linear activation h(x) creates many critical points and a proper choice of pairwise importance α can make them local optima, enabling learning of diverse features. Right: Convergence patterns (iteration t versus iteration discrepancy ∥w(t + 1) -w(t)∥2) of Power Iteration (Eqn. PI) in latent summation models, when ∥U ⊤ U -I∥2 is small but non-zero. In this case, Theorem 3 tells there still exist local optima close to each um.

Figure 2: Our setting for 2-layer network. (a) We use W for low-layer weights and V for top-layer weights.

Define xk := [ xk1 ; xk2 ; . . . , xkM ] ∈ R M d as the concatenation of xkm and finally x := [ x1 ; . . . ; xK ] ∈ R KM d is the concatenation of all xk . Similarly, let w k := [w km ] M m=1 ∈ R M d be a concatenation of all w km in the same RF R k , and w := [w k ] K k=1 ∈ R KM d be a column concatenation of all w k . Finally, P ⊥ w := diag km [P ⊥

augmentation swaps with nonlinear gating. For conditional independence, intuitively z can be regarded as different type of global patterns that determines what input x can be perceived (Fig.2(b)). Once z is given, the remaining variation resides within each RF R k and independent across different R k . Note that there exists many patterns in each RF R k . Some are parts of the global pattern z, and others may come from noise. We study how each weight w km captures distinct and useful patterns after training.

Figure 3: Experimental setting (Sec. 5). When generating input, we first randomly pick one generator (e.g., ** C * B * A * D * ) from a pool of G generators, generate the sequence by instantiating wildcard * with an arbitrary token, and then replace the token a of sequence with an embedding vector ua to form the input x. Inputs from the same generator are treated as positive pairs, otherwise negative pairs for contrastive loss.

Figure 4: Overall matching score χ+ (Eqn. 16) with InfoNCE (top row) and quadratic loss (bottom row).

Figure 5: Visualization of learned weights with P = 3 (3 local patterns related to generators at each RF) and β = 5 (5x over-parameterization). Each of the K = 10 subfigures corresponds to a RF (R0-R9). In each subfigure, the left panel is the learned weight by ReLU, while the right panel is from linear activations. 15 rows corresponds to M = βP = 15 weights and each weight is d = 8 dimensional. With ReLU activation, learned weights clearly capture the 3 candidate tokens within R g k at each RF R k , while linear activation cannot.

Bound of local roughness ρ(w) in ReLU setting). If input ∥x∥ 2 ≤ C 0 is bounded, α has kernel structure (Def. 1) and batchsize N → +∞, then ρ(w * ) ≤ C 3 0 vol(C0) π r(w * , α), where r(w, α) := +∞ l=0 z 2 l (α) max w ⊤ x=0 pl (x; α).

69) and V pl [ xw ] -V pl [ xw * ] = (E -E * ) + (e * e ⊤ * -ee ⊤ ). Define the following regions

x) = z 0 (α g )r(w * , α u ) (99) Here z 0 (α g ) := ϕ g0 (x)p D (x)dx = e -h 2 (w ⊤ * x) 2τ p D (x)dx ≤ 1 (100) B.4 FINDING CRITICAL POINTS WITH INITIAL GUESS (SEC. 3.3)

Figure 6: Decomposition of the variance term V[ x].

Figure 7: The one-dimensional dynamics (Eqn. 146) (b l > 0). There exists a stable critical point y l = -b l and one unstable critical point y l = 0. When y l (0) > 0, the dynamics blows up in finite time t * l = 1 b l ln 1 + b l y l (0) .

Figure 8: Local optima of objective Eα(w) under uniform α := 1. w = [cos θ, sin θ] ⊤ is parameterized by θ ∈ [-π, π],and each vertical dotted line is a data point. Dotted line is Eα with linear activation h(x) = x, and solid line is using ReLU activation h(x) = max(x, 0). Both lines are scaled so that their maximal value is 1.

Figure 9: Overall matching score χ+ (Eqn. 180, the top row) and irrelevant-matching score χ-(Eqn. 181, the bottom row). This is an extended version of Fig. 4. (a) When P = 1, linear model works well regardless of the degree of over-parameterization β, while ReLU model requires large over-parameterization to perform well; (b)When each R k has multiple local patterns that are related to the global patterns (P > 1) related to generators, ReLU models can capture diverse patterns better than linear ones in the over-parameterization region β > 1 and stay focus on relevant local patterns that are related to the global patterns (i.e., low χ-). Among all activations (homogeneous or non-homogeneous), ReLU shows its strength by achieving the lowest irrelevant-matching score χ-. In contrast, linear models are much less affected by over-parameterization. Each setting is repeated 3 times and mean/standard derivations are reported.

Figure10: Overall matching score χ+ (top row) and overall irrelevant-matching score χ-(Eqn. 16, bottom row) using quadratic loss function rather than InfoNCE. The result using InfoNCE is shown in Fig.9, with all experiments setting being the same, except for the loss function. While we see similar trends as in Sec. 5.1, quadratic loss is not as effective as InfoNCE in feature learning.

Figure 11: The effect of BatchNorm (BN) with ReLU activation in the presence of non-uniformity ζ of the input data. The non-uniformity is set to be ζ = 2 (top), ζ = 5 (middle) and ζ = 10 (bottom). For small non-uniformity, BN doesn't help much. For larger nonuniformity, BN yields better matching score χ+ in the over-parameterization region (large β) and multiple tokens per RF (large P ).

Figure 13: Same as Fig. 12 but with data augmentation during contrastive learning. The learned filters are smoother. According to the tentative theory in Sec. B.5, data augmentation removes some of the local optima.

|d t -c t+1 | = (ϕ 1 (t) -ϕ 1 (t + 1)) ⊤ w(t + 1) ≤ ν t (1 -c t )(192)where ν(w) := 2κL 2 (1 + µ(w)c(w)) + 2Lλ -1 gap (Aw(t)) µ(w)(1 + µ(w)c(w)) ≥ 0(193)    and ν t := ν(w(t)).

There are K disjoint receptive fields (abbreviated as RF) R k , each with M weight vectors in the low-layer, denoted as w km . The activation function of hidden layer nodes is h(x) and can be linear or nonlinear. (b) Conditional independence in Assumption 2: there exists a global categorical variable z. Given z, variation in different RFs are assumed to be independent.See proof in Appendix B.4 . Intuitively, with L and κ small, c 0 close to 1, and λ gap large, Eqn. 11 can always hold with γ < 1 and the fixed point exists. For example, for the two cases in Sec. 3.1, if U = [u 1 , . . . , u M ] is only approximately orthogonal (i.e., ∥U ⊤ U -I∥ is not zero but small), and/or the conditions of Corollary 1 hold roughly, then Theorem 3 tells that multiple local optima close to u m still exist for each m (Fig.

). For small non-uniformity, BN doesn't help much. For larger nonuniformity, BN yields better matching score χ+ in the over-parameterization region (large β) and multiple tokens per RF (large P ). The effect of backpropagating different BN statistics under nonuniformity ζ = 10. Backpropagating the gradient through the sample mean µ k but not the sample variance σ 2 k gives overall good matching score χ+, justifying our setting of mean-backprop BatchNorm (Def. 5).

Lemma 7 (Bound of 1 -d t ). Define If c t > 0 and λ 1 (t) > 0, then 1 -d t ≤ µ t (1 -c t ).

For brevity, we omit all temporal notation if the quantity is evaluated at iteration t. E.g., δw means δw(t) and ϕ 1 means ϕ 1 (t).

annex

Remarks. Note that an alternative route is to use homogeneous condition first: C α [h(w ⊤ x)] = w ⊤ A(x)w, then taking the differential. This involves an additional term 1 2 w ⊤ (dA)w. In the following we will show it is zero. For this we first compute dA:we have:(50) Note that we now see the term h ′′ (w ⊤ x)w ⊤ x. For ReLU activation, its second derivative h ′′ (x) = δ(x), where δ(x) is Direct delta function (Boas & Peters, 1984) . From the property of delta function, we have xh ′′ (x) = xδ(x) = 0 even evaluated at x = 0. Therefore, h ′′ (w ⊤ x)w ⊤ x = 0 and w ⊤ (dA)w = 0. This is similar for LeakyReLU as well.First, we have:The previous derivation is due to the fact that P ⊥ w * A(w * )w * = 0, u ⊤ A(w * )w * = 0 and P ⊥ w * u = u. Therefore, for P ⊥ v A(v)v, we can decompose it to two parts:Therefore, since u ⊤ w * = 0, we have:and since ∥u∥ 2 = ∥v∥ 2 = 1 and ∥P ⊥ v ∥ 2 = 1, we have:61) By the definition of local roughness measure ρ(w * ), we have:(64) Therefore, we have:When λ gap (w * ) > ρ(w * ) and we have u ⊤ P ⊥ v A(v)v < 0 for any u ⊥ w * and sufficiently small ϵ. Therefore, the critical point w * is stable.Proof. Note that if c 0 < 0, we can always use -ϕ 1 (w) as the maximal eigenvector.First we assume A(w) is positive definite (PD) over the entire unit sphere ∥w∥ 2 = 1, then follow Lemma 11, and notice that ∥w -w(0When A(w) is not PD, Theorem 3 still applies to the PD matrix Â(w) := A(w) -λ min (w)I + ϵI with L and κ specified by Â(w), where ϵ > 0 is a small constant.This transformation keeps c 0 since the eigenvectors of Â(w) are the same as A(w). wwThe resulting fixed point ŵ * is also the fixed point of the original problem with A(w), due to the fact thatRemarks. Note that Lemma 11 assumes that along the trajectory {w(t)}, µ t + ν t ≤ γ holds. In Theorem 3, this can not be assumed true until we prove that the entire trajectory is within B γ .

B.5 THE EFFECT OF DATA AUGMENTATION ON LOCAL OPTIMA

While the majority of the analysis focuses on the cases where there are no data augmentation (i.e., using Corollary 1), the original formulation Lemma 2 can still handle contrastive learning in the presence of data augmentation.In fact, data augmentation plays an important role by removing unnecessary local optima. First, Lemma 2 tells that the objective Eqn. 5, when K = 1, takes the following form:where b(w|xNow let us consider the following simple data augmentation of x 0 :where R(t) ∈ R d×d is some rotation parameterized by t, which is drawn uniformly from a parameter family T .We assume {R(t)} t∈T forms a 1-dimensional Lie group parameterized by T . This means that• Closeness. For any t, t ′ ∈ T , there exists t ′′ ∈ T so that R(t ′′ ) = R(t)R(t ′ ).• Existence of inverse element. For each R(t), there exists an inverse element t ′ ∈ T so thatThe last equality is due to the fact that R(t) is a rotation.• Existence of identity map. R(0) = I.Then for any small transformation R(t ′ ) applied to the weights w (here "small" means ∥R(t ′ ) -I∥ 2 is small), we can write down b(R(t ′ )w|x 0 ) using reparameterization trick:Note that the last equality is due to the fact that {R(t)} t∈T is a Lie group, so that R -1 (t)R(t ′ ) always maps to another group element R(t ′′ ), and t ′′ as the resulting parameterization is still uniform.Since {ϕ j } is a set of orthonormal bases, Parseval's identity tells that for any vector v, its energy under any orthonormal bases are preserved: j (ϕ ⊤ j v) 2 = ∥v∥ 2 2 . Therefore, we have:Note that using -1 ≤ d t ≤ 1 and Lemma 7, we have:Finally using bound of weight difference (Lemma 8), we have:HereDefine ω(w) := µ(w) + ν(w) to be the irregularity (also defined in Def. 4). If there exists γ < 1 so that sup• The sequence {c t } increases monotonously and converges to 1;• There exists w * so that lim t→+∞ w(t) = w * .• w * is the maximal eigenvector of A(w * ) and thus a fixed point of gradient update (Eqn. 7);. That is, w * is in the vicinity of the initial weight w(0).Proof. We first prove by induction that the following induction arguments are true for any t:• w(t) is not far away from its initial value w(0):which suggests that w(t) ∈ B γ .Base case (t = 1). Since 1 ≥ c 0 > 0, µ(w) ≥ 0, and A(w) is PD, applying Lemma 8 to ∥w(1) -w(0)∥ 2 , it is clear thatNote that the last inequality is due to µ 0 ≤ γ. Note thatand finally we have c 1 ≥ 1 -γ(1 -c 0 ) ≥ c 0 > 0. So the base case is satisfied.Inductive step. Assume for t, the induction argument is true and thus w(t) ∈ B γ . Therefore, by the condition, we know µ t + ν t ≤ γ.By Lemma 8, we know thatTherefore, we know that w(t + 1) also satisfies Eqn. 206:Also we have:and thus we haveTherefore This means that A(w * )w * = λ * w * and thus P ⊥ w * A(w * )w * = 0, i.e., w * is a fixed point of gradient update (Eqn. 7). Finally, we have:Since ∥ • ∥ 2 is continuous, we have the conclusion.

