BEHIND THE SCENES OF GRADIENT DESCENT: A TRAJECTORY ANALYSIS VIA BASIS FUNCTION DE-COMPOSITION

Abstract

This work analyzes the solution trajectory of gradient-based algorithms via a novel basis function decomposition. We show that, although solution trajectories of gradient-based algorithms may vary depending on the learning task, they behave almost monotonically when projected onto an appropriate orthonormal function basis. Such projection gives rise to a basis function decomposition of the solution trajectory. Theoretically, we use our proposed basis function decomposition to establish the convergence of gradient descent (GD) on several representative learning tasks. In particular, we improve the convergence of GD on symmetric matrix factorization and provide a completely new convergence result for the orthogonal symmetric tensor decomposition. Empirically, we illustrate the promise of our proposed framework on realistic deep neural networks (DNNs) across different architectures, gradient-based solvers, and datasets. Our key finding is that gradient-based algorithms monotonically learn the coefficients of a particular orthonormal function basis of DNNs defined as the eigenvectors of the conjugate kernel after training. Our code is available at github.com/jianhaoma/function-basis-decomposition. 1 Given a probability distribution D, we define the L2(D)-norm as ∥f ∥ 2 L 2 (D) = Ex∼D f 2 (x) . 2 Both Legendre and Hermite polynomials can be derived sequentially using Gram-Schmidt procedure. For instance, the first three Legendre polynomials are defined as P1(x) = 1/ √ 2, P2(x) = 3/2x, and P3(x) = 5/8(3x 2 -1).

1. INTRODUCTION

Learning highly nonlinear models amounts to solving a nonconvex optimization problem, which is typically done via different variants of gradient descent (GD). But how does GD learn nonlinear models? Classical optimization theory asserts that, in the face of nonconvexity, GD and its variants may lack any meaningful optimality guarantee; they produce solutions that-while being first-or second-order optimal (Nesterov, 1998; Jin et al., 2017) -may not be globally optimal. In the rare event where the GD can recover a globally optimal solution, the recovered solution may correspond to an overfitted model rather than one with desirable generalization. Inspired by the large empirical success of gradient-based algorithms in learning complex models, recent work has postulated that typical training losses have benign landscapes: they are devoid of spurious local minima and their global solutions coincide with true solutions-i.e., solutions corresponding to the true model. For instance, different variants of low-rank matrix factorization (Ge et al., 2016; 2017) and deep linear NNs (Kawaguchi, 2016) have benign landscapes. However, when spurious solutions do exist (Safran & Shamir, 2018) or global and true solutions do not coincide (Ma & Fattahi, 2022b) , such a holistic view of the optimization landscape cannot explain the success of gradient-based algorithms. To address this issue, another line of research has focused on analyzing the solution trajectory of different algorithms. Analyzing the solution trajectory has been shown extremely powerful in sparse recovery (Vaskevicius et al., 2019) , low-rank matrix factorization (Li et al., 2018; Stöger & Soltanolkotabi, 2021) , and linear DNNs (Arora et al., 2018; Ma & Fattahi, 2022a) . However, these analyses are tailored to specific models and thereby cannot be directly generalized. In this work, we propose a unifying framework for analyzing the optimization trajectory of GD based on a novel basis function decomposition. We show that, although the dynamics of GD may vary drastically on different models, they behave almost monotonically when projected onto an appropriate choice of orthonormal function basis. Motivating example: Our first example illustrates this phenomenon on DNNs. We study the optimization trajectories of two adaptive gradient-based algorithms, namely AdamW and LARS, on three different DNN architectures, namely AlexNet (Krizhevsky et al., 2017) , ResNet-18 (He et al., 2016) , and Vision Transformer (ViT) (Dosovitskiy et al., 2020) with the CIFAR-10 dataset. The first row of the Figure 1 shows the top-5 coefficients of the solution trajectory when projected onto a randomly generated orthonormal basis. We see that the trajectories of the coefficients are highly non-monotonic and almost indistinguishable (they range between -0.04 to 0.06), implying that the energy of the obtained model is spread out on different orthogonal components. The second row of Figure 1 shows the same trajectory after projecting onto an orthogonal basis defined as the eigenvectors of the conjugate kernel after training (Long, 2021) (see Section 3.4 and Appendix B for more details). Unlike the previous case, the top-5 coefficients carry more energy and behave monotonically (modulo the small fluctuations induced by the stochasticity in the algorithm) in all three architectures, until they plateau around their steady state. In other words, the algorithm behaves more monotonically after projecting onto a correct choice of orthonormal basis.

1.1. MAIN CONTRIBUTIONS

The monotonicity of the projected solution trajectory motivates the use of an appropriate basis function decomposition to analyze the behavior of gradient-based algorithms. In this paper, we show how an appropriate basis function decomposition can be used to provide a much simpler convergence analysis for gradient-based algorithms on several representative learning problems, from simple kernel regression to complex DNNs. Our main contributions are summarized below: -Global convergence of GD via basis function decomposition: We prove that GD learns the coefficients of an appropriate function basis that forms the true model. In particular, we show that GD learns the true model when applied to the expected ℓ 2 -loss under certain gradient independence and gradient dominance conditions. Moreover, we characterize the convergence rate of GD, identifying conditions under which it enjoys linear or sublinear convergence rates. Our result does not require a benign landscape for the loss function and can be applied to both convex and nonconvex settings. -Application in learning problems: We show that our general framework is well-suited for analyzing the solution trajectory of GD on different representative learning problems. Unlike the existing results, our proposed method leads to a much simpler trajectory analysis of GD for much broader classes of models. Using our technique, we improve the convergence of GD on the symmetric matrix factorization and provide an entirely new convergence result for GD on the orthogonal symmetric tensor decomposition. We also prove that GD enjoys an incremental learning phenomenon in both problems. -Empirical validation on DNNs: We empirically show that our proposed framework applies to DNNs beyond GD. More specifically, we show that different gradient-based algorithms monotonically learn the coefficients of a particular function basis defined as the eigenvectors of the conjugate kernel after training (also known as "after kernel regime"). We show that this phenomenon happens across different architectures, datasets, solvers, and loss functions, strongly motivating the use of function basis decomposition to study deep learning.

2. GENERAL FRAMEWORK: FUNCTION BASIS DECOMPOSITION

We study the optimization trajectory of GD on the expected (population) ℓ 2 -loss min θ∈Θ L (θ) := 1 2 E x,y (f θ (x) -y) 2 . (expected ℓ 2 -loss) Here the input x ∈ R d is drawn from an unknown distribution D, and the output label y is generated as y = f ⋆ (x) + ε, where ε is an additive noise, independent of x, with mean E[ε] = 0 and variance E[ε 2 ] = σ 2 ε < ∞. The model f θ (x) is characterized by a parameter vector θ ∈ R m , which naturally induces a set of admissible models (model space for short) F Θ := {f θ : θ ∈ R m }. We do not require the true model f ⋆ to lie within the model space; instead, we seek to obtain a model f θ ⋆ ∈ F Θ that is closest to f ⋆ in L 2 (D)-distance. In other words, we consider f ⋆ = f θ ⋆ (x) + f ⋆ ⊥ (x), where θ ⋆ = arg min θ ∥f θf ⋆ ∥ L 2 (D) . 1 To minimize the expected ℓ 2 -loss, we use vanilla GD with constant step-size η > 0: θ t+1 = θ t -η∇L(θ t ). (GD) Definition 1 (Orthonormal function basis). A set of functions {ϕ i (x)} i∈I forms an orthonormal function basis for the model space F Θ with respect to the L 2 (D)-metric if • for any i ∈ I, we have E x∼D [ϕ 2 i (x)] = 1; • for any i, j ∈ I such that i ̸ = j, we have E x∼D [ϕ i (x)ϕ j (x)] = 0; • for any f θ ∈ F Θ , there exists a unique sequence of basis coefficients {β i (θ)} i∈I such that f θ (x) = i∈I β i (θ)ϕ i (x). Example 1 (Orthonormal basis for polynomials). Suppose that F Θ is the class of all univariate real polynomials of degree at most n, that is, F Θ = { n+1 i=1 θ i x i-1 : θ ∈ R n+1 }. If D is a uniform distribution on [-1, 1], then the so-called Legendre polynomials form an orthonormal basis for F Θ with respect to the L 2 (D)-metric (Olver et al., 2010, Chapter 14) . Moreover, if D is a normal distribution, then Hermite polynomials define an orthonormal basis for F Θ with respect to the L 2 (D)-metric (Olver et al., 2010, Chapter 18 ). 2 Example 2 (Orthonormal basis for symmetric matrix factorization). Suppose that the true model is defined as f U ⋆ (X) = ⟨U ⋆ U ⋆⊤ , X⟩ with some rank-r matrix U ⋆ ∈ R d×r , and consider an "overparameterized" function class F Θ = {f U (X) : U ∈ R d×r ′ } where r ′ ≥ r is an overestimation of the rank. Moreover, suppose that the elements of X ∼ D are iid with zero mean and unit variance. Consider the eigenvalues of U ⋆ U ⋆⊤ as σ 1 ≥ • • • ≥ σ d with σ r+1 = • • • = σ d = 0 , and their corresponding eigenvectors z 1 , . . . , z d . It is easy to verify that the functions ϕ ij (X) = ⟨z i z ⊤ j , X⟩ for 1 ≤ i, j ≤ d define a valid orthogonal basis for F Θ with respect to the L 2 (D)-metric. Moreover, for any f U (X), the basis coefficients can be obtained as β ij (U ) = E U U ⊤ , X z i z ⊤ j , X = z i z ⊤ j , U U ⊤ . As will be shown in Section 3.2, this choice of orthonormal basis significantly simplifies the dynamics of GD for symmetric matrix factorization. Given the input distribution D, we write f θ ⋆ (x) = i∈I β i (θ ⋆ )ϕ i (x), where {ϕ(x)} i∈I is an orthonormal basis for F Θ with respect to L 2 (D)-metric, and {β i (θ ⋆ )} i∈I are the true basis coefficients. For short, we denote β ⋆ i = β i (θ ⋆ ). In light of this, the expected loss can be written as: L(θ) = 1 2 i∈I (β i (θ) -β ⋆ i ) 2 optimization error + 1 2 ∥f ⋆ ⊥ ∥ 2 L 2 (D) approximation error + σ 2 ε /2 noise . Accordingly, GD takes the form θ t+1 = θ t -η i∈I (β i (θ t ) -β ⋆ i ) ∇β i (θ t ). (GD dynamic) Two important observations are in order based on GD dynamic: first, due to the decomposed nature of the expected loss, the solution trajectory becomes independent of the approximation error and noise. Second, in order to prove the global convergence of GD, it suffices to show the convergence of β i (θ t ) to β ⋆ i . In fact, we will show that the coefficients β i (θ t ) enjoy simpler dynamics for particular choices of orthonormal basis that satisfy appropriate conditions. Assumption 1 (Boundedness and smoothness). There exist constants L f , L g , L H > 0 such that ∥f θ ∥ L 2 (D) ≤ L f , ∥∇f θ ∥ L 2 (D) ≤ L g , ∇ 2 f θ L 2 (D) ≤ L H . The above assumptions are common in the optimization literature (Bubeck et al., 2015) and sometimes necessary to ensure the convergence of local-search algorithms (Patel & Berahas, 2022) . Moreover, although these assumptions may not hold globally, all of our subsequent results hold when Assumption 1 is satisfied within any bounded region for θ that includes the solution trajectory. We will also relax these assumptions for several learning problems. Proposition 1 (Dynamic of β i (θ t )). Under Assumption 1 and based on GD dynamic, we have β i (θ t+1 ) = β i (θ t ) -η j∈I β j (θ t ) -β ⋆ j ⟨∇β i (θ t ), ∇β j (θ t )⟩ ± O η 2 L H L 2 f L 2 g . The above proposition holds for any valid choice of orthonormal basis {ϕ i (x)} i∈I . Indeed, there may exist multiple choices for the orthonormal basis, and not all of them would lead to equally simple dynamics for the coefficients. Examples of "good" and "bad" choices of orthonormal basis were presented for DNNs in our earlier motivating example. Indeed, an ideal choice of orthogonal basis should satisfy ⟨∇β i (θ t ), ∇β j (θ t )⟩ ≈ 0 for i ̸ = j, i.e., the gradients of the coefficients remain orthogonal along the solution trajectory. Under such assumption, the dynamics of β i (θ t ) almost decompose over different indices: β i (θ t+1 ) ≈ β i (θ t ) -η (β i (θ t ) -β ⋆ i ) ∥∇β i (θ t )∥ 2 ± O(η 2 ), where the last term accounts for the second-order interactions among the basis coefficients. If such an ideal orthonormal basis exists, then our next theorem shows that GD efficiently learns the true basis coefficients. To streamline the presentation, we assume that β ⋆ 1 ≥ • • • ≥ β ⋆ k > 0 and β ⋆ i = 0, i > k for some k < ∞. We refer to the index set S = {1, . . . , k} as signal and the index set E = I\S as residual. When there is no ambiguity, we also refer to β i (θ t ) as a signal if i ∈ S. Theorem 1 (Convergence of GD with finite ideal basis). Suppose that the initial point θ 0 satisfies β i (θ 0 ) ≥ C 1 α, for all i ∈ S, (lower bound on signals at θ 0 ) ∥f θ0 ∥ L 2 (D) = i∈I β 2 i (θ 0 ) 1/2 ≤ C 2 α, (upper bound on energy at θ 0 ) for C 1 , C 2 > 0 and α ≲ β ⋆ k . Moreover, suppose that the orthogonal function basis is finite, i.e., |I| = d for some finite d, and the gradients of the coefficients satisfy the following conditions for every 0 ≤ t ≤ T : ⟨∇β i (θ t ), ∇β j (θ t )⟩ = 0 for all i ̸ = j, (gradient independence) ∥∇β i (θ t )∥ ≥ C |β i (θ t )| γ for all i ∈ S, ( for C > 0 and 1/2 ≤ γ ≤ 1. Then, GD with step-size η ≲ α 2γ √ dC 2 L H L 2 g L 2 f β ⋆2γ k log -1 dβ ⋆ k C1α satisfies: • If γ = 1 2 , then within T = O 1 C 2 ηβ ⋆ k log β ⋆ k C1α iterations, we have ∥f θ T -f θ ⋆ ∥ L 2 (D) ≲ α. • If 1 2 < γ ≤ 1, then within T = O 1 C 2 ηβ ⋆ k α 2γ-1 iterations, we have ∥f θ T -f θ ⋆ ∥ L 2 (D) ≲ α. Theorem 1 shows that, under certain conditions on the basis coefficients and their gradients, GD with constant step-size converges to a model that is at most α-away from the true model. In particular, to achieve an ϵ-accurate solution for any ϵ > 0, GD requires O((1/ϵ) log(1/ϵ)) iterations for γ = 1/2, and O(1/ϵ 2γ ) iterations for 1/2 < γ ≤ 1 (ignoring the dependency on other problem-specific parameters). Due to its generality, our theorem inevitably relies on a small step-size and leads to a conservative convergence rate for GD. Later, we will show how our proposed approach can be tailored to specific learning problems to achieve better convergence rates in each setting. How realistic are the assumptions of Theorem 1? A natural question arises as to whether the conditions for Theorem 1 are realistic. We start with the conditions on the initial point. Intuitively, these assumptions entail that a non-negligible fraction of the energy is carried by the signal at the initial point. We note that these assumptions are mild and expected to hold in practice. For instance, We will show in Section 3 that, depending on the learning task, they are guaranteed to hold with fixed, random, or spectral initialization. 3 We have also empirically verified that these conditions are satisfied for DNNs with random or default initialization. For instance, Figure 2a illustrates the top-20 basis coefficients at θ 0 for LARS with random initialization on a realistic CNN. It can be seen that a non-negligible fraction of the energy at the initial point is carried by the first few coefficients. The conditions on the coefficient gradients are indeed harder to satisfy; as will be shown later, the existence of an ideal orthonormal basis may not be guaranteed even for linear NNs. Nonetheless, we have empirically verified that, with an appropriate choice of the orthonormal basis, the gradients of the coefficients remain approximately independent throughout the solution trajectory. Figure 2b shows that, when the orthonormal basis is chosen as the eigenvectors of the conjugate kernel after training, the maximum value of | cos(∇β i (θ t ), ∇β j (θ t ))| remains small throughout the solution trajectory. Finally, we turn to the gradient dominance condition. Intuitively, this condition entails that the gradient of each signal scales with its norm. We prove that this condition is guaranteed to hold for kernel regression, symmetric matrix factorization, and symmetric tensor decomposition. Moreover, we have empirically verified that the gradient dominance holds across different DNN architectures. Figure 2c shows that this condition is indeed satisfied for the top-4 basis coefficients of the solution trajectory (other signal coefficients behave similarly). We also note that our theoretical result on GD may not naturally extend to LARS. Nonetheless, our extensive simulations suggest that our proposed analysis can be extended to other stochastic and adaptive variants of GD (see Section 3.4 and Appendix B); a rigorous verification of this conjecture is left as future work.

3. APPLICATIONS

In this section, we show how our proposed basis function decomposition can be used to study the performance of GD in different learning tasks, from simple kernel regression to complex DNNs. We start with the classical kernel regression, for which GD is known to converge linearly Karimi et al. (2016) . Our purpose is to revisit GD through the lens of basis function decomposition, where there is a natural and simple choice for the basis functions. Next, we apply our approach to two important learning problems, namely symmetric matrix factorization and orthogonal symmetric tensor decomposition. In particular, we show how our proposed approach can be used to improve the convergence of GD for the symmetric matrix factorization and leads to a completely new convergence result for the orthogonal symmetric tensor decomposition. Finally, through extensive experiments, we showcase the promise of our proposed basis function decomposition on realistic DNNs.

3.1. KERNEL REGRESSION

In kernel regression (KR), the goal is to fit a regression model f θ (x) = d i=1 θ i ϕ i (x) from the function class F Θ = {f θ (x) : θ ∈ R d } to observation y, where {ϕ i (x)} d i=1 are some known kernel functions. Examples of KR are linear regression, polynomial regression (including those described in Example 1), and neural tangent kernel (NTK) (Jacot et al., 2018) . Without loss of generality, we may assume that the kernel functions {ϕ i (x)} d i=1 are orthonormal. 4 Under this assumption, the basis coefficients can be defined as β i (θ) = θ i and the expected loss can be written as L(θ t ) = 1 2 E (f θt (x) -f θ ⋆ (x)) 2 = 1 2 ∥θ -θ ⋆ ∥ 2 = 1 2 d i=1 (β i (θ t ) -θ ⋆ i ) 2 . (5) Moreover, the coefficients satisfy the gradient independence condition. Therefore, an adaptation of Proposition 1 reveals that the dynamics of the basis coefficients are independent of each other. Proposition 2 (dynamics of β i (θ t )). Consider GD with a step-size that satisfies 0 < η < 1. Then, • for i ∈ S, we have β i (θ t ) = β ⋆ i -(1 -η) t (β ⋆ i -β i (θ 0 )), • for i ̸ ∈ S, we have β i (θ t ) = (1 -η) t β i (θ 0 ). Without loss of generality, we assume that 0 < θ ⋆ k ≤ • • • ≤ θ ⋆ 1 ≤ 1 and ∥θ 0 ∥ ∞ ≤ α. Then, given Proposition 2, we have |β i (θ t )| ≤ 2 + α for every 1 ≤ i ≤ k. Therefore, the gradient dominance is satisfied with parameters (C, γ) = (1/ √ 2 + α, 1/2). Since both gradient independence and gradient dominance are satisfied, the convergence of GD can be established with an appropriate initial point. Theorem 2. Suppose that θ 0 = α1, where α ≲ k|θ ⋆ k |/d. Then, within T ≲ (1/η) log (k|θ ⋆ 1 |/α) iterations, GD with step-size 0 < η < 1 satisfies ∥θ T -θ ⋆ ∥ ≲ α. Theorem 2 reveals that GD with large step-size and small initial point converges linearly to an ϵ-accurate solution, provided that the initialization scale is chosen as α = ϵ. This is indeed better than our result on the convergence of GD for general models in Theorem 1.

3.2. SYMMETRIC MATRIX FACTORIZATION

In symmetric matrix factorization (SMF), the goal is to learn a model f U ⋆ (X) = ⟨U ⋆ U ⋆⊤ , X⟩ with a low-rank matrix U ⋆ ∈ R d×r , where we assume that each element of X ∼ D is iid with E[X ij ] = 0 and E[X 2 ij ] = 1. Examples of SMF are matrix sensing Li et al. (2018) and completion Ge et al. (2016) . Given the eigenvectors {z 1 , . . . , z d } of U ⋆ U ⋆⊤ and a function class F Θ = {f U (X) : U ∈ R d×r ′ } with r ′ ≥ r, it was shown in Example 2 that the functions ϕ ij (X) = ⟨z i z ⊤ j , X⟩ define a valid orthogonal basis for F Θ with coefficients β ij (U ) = z i z ⊤ j , U U ⊤ . Therefore, we have L(U t ) = 1 4 E (f Ut (X)-f U ⋆ (X)) 2 = 1 4 E U t U ⊤ t -U ⋆ U ⋆⊤ , X 2 = 1 4 r ′ i,j=1 (β ij (U t )-β ⋆ ij ) 2 . Here, the true basis coefficients are defined as β ⋆ ii = σ i for i ≤ r, and β ⋆ ij = 0 otherwise. Moreover, one can write ∥∇β ii (U )∥ F = 2 z i z ⊤ i U F = 2 β ii (U ). Therefore, gradient dominance holds with parameters (C, γ) = (2, 1/2). However, gradient independence does not hold for this choice of function basis: given any pair (i, j) and (i, k) with j ̸ = k, we have ⟨∇β ij (U ), ∇β ik (U )⟩ = ⟨z j z ⊤ k , U U ⊤ ⟩ which may not be zero. Despite the absence of gradient independence, our next proposition characterizes the dynamic of β ij (U t ) via a finer control over the coefficient gradients. Proposition 3. Suppose that γ := min 1≤i≤r {σ iσ i+1 } > 0. Let U 0 = αB, where α ≲ min (ησ 2 r ) Ω(σ1/γ) , (σ r /d) Ω(σ1/γ) , (κ log 2 (d)) -Ω(1/(ησr)) and the entries of B are independently drawn from a standard normal distribution. Suppose that the step-size for GD satisfies η ≲ 1/σ 1 . Then, with probability of at least 1exp(-Ω(r ′ )): • For 1 ≤ i ≤ r, we have 0.99σ i ≤ β ii (U t ) ≤ σ i within O ((1/(ησ i )) log (σ i /α)) iterations. • For t ≥ 0 and i ̸ = j or i, j > r, we have |β ij (U t )| ≲ poly(α). Proposition 3 shows that GD with small random initialization learns larger eigenvalues before the smaller ones, which is commonly referred to as incremental learning. Incremental learning for SMF has been recently studied for gradient flow (Arora et al., 2019a; Li et al., 2020) , as well as GD with identical initialization for the special case r ′ = d Chou et al. (2020) . To the best of our knowledge, Proposition 3 is the first result that provides a full characterization of the incremental learning phenomenon for GD with random initialization on SMF. Theorem 3. Suppose that the conditions of Proposition 3 are satisfied. Then, with probability of at least 1exp(-Ω(r ′ )) and within T ≲ (1/(ησ r )) log (σ r /α) iterations, GD satisfies U T U ⊤ T -M ⋆ F ≲ r ′ log(d)dα 2 . ( ) It has been shown in (Stöger & Soltanolkotabi, 2021, Thereom 3.3 ) that GD with small random initialization satisfies U T U ⊤ T -M ⋆ F ≲ d 2 /r ′ 15/16 α 21/16 within the same number of iterations. Theorem 3 improves the dependency of the final error on the initialization scale α.

3.3. ORTHOGONAL SYMMETRIC TENSOR DECOMPOSITION

We use our approach to provide a new convergence guarantee for GD on the orthogonal symmetric tensor decomposition (OSTD). In OSTD, the goal is to learn f U ⋆ (X) = ⟨T U ⋆ , X⟩, where U ⋆ = [u ⋆ 1 , . . . , u ⋆ r ] ∈ R d×r and T U ⋆ = r i=1 u ⋆ i ⊗l = d i=1 σ i z i ⊗l is a symmetric tensor with order l and rank r. Here, σ 1 ≥ • • • ≥ σ d are tensor eigenvalues with σ r+1 = • • • = σ d = 0, and z 1 , . . . , z d are the corresponding tensor eigenvectors. The notation u ⊗l refers to the l-time outer product of u. We assume that X ∼ D is an l-order tensor whose elements are iid with zero mean and unit variance. Examples of OSTD are tensor regression (Tong et al., 2022) and completion (Liu et al., 2012) . When the rank of T U ⋆ is unknown, it must be overestimated. Even when the rank is known, its overestimation can improve the convergence of gradient-based algorithms (Wang et al., 2020) . This leads to an overparameterized model f U (X) = ⟨T U , X⟩, where T U = r ′ i=1 u i ⊗l with an overestimated rank r ′ ≥ r. Accordingly, the function class is defined as F Θ = f U (X) : U = [u 1 , • • • , u r ′ ] ∈ R d×r ′ . Upon defining a multi-index Λ = (j 1 , • • • , j l ), the functions ϕ Λ (X) = ⊗ l k=1 z j k , X for 1 ≤ j 1 , . . . , j l ≤ d form an orthonormal basis for F Θ with basis coefficients defined as β Λ (U ) = E ⟨T U , X⟩ ⊗ l k=1 z j k , X = r ′ i=1 u ⊗l i , ⊗ l k=1 z j k = r ′ i=1 l k=1 ⟨u i , z j k ⟩ , and the expected loss can be written as L(U ) = 1 2 E (f U (X) -f U ⋆ (X)) 2 = 1 2 r ′ i=1 u ⊗l i - r i=1 σ i z ⊗l i 2 F = 1 2 Λ (β Λ (U ) -β ⋆ Λ ) 2 , where the true basis coefficients are β ⋆ Λi = σ i for Λ i = (i, . . . , i), 1 ≤ i ≤ r, and β ⋆ Λ = 0 otherwise. Unlike KR and SMF, neither gradient independence nor gradient dominance are satisfied for OSTD with a random or equal initialization. However, we show that these conditions are approximately satisfied throughout the solution trajectory, provided that the initial point is nearly aligned with the eigenvectors z 1 , . . . , z r ′ ; in other words, cos(u i (0), z i ) ≈ 1 for every 1 ≤ i ≤ r ′ .foot_2 Assuming that the initial point satisfies this alignment condition, we show that the entire solution trajectory remains aligned with these eigenvectors, i.e., cos(u i (t), z i ) ≈ 1 for every 1 ≤ i ≤ r ′ and 1 ≤ t ≤ T . Using this key result, we show that both gradient independence and gradient dominance are approximately satisfied throughout the solution trajectory. We briefly explain the intuition behind our approach for gradient dominance and defer our rigorous analysis for gradient independence to the appendix. Note that if cos(u i (t), z i ) ≈ 1, then β Λi (U t ) ≈ ⟨u i (t), z i ⟩ l and ∥∇β Λi (U t )∥ F ≈ ∥∇ ui β Λi (U t )∥ ≈ l ⟨u i (t), z⟩ l-1 . Therefore, gradient dominance holds with parameters (C, γ) = (l, (l -1)/l). We will make this intuition rigorous in Appendix G. -2) . Then, GD with step-size η ≲ 1/(lσ 1 ) satisfies: Proposition 4. Suppose that the initial point U 0 is chosen such that ∥u i (0)∥ = α 1/l and cos(u i (0), z i ) ≥ √ 1 -γ, for all 1 ≤ i ≤ r ′ , where α ≲ d -l 3 and γ ≲ (lκ) -l/(l • For 1 ≤ i ≤ r, we have 0.99σ i ≤ β Λi (U t ) ≤ 1.01σ i within O (1/(ηlσ r ))α -l-2 l iterations. • For t ≥ 0 and Λ ̸ = Λ i , we have |β Λ (U t )| = poly(α). Proposition 4 shows that, similar to SMF, GD learns the tensor eigenvalues incrementally. However, unlike SMF, we require a specific alignment for the initial point. We note that such initial point can be obtained in a pre-processing step via tensor power method within a number of iterations that is almost independent of d (Anandkumar et al., 2017, Theorem 1) . We believe that Proposition 4 can be extended to random initialization; we leave the rigorous verification of this conjecture to future work. Equipped with this proposition, we next establish the convergence of GD on OSTD. Theorem 4. Suppose that the conditions of Proposition 4 are satisfied. Then, within T ≲ (1/(ηlσ r ))α -(l-2)/l iterations, GD satisfies ∥T U T -T U ⋆ ∥ 2 F ≲ rd l γσ l-1 l 1 α 1 l . Theorem 4 shows that, with appropriate choices of η and α, GD converges to a solution that satisfies ∥T U T -T U ⋆ ∥ 2 F ≤ ϵ within O(d l(l-2) /ϵ l-2 ) iterations. To the best of our knowledge, this is the first result establishing the convergence of GD with a large step-size on OSTD.

3.4. EMPIRICAL VERIFICATION ON NEURAL NETWORKS

In this section, we numerically show that the conjugate kernel after training (A-CK) can be used as a valid orthogonal basis for DNNs to capture the monotonicity of the solution trajectory of different optimizers on image classification tasks. To ensure consistency with our general framework, we use ℓ 2 -loss, which is shown to have a comparable performance with the commonly-used cross-entropy loss Hui & Belkin (2020) . In Appendix B, we extend our simulations to cross-entropy loss. The conjugate kernel (CK) is a method for analyzing the generalization performance of DNNs that uses the second to last layer (the layer before the last linear layer) at the initial point as the feature map (Daniely et al., 2016; Fan & Wang, 2020; Hu & Huang, 2021) . Recently, Long (2021) shows that A-CK, a variant of CK that is evaluated at the last epoch, better explains the generalization properties of realistic DNNs. Surprisingly, we find that A-CK can be used not only to characterize the generalization performance but also to capture the underlining solution trajectory of different gradient-based algorithms. To formalize the idea, note that any neural network whose last layer is linear can be characterized as f θ (x) = W ψ(x), where x ∈ R d is the input drawn from the distribution D, ψ(x) ∈ R m is the feature map with number of features m, and W ∈ R k×m is the last linear layer with k referring to the number of classes. We denote the trained model, i.e., the model in the last epoch, by f θ∞ (x) = W ∞ ψ ∞ (x). To form an orthogonal basis, we use SVD to obtain a series of basis functions ϕ i (x) = W ∞,i ψ ∞ (x) that satisfy E x∼D [∥ϕ i (x)∥ 2 ] = 1 and E x∼D [⟨ϕ i (x), ϕ j (x)⟩] = δ ij where δ ij is the delta function. Hence, the coefficient β i (θ t ) at each epoch t can be derived as β i (θ t ) = E x∼D [⟨f θt (x), ϕ j (x)⟩], where the expectation is estimated by its sample mean on the test set. More details on our implementation can be found in Appendix B. Performance on convolutional neural networks: We use LARS to train CNNs with varying depths on MNIST dataset. These networks are trained such that their test accuracies are above 96%. Figures 3a-3c illustrate the evolution of the top-5 basis coefficients after projecting LARS onto the orthonormal basis obtained from A-CK. It can be observed that the basis coefficients are consistently monotonic across different depths, elucidating the generality of our proposed basis function decomposition. In the appendix, we discuss the connection between the convergence of the basis functions and the test accuracy for different architectures and loss functions. Performance with different optimizers: The monotonic behavior of the projected solution trajectory is also observed across different optimizers. Figures 3d-3f show the solution trajectories of three optimizers, namely LARS, SGD, and AdamW, on AlexNet with the CIFAR-10 dataset. It can be seen that all three optimizers have a monotonic trend after projecting onto the orthonormal basis obtained from A-CK. Although our theoretical results only hold for GD, our simulations highlight the strength of the proposed basis function decomposition in capturing the behavior of other gradient-based algorithms on DNN. 

H Auxiliary Lemmas

A DISCUSSION AND FUTURE DIRECTION Extension to empirical loss and implication on generalization. The most significant future direction is to extend our analysis of expected loss to the finite sample regime. It has been shown that incremental learning can drive generalization (Gissin et al., 2019) . Hence, our proposed framework is likely to explain the puzzling generalization ability of overparameterized machine learning models. For such an extension, the main technical difficulty is the relaxation of gradient independence. In the finite-sample regime, gradient independence does not hold exactly due to the randomness of samples. Therefore, an alternative approach would be to establish gradient independence approximately and with high probability. We have successfully applied our framework to matrix factorization and tensor decomposition, where gradient independence only holds approximately. Therefore, we believe that more general guarantees on the finite sample regime are not out of reach. Extension to other optimization algorithms. In this paper, we mainly focused on GD as a representative of various optimization algorithms. Nonetheless, we believe that our analysis can be adapted to investigate other local-search optimization algorithms, such as GD with momentum (Nesterov, 1983) , SGD (Robbins & Monro, 1951) , and Adam (Kingma & Ba, 2014). Overall, our approach can offer a unified framework for examining the implicit bias and incremental learning phenomena for different local-search algorithms.

B ADDITIONAL EXPERIMENTS

In this section, we provide more details on our simulation and further explore the empirical strength of the proposed basis function decomposition on different datasets, optimizers, loss functions, and batch sizes; see Table 1 for a summary of our simulations in this section. 

B.1 NUMERICAL VERIFICATION OF OUR THEORETICAL RESULTS

In this section, we provide experimental evidence to support our theoretical results on kernel regression (KR), symmetric matrix factorization (SMF), and orthogonal symmetric tensor decomposition (OSTD). The results are presented in Figure 4 . Kernel regression. We randomly generate 20 orthonormal kernel functions. The true model is comprised of 4 signal terms with basis coefficients 10, 5, 3, 1. Figure 4a shows the trajectories of the top-4 basis coefficients of GD with initial point θ 0 = 5 × 10 -7 × 1 and step-size η = 0.4. It can be seen that GD learns different coefficients at the same rate, which is in line with Proposition 2. Symmetric matrix factorization. In this simulation, we aim to recover a rank-4 matrix M ⋆ = V ΣV ⊤ ∈ R 20×20 . In particular, we assume that V ∈ R 20×4 is a randomly generated orthonormal matrix and Σ = Diag{10, 5, 3, 1}. We consider a fully over-parameterized model where U ∈ R 20×20 (i.e., r ′ = 20). Figure 4b illustrates the incremental learning phenomenon that was proved in Proposition 3 for GD with small Gaussian initialization U ij i.i.d. ∼ N (0, α 2 ), α = 5 × 10 -7 and step-size η = 0.04. Orthogonal symmetric tensor decomposition. Finally, we present our simulations for OSTD. We aim to recover a rank-4 symmetric tensor of the form T ⋆ = 4 i=1 σ i z ⊗4 i where σ i are the nonzero eigenvalues with values {10, 5, 3, 1} and z i ∈ R 10 are the corresponding eigenvectors. We again consider a fully over-parameterized model with r ′ = 10. Figure 4c shows the incremental learning phenomenon for GD with an aligned initial point that satisfies cos(u i (0), z i ) ≥ 0.9983, 1 ≤ i ≤ r ′ and step-size η = 0.001. 

B.2 DERIVATION OF BASIS FUNCTIONS FOR DNNS

In this section, we provide more details on how we evaluate our proposed orthogonal basis induced by A-CK and calculate the corresponding coefficients β i (θ t ) for an arbitrary neural network. First, recall that any neural network whose last layer is linear can be characterized as f θ (x) = W ψ(x), where x ∈ R d is the input drawn from the distribution D, ψ(x) ∈ R m is the feature map with number of features m, and W ∈ R k×m is the last linear layer with k referring to the number of classes. We denote the trained model, i.e., the model in the last epoch, by f θ∞ (x) = W ∞ ψ ∞ (x). To form an orthogonal basis, we use SVD to obtain a series of basis functions ϕ i (x) = W ∞,i ψ ∞ (x) that satisfy E x∼D [∥ϕ i (x)∥ 2 ] = 1 and E x∼D [⟨ϕ i (x), ϕ j (x)⟩] = δ ij where δ ij is the delta function. Hence, the coefficient β i (θ t ) at each epoch t can be derived as β i (θ t ) = E x∼D [⟨f θt (x), ϕ j (x)⟩]. In all of our implementation, we use the test dataset to approximate the population distribution. Step 1: Obtaining the orthogonal basis ϕ i (x). We denote Ψ = [ψ(x 1 ), • • • , ψ(x N )] ∈ R m×N as the feature matrix where N is the number of the test data points. We write the SVD of Ψ as Ψ = U ΣV ⊤ . The right singular vectors collected in V can be used to define the desired orthogonal basis of {ψ i (x)} m i=1 . To this goal, we write the prediction matrix as F = W Ψ = W Ψ where W = W U Σ and Ψ = V ⊤ . Our goal is to define a set of matrices A i such that ϕ i (x) = A i ψ i (x) form a valid orthonormal basis for F = W Ψ = W Ψ. Before designing such A i , first note that, due to the orthogonality of { ψ i (x)}, we have E ∥ϕ i (x)∥ 2 = ∥A i ∥ 2 F , E [⟨ϕ i (x), ϕ j (x)⟩] = ⟨A i , A j ⟩ . Therefore, it suffices to ensure that {A i } are orthonormal. Consider the SVD of W as W = i σ i u i v ⊤ i . We define A i = u i v ⊤ i . Clearly, defined {A i } are orthonormal. Moreover, it is easy to see that the basis coefficients (treated as the true basis coefficients) are exactly the singular values of W . Step 2: Obtaining the basis coefficients β i (θ t ). After obtaining the desired orthonormal basis ϕ i (x), we can calculate the coefficient β i (θ t ) for each epoch. Given the linear layer W t and the feature matrix Φ t at epoch t, we can obtain the coefficients for the signal terms by projecting the prediction matrix F t = W t Ψ t onto Ψ = V ⊤ . In particular, we write the prediction matrix as P ψ F t = P ψ W t Ψ t = W t Ψ t V V ⊤ = W t Ψ where W t = W t Ψ t V . Hence, the basis coefficients can be easily calculated as β i (θ t ) = W t , u i v ⊤ i .

B.3 FURTHER DETAILS ON THE EXPERIMENTS

In this section, we provide more details on our experiments presented in the main body of the paper and compare them with other DNN architectures. All of our experiments are implemented in Python 3.9, Pytorch 1.12.1 environment and run through a local server SLURM using NVIDIA Tesla with V100-PCIE-16GB GPUs. We use an additional NNGeometry package for calculating batch gradient, and our implemention of ViT is adapted from https://juliusruseckas.github.io/ml/cifar10-vit.html. To ensure consistency with our theoretical results, we drop the last softmax operator and use the ℓ 2 -loss throughout this section. All of our training data are augmented by RandomCrop and RandomHorizontalFlip, and normalized by mean and standard deviation. Experimental details for Figure 1 . Here, we describe our implementation details for Figure 1 , and present additional experiments on VGG-11, ResNet-34, and ResNet-50 with the CIFAR-10 dataset. The results can be seen in Figure 5 . We use standard data augmentation for all architectures except for ViT. For ViT, we only use data normalization. To obtain a stable A-CK, we trained the above models for 300 epochs. For ResNet-18, we used LARS with a learning rate of 0.5 and applied small initialization with α = 0.3, i.e., we scale the default initial point by α = 0.3. For ResNet-34 and ResNet-50, we choose the default learning rate and apply small initialization with α = 0.3. For ViT, we use AdamW with a learning rate of 0.01. The remaining parameters are set to their default values. Experiments details for the first row of Figure 3 . Here we conduct experiments on MNIST dataset with different CNN architectures. The CNNs are composed of l blocks of layers, followed by a single fully connected layer. A block of a CNN consists of a convolutional layer, an activation layer, and a pooling layer. In our experiments, we use ReLU activation and vary the depth of the network. For the first block, we used identity pooling. For the remaining blocks, we used max-pooling. For 2-block CNN, we set the convolutional layer width to 256 and 64, respectively. For 3-block CNN, we set the convolutional layer width to 256, 128, and 64, respectively. And for 4-block CNN, we set the convolutional layer width to 256, 128, 128, and 64, respectively. To train these networks, we used LARS with the learning rate of 0.05. The remaining parameters are set to their default values. We run 20 epochs to calculate A-CK. Experiments details for second row of Figure 3 . We conduct experiments to compare the performance of different optimizers on the CIFAR-10 dataset. In particular, we use AlexNet to compare the performance of three optimizers, i.e., SGD, AdamW, LARS. For SGD, we set the base learning rate to be 2 with Nesterov momentum of 0.9 and weight decay of 0.0001, together with the "linear warm-up" technique. 6 For AdamW, we set the learning rate to 0.01 and keep the remaining parameters unchanged. For LARS, we set the learning rate to 2 with Nesterov momentum of 0.9 and weight decay of 0.0001. The remaining parameters are set to the default setting.

B.4 EXPERIMENTS FOR CIFAR-100

In this section, we conduct experiments using the CIFAR-100 dataset which is larger than both CIFAR-10 and MNIST. Our simulations are run on AlexNet, VGG-11, ViT, ResNet-18, ResNet-34, and ResNet-50. In particular, we use the "loss scaling trick" (Hui & Belkin, 2020 ) defined as follows: consider the datapoint (x, y) where x ∈ R d is the input and y ∈ R k is a one-hot vector with 1 at position i. Then, the scaled ℓ 2 -loss is defined as ℓ 2,scaling (x) = k • (f θ (x)[i] -M ) 2 + i ′ ̸ =i (f θ (x)[i ′ ]) 2 , for some constants k, M > 0. We set these parameters to k = 1, M = 4. For AlexNet and VGG-11, we use LARS with a base learning rate of η = 1. For ViT, we use AdamW with a base learning rate of 0.01 and batch size of 256. For ResNet architectures, we use SGD with a base learning rate of 0.3 and batch size of 64. We also add 5 warm-up epochs for ResNets. All the remaining parameters for the above architectures are set to their default values. The results can be seen in Figure 6 . Our experiments highlight a trade-off between the monotonicity of the projected solution trajectories and the test accuracy: in order to obtain a higher test accuracy, one typically needs to pick a larger learning rate, which in turn results in more sporadic behavior of the solution trajectories. Nonetheless, even with large learning, the basis coefficients remain relatively monotonic after the first few epochs and converge to meaningful values.

B.5 EXPERIMENTS FOR DIFFERENT LOSSES

Next, we compare the projected solution trajectories on two loss functions, namely ℓ 2 -loss and cross-entropy (CE) loss. We use LARS to train AlexNet on the CIFAR-10 dataset with both ℓ 2 -loss and CE loss. In particular, we add the softmax operator before training the CE loss. For CE loss, we set the base learning rate of LARS to 1. For ℓ 2 -loss, we use the base learning rate of 2. The remaining parameters are set to their default values. The results can be seen in Figure 7 . We observe that, similar to the ℓ 2 -loss, the solution trajectory of the CE loss behaves monotonically after projecting onto the orthogonal basis induced by A-CK. Inspired by these observations, another venue for future research would be to extend our framework to general loss functions. Interestingly, the convergence of the basis coefficients is much slower than those of the ℓ 2 -loss. This is despite the fact that CE loss can learn slightly faster than ℓ 2 -loss in terms of test accuracy as shown by Hui & Belkin (2020) .

B.6 EXPERIMENTS FOR DIFFERENT BATCH SIZE

Next, we study the effect of different batch sizes on the solution trajectory. We train AlexNet on the CIFAR-10 dataset. When testing for different batch sizes, we follow the "linear scaling" rule, i.e., the learning rate scales linearly with the batch size. For batch size of 32, we used SGD with a base learning rate η = 0.1 and 5 warm-up epochs. For batch size of 64, we used SGD with a base learning rate of η = 0.2 and 5 warm-up epochs. For batch size of 256, we used LARS with a base learning rate of 1. The remaining hyperparameters are set to their default values. The results are reported in Figure 8 . We see that the projected solution trajectories share a similar monotonic behavior for different batch sizes.

B.7 EXPERIMENTS FOR RESNET-18 WITH SGD ON CIFAR-10

Lastly, we train ResNet-18 with SGD on CIFAR-10. The results can be seen in Figure 9 . To achieve good generalization, we use a large learning rate η = 0.3 (with 5 warm-up epochs starting at 1 × 10 -5 ), and decrease the learning rate by 0.33 for every 50 epochs. Moreover, We use a large batch size 512 to better imitate the trajectory of GD. Based on the simulation result, we can find an approximately monotonic (modulo the fluctuations caused by the randomness of the gradients) behavior of the dynamics of the top-5 components, which again validifies our theoretical results. 

C RELATED WORK

GD for general nonconvex optimization. Gradient descent and its stochastic or adaptive variants are considered as the "go-to" algorithms in large-scale (unconstrained) nonconvex optimization. Because of their first-order nature, they are known to converge to first-order stationary points (Nesterov, 1998) . Only recently it has been shown that GD (Lee et al., 2019; Panageas et al., 2019) and its variants, such as perturbed GD (Jin et al., 2017) and SGD (Fang et al., 2019; Daneshmand et al., 2018) , can avoid saddle points and converge to a second-order stationary point. However, these guarantees do not quantify the distance between the obtained solution and the globally optimal and/or true solutions. To the best of our knowledge, the largest subclass of nonconvex optimization problems for which GD or its variants converge to meaningful solutions are those with benign landscapes. These problems include different variants of low-rank matrix optimization with exactly parameterized rank, namely matrix completion (Ge et al., 2016) , matrix sensing (Ge et al., 2017; Zhang et al., 2021) , dictionary learning (Sun et al., 2016) , and robust PCA (Fattahi & Sojoudi, 2020) , as well as deep linear neural networks (Kawaguchi, 2016) . However, benign landscape is too restrictive to hold in practice; for instance,Zhang (2021) shows that spurious local minima are ubiquitous in the low-rank matrix optimization, even under fairly mild conditions. Therefore, the notion of benign landscape cannot be used to explain the success of local search algorithms in more complex learning tasks. GD for specific learning problems. Although there does not exist a unifying framework to study the global convergence of GD for general learning tasks, its convergence has been established in specific learning problems, such as kernel regression (which includes neural tangent kernel (Jacot et al., 2018) ), sparse recovery (Vaskevicius et al., 2019) , matrix factorization (Li et al., 2018) , tensor decomposition (Wang et al., 2020; Ge et al., 2021) , and linear neural network (Arora et al., 2018) . In what follows, we review specific learning tasks that are most related to our work. The convergence of GD on kernel regression was studied far before the emergence of deep learning. 2021) studied the dynamic of a modified GD for overcomplete nonconvex tensor decomposition. Moreover, Razin et al. (2021; 2022) analyzed the implicit regularization and the incremental learning of gradient flow in hierarchical tensor decomposition and showed its connection to neural networks. Conjugate kernel. Conjugate kernel (CK) at the initial point has been considered as one of the promising methods for studying the generalization properties of DNNs (Daniely et al., 2016; Hu & Huang, 2021; Fan & Wang, 2020) . However, similar to NTK, a major shortcoming of CK is that it cannot fully characterize the behavior of the practical neural networks (Vyas et al., 2022) . Recent results have suggested that the conjugate kernel evaluated after training (for both NTK and CK) can better describe the generalization properties of DNNs (Fort et al., 2020; Long, 2021) . In our work, we show that such "after kernel regime" can also be adapted to study the optimization trajectory of practical DNNs.

D PROOFS FOR GENERAL FRAMEWORK D.1 PROOF OF PROPOSITION 1

To prove this proposition, we first combine (1) and (GD): θ t+1 = θ t -η∇L(θ t ) = θ t -η i∈I (β i (θ t ) -β ⋆ i ) ∇β i (θ t ). For notational simplicity, we denote E(θ t ) = 1 2 i∈E β 2 i (θ t ). Then, one can write θ t+1 = θ t -η∇L(θ t ) = θ t -η i∈S (β i (θ t ) -β ⋆ i ) ∇β i (θ t ) -η∇E(θ t ). Due to the Mean-Value Theorem, there exists a ξ ∈ R m such that β i (θ t+1 ) = β i   θ t -η j∈I β j (θ t ) -β ⋆ j ∇β j (θ t )   = β i (θ t ) -η j∈I β j (θ t ) -β ⋆ j ⟨∇β i (θ t ), ∇β j (θ t )⟩ + η 2 2 ∇L(θ t ), ∇ 2 β i (ξ)∇L(θ t ) . (12) On the other hand, one can write ∇L(θ t ), ∇ 2 β i (ξ)∇L(θ t ) ≤ sup θ ∇ 2 β i (θ) ∥∇L(θ t )∥ 2 . ( ) For sup θ ∇ 2 β i (θ) , we further have sup θ ∇ 2 β i (θ) = sup θ ∇ 2 E[f θ (x)ϕ(x)] = sup θ E[∇ 2 f θ (x)ϕ(x)] ≤ sup θ E ∇ 2 f θ (x) |ϕ(x)| (a) ≤ sup θ E ∇ 2 f θ (x) 2 1/2 E ϕ 2 (x) 1/2 (b) ≤ L H . Here, we used Cauchy-Schwartz inequality for (a). Moreover, for (b), we used Assumption 1 and the definition of the orthonormal basis. Hence, we have β i (θ t+1 ) = β i (θ t ) -η j∈I β j (θ t ) -β ⋆ j ⟨∇β i (θ t ), ∇β j (θ t )⟩ ± (1/2)η 2 L H ∥∇L(θ t )∥ 2 . (15) Now, it suffices to bound ∥∇L(θ t )∥ 2 . Using Cauchy-Schwarz inequality, we have ∥∇L(θ t )∥ 2 = ∇E 1 2 (f θt -f θ ⋆ ) 2 2 = ∥E [(f θt -f θ ⋆ ) ∇f θt ]∥ 2 ≤ ∥f θt -f θ ⋆ ∥ 2 L 2 (D) ∥∇f θt ∥ 2 L 2 (D) ≤ 4L 2 g L 2 f . Therefore, we conclude that β i (θ t+1 ) = β i (θ t ) -η j∈I β j (θ t ) -β ⋆ j ⟨∇β i (θ t ), ∇β j (θ t )⟩ ± 2η 2 L H L 2 g L 2 f . ( ) which completes the proof. □

D.2 PROOF OF THEOREM 1

Proof. Invoking the gradient independence condition, Proposition 1 can be simplified as β i (θ t+1 ) = β i (θ t ) -η j∈I β j (θ t ) -β ⋆ j ⟨∇β i (θ t ), ∇β j (θ t )⟩ ± 2η 2 L H L 2 g L 2 f = β i (θ t ) -η (β i (θ t ) -β ⋆ i ) ∥∇β i (θ t )∥ 2 ± 2η 2 L H L 2 g L 2 f . ( ) We next provide upper and lower bounds for the residual and signal terms. Recall that S = {i ∈ I : β ⋆ i ̸ = 0} , and E = I\S. We first consider the dynamic of the signal term β i (θ t ), i ∈ S. Without loss of generality, we assume β ⋆ i > 0. Then, due to the gradient dominance condition, we have the following lower bound β i (θ t+1 ) ≥ β i (θ t ) -η (β i (θ t ) -β ⋆ i ) ∥∇β i (θ t )∥ 2 -2η 2 L H L 2 g L 2 f ≥ 1 + C 2 η (β ⋆ i -β i (θ t )) β 2γ-1 i (θ t ) β i (θ t ) -2η 2 L H L 2 g L 2 f , i ∈ S (19) Next, for the dynamic of the residual term β i (θ t ), i ∈ E, we have β i (θ t+1 ) = 1 -η ∥∇β i (θ t )∥ 2 β i (θ t ) ± 2η 2 L H L 2 g L 2 f . Next, we show that ∥∇β i (θ t )∥ ≤ L g . One can write, ∥∇β i (θ t )∥ = ∥∇E [f θt (x)ϕ i (x)]∥ ≤ E [∥∇f θt (x)∥ |ϕ i (x)|] ≤ E ∥∇f θt (x)∥ 2 1/2 E ϕ 2 (x) 1/2 ≤ L g . ( ) Due to our choice of the step-size, we have η ≲ 1 L 2 g , which in turn implies 1 -η ∥∇β i (θ t )∥ 2 ≤ 1. Therefore, we have |β i (θ t+1 )| ≤ |β i (θ t )| + 2η 2 L H L 2 g L 2 f . (22) Now, we are ready to prove the theorem. We divide it into two cases. Case 1: γ = 1 2 . In this case, since we set the step-size η ≲ α √ dC 2 L H L 2 g L 2 f β ⋆ k log -1 dβ ⋆ k C1α , we can simplify the dynamics of both signal and residual terms in Equation 19and Equation 20 as follows β i (θ t+1 ) ≥ 1 + 0.5C 2 η (β ⋆ i -β i (θ t )) β i (θ t ) ∀i ∈ S, |β i (θ t+1 )| ≤ |β i (θ t )| + 2η 2 L H L 2 g L 2 f ∀i ∈ E. ( ) We first analyze the dynamic of signal β i (θ t ) for i ∈ S. To this goal, we further divide this case into two phases. In the first phase, we assume C 1 α ≤ β i (θ t ) ≤ 1 2 β ⋆ i . Under this assumption, we can simplify the dynamic of β i (θ t ) as β i (θ t+1 ) ≥ 1 + 0.25C 2 ηβ ⋆ i β i (θ t ). Therefore, within T 1 = O 1 C 2 ηβ ⋆ i log β ⋆ i C1α iterations, β i (θ t ) becomes larger than 1 2 β ⋆ i . In the second phase, we assume that β i (θ t ) ≥ β ⋆ i /2 and define y t = β ⋆ i -β i (θ t ). One can write y t+1 ≤ 1 -0.5C 2 ηβ i (θ t ) y t ≤ (1 -0.25C 2 ηβ ⋆ i )y t . Hence, with additional T 2 = O 1 C 2 ηβ ⋆ i log dβ ⋆ i α , we have y t ≤ α √ d which implies β i (θ t ) ≥ β ⋆ i -α √ d . Next, we show that there exists a time t ⋆ such that β ⋆ i -1 √ d α ≤ β i (θ t ⋆ ) ≤ β ⋆ i + 1 √ d α. Without loss of generality, we assume that t ⋆ is the first time that β i (θ t ) ≥ β ⋆ i -1 √ d α. Due to the dynamic of β i (θ t ), the distance between two adjacent iterations can be upper bounded as |β i (θ t+1 ) -β i (θ t )| ≤ η |β ⋆ i -β i (θ t )| ∥∇β i (θ t )∥ 2 + 2η 2 L H L 2 g L 2 f ≤ ηL 2 g |β ⋆ i -β i (θ t )| + 2η 2 L H L 2 g L 2 f . In particular, for t = t ⋆ -1, we have β i (θ t ⋆ -1 ) ≤ β ⋆ i -1 √ d α, which in turn implies β i (θ t ⋆ ) ≤ β i (θ t ⋆ -1 ) + ηL 2 g |β ⋆ i -β i (θ t ⋆ -1 )| + 2η 2 L H L 2 g L 2 f ≤ β ⋆ i + 2η 2 L H L 2 g L 2 f ≤ β ⋆ i + 1 √ d α. ( ) Published as a conference paper at ICLR 2023 Therefore, for each i ∈ S, we have |β i (θ t ) -β ⋆ i | ≤ 1 √ d α within T i = O 1 C 2 ηβ ⋆ i log dβ ⋆ i C1α iterations. Meanwhile, we can show that the residual term |β i θ t )|, ∀i ∈ E remains small for max i∈S T i = O 1 C 2 ηβ ⋆ k log dβ ⋆ k C1α iterations: |β i (θ t )| ≤ |β i (θ 0 )| + max i∈S T i • 2η 2 L H L 2 g L 2 f = |β i (θ 0 )| + O 1 β ⋆ k η log β ⋆ k C 1 α L H L 2 g L 2 f = |β i (θ 0 )| + O 1 √ d α . (28) Therefore, we have that within T = O 1 ηβ ⋆ k log dβ ⋆ k C1α iterations: ∥f θ T -f θ ⋆ ∥ 2 L 2 (D) = i∈S (β ⋆ i -β i (θ T )) 2 + i∈E β 2 i (θ T ) ≲ α 2 . ( ) Case 2: 1 2 < γ ≤ 1. In this case, we have the following bounds for the signal and residual terms β i (θ t+1 ) ≥ 1 + C 2 η (β ⋆ i -β i (θ t )) β 2γ-1 i (θ t ) β i (θ t ) -2η 2 L H L 2 g L 2 f ∀i ∈ S, |β i (θ t+1 )| ≤ |β i (θ t )| + 2η 2 L H L 2 g L 2 f ∀i ∈ E. ( ) We first analyze the dynamic of the signal term β i (θ t ) for i ∈ S. We will show that |β i (θ t ) -β ⋆ i | ≤ α √ k within T = O 1 C 2 ηβ ⋆ i α 2γ-1 iterations. Due to η ≲ α 2γ √ dC 2 L H L 2 g L 2 f β ⋆2γ k , we can further simplify the dynamic of β i (θ t ) as β i (θ t+1 ) ≥ 1 + 0.5C 2 η (β ⋆ i -β i (θ t )) β 2γ-1 i (θ t ) β i (θ t ). Next, we divide our analysis into two phases. In the first phase, we have β i (θ t ) ≤ 1 2 β ⋆ i . We denote the number of iterations for this phase as T i,1 . We further divide this period into ⌈log (β ⋆ i /2α)⌉ substages. In each Substage k, we have C 1 2 k-1 α ≤ β i (θ t ) ≤ C 1 2 k α. Let t k be the number of iterations in Substage k. We first provide an upper bound for t k . To this goal, note that at this substage β i (θ t+1 ) ≥ 1 + 0.5C 2 η (β ⋆ i -β i (θ t )) β 2γ-1 i (θ t ) β i (θ t ) ≥ 1 + 0.25C 2 ηβ ⋆ i β 2γ-1 i (θ t ) β i (θ t ). (32) Hence, we have t k ≤ log(2) log (1 + 0.25C 2 ηβ ⋆ i (C 1 2 k-1 α) 2γ-1 ) . ( ) Summing over t k , we obtain an upper bound for T i,1 T i,1 = ⌈log(β ⋆ 1 /2α)⌉ i=1 t k ≲ ∞ k=1 1 C 2 ηβ ⋆ i (C 1 2 k-1 α) 2γ-1 ≲ 1 C 2 ηβ ⋆ i (C 1 α) 2γ-1 . ( ) Via a similar argument, we can show that in the second phase, we have |β i (θ t ) -β ⋆ i | ≲ 1 √ d α within additional T i,2 = O 1 C 2 ηβ ⋆ i (C1α) 2γ-1 iterations. Therefore, for each i ∈ S, we conclude that |β i (θ t ) -β ⋆ i | ≤ α √ d within T i = T i,1 + T i,2 = O 1 C 2 ηβ ⋆ i (C1α) 2γ-1 iterations. Meanwhile, for the residual term β i (θ t ), i ∈ E, we have |β i (θ t )| ≤ |β i (θ 0 )| + max i∈S T i • 2η 2 L H L 2 g L 2 f = |β i (θ 0 )| + O 1 √ d α . Therefore, ∥f θ T -f θ ⋆ ∥ 2 L 2 (D) = i∈S (β ⋆ i -β i (θ T )) 2 + i∈E β 2 i (θ T ) ≲ α 2 , ( ) within T = O 1 C 2 ηβ ⋆ k (C1α) 2γ-1 iterations. This completes the proof.

E PROOFS FOR KERNEL REGRESSION E.1 PROOF OF PROPOSITION 2

Note that θ t+1 = θ tη(θ tθ ⋆ ), and β i (θ) = θ i . Hence, for every 1 ≤ i ≤ k, we have β i (θ t+1 ) = β i (θ t ) + η(θ ⋆ i -β i (θ t )). ( ) This in turn implies β ⋆ i -β i (θ t+1 ) = (1 -η) (β ⋆ i -β i (θ t )) =⇒ β ⋆ i -β i (θ t ) = (1 -η) t (β ⋆ i -β i (θ 0 )) . ( ) For i > k, we have β i (θ t+1 ) = (1 -η)β i (θ t ) =⇒ β i (θ t ) = (1 -η) t β i (θ 0 ), which completes the proof. □

E.2 PROOF OF THEOREM 2

Due to our choice of initial point θ 0 = α1, α ≲ |β ⋆ k |, we have |β ⋆ i -β i (θ 0 )| ≤ 2|β ⋆ i |. Hence, by Proposition 2, we have |β ⋆ i -β i (θ t )| = (1 -η) t |β ⋆ i -β i (θ 0 )| ≤ 2(1 -η) t β ⋆ i , i ∈ S ( ) |β i (θ t )| ≤ (1 -η) t α, i ∈ E. Therefore, to prove |β ⋆ i -β i (θ t )| ≤ α √ k , i ∈ S, it suffices to have 2(1 -η) t β ⋆ i ≤ α √ k =⇒ t ≳ 1 η log kθ i α . ( ) On the other hand, to ensure |β i (θ t )| ≤ α √ d , i ∈ E, it suffices to have (1 -η) t α ≤ α √ d =⇒ t ≳ 1 η log (d) . ( ) Recall that α ≲ k|θ k | d and |θ 1 | ≥ |θ 2 | ≥ • • • ≥ |θ k | > 0. Therefore, within T = O 1 η log k|θ1| α iterations, we have ∥θ T -θ ⋆ ∥ ≤ k i=1 α 2 k + d i=k+1 α 2 d ≤ √ 2α, which completes the proof. □

F PROOFS FOR SYMMETRIC MATRIX FACTORIZATION F.1 INITIALIZATION

We start by proving that both lower bound on signals at θ 0 and upper bound on energy at θ 0 are satisfied with high probability. Recall that each element of U 0 is drawn from N (0, α 2 ). The following proposition characterizes the upper and lower bounds for different coefficients β ij (U 0 ). Proposition 5 (Initialization). With probability at least 1e -Ω(r ′ ) , we have β ii (U 0 ) = z i z ⊤ i , U 0 U ⊤ 0 ≥ 1 4 r ′ α 2 , for all 1 ≤ i ≤ r, and |β ij (U 0 )| = z i z ⊤ j , U 0 U ⊤ 0 ≤ 4 log(d)r ′ α 2 , for all i ̸ = j or i, j > r. ( ) In light of the above equality and the orthogonality of {z i } i∈ [d] , the RHS of Equation 55 can be written in terms of β kl (U t ), 1 ≤ k, l ≤ d. In particular (U t U ⊤ t -M ⋆ )U t U ⊤ t , z i z ⊤ j =   i,j β ij (U t ) -β ⋆ ij z i z ⊤ j   i,j β ij (U t )z i z ⊤ j , z i z ⊤ j = i,j k (β ik (U t ) -β ⋆ ik ) β kj (U t )z i z ⊤ j , z i z ⊤ j = k (β ik (U t ) -β ⋆ ik ) β kj (U t ) = (β ii (U t ) -σ i ) β ij (U t ) + k̸ =i (β ik (U t ) -β ⋆ ik ) β kj (U t ). (56) Other terms in Equation 55 can be written in terms of β ij (U t ) in an identical fashion. Substituting these derivations back in Equation 55, we obtain β ij (U t+1 ) = (1 + η (σ i + σ j )) β ij (U t ) -2η k β ik (U t )β kj (U t ) + η 2   k,l β ik (U t )β kl (U t )β lj (U t ) -(σ i + σ j ) k β ik (U t )β kj (U t ) + σ i σ j β ij (U t )   . (57) Note that the above equality holds for any 1 ≤ i, j ≤ r ′ . In particular, for i = j, we further have β ii (U t+1 ) = (1 + 2η (σ i -β ii (U t ))) β ii (U t ) -2η j̸ =i β 2 ij (U t ) + η 2   j,k β ij (U t )β ik (U t )β jk (U t ) -2σ i j β 2 ij (U t ) + σ 2 i β ii (U t )   , which completes the proof.

F.3 PROOFS OF PROPOSITION 3 AND THEOREM 3

To streamline the presentation, we prove Proposition 3 and Theorem 3 simultaneously. The main idea behind our proof technique is to divide the solution trajectory into r substages: in Substage i, the basis coefficient β ii (U t ) converges linearly to σ i while all the remaining coefficients remain almost unchanged. More precisely, suppose that Substage i lasts from iteration t i,s to t i,e . We will show that β ii (U ti,e ) ≈ σ i and β ij (U ti,e ) ≈ β ij (U ti,s ). Recall that γ = min 1≤i≤r σ iσ i+1 is the eigengap of the true model, which we assume is strictly positive. Substage 1. In the first stage, we show that β 11 (U t ) approaches σ 1 and |β ij (U t )|, i, j ≥ 2 remains in the order of poly(α) within T 1 = O 1 ησ1 log σ1 α iterations. To formalize this idea, we further divide this substage into two phases. In the first phase (which we refer to as the warm-up phase), we show that β 11 (U t ) will quickly dominate the remaining terms β ij (U t ), ∀(i, j) ̸ = (1, 1) within O 1 ηγ log log(d) iterations. This is shown in the following lemma. Lemma 1 (Warm-up phase). Suppose that the initial point satisfies Equation 45 and Equation 46. Then, within O 1 ηγ log log(d) iterations, we have |β ij (U t )| ≤ β 11 (U t ) ≲ r ′ α 2 log(d) 1+σ1/γ . Proof. To show this, we use an inductive argument. Due to our choice of the initial point, we have |β ij (U 0 )| ≲ r ′ α 2 log(d) 1+σ1/γ , 1 ≤ i, j ≤ d. Now, suppose that at time t, we have |β ij (U t )| ≲ r ′ α 2 log(d) 1+σ1/γ , 1 ≤ i, j ≤ d. Then, by Proposition 6, we have β 11 (U t+1 ) = (1 + 2η (σ 1 -β 11 (U t ))) β 11 (U t ) -2η j̸ =1 β 2 1j (U t ) + η 2   j,k β 1j (U t )β 1k (U t )β jk (U t ) -2σ 1 j β 2 1j (U t ) + σ 2 1 β 11 (U t )   ≥ 1 + 2ησ 1 -O ηdr ′ α 2 log(d) 1+σ1/γ β 11 (U t ). Similarly, for the remaining coefficients β ij (U t ), ∀(i, j) ̸ = (1, 1), we have |β ij (U t+1 )| ≤ 1 + η σ 1 + σ 2 + O dr ′ α 2 log(d) 1+σ1/γ |β ij (U t )| . Note that the stepsize satisfies η ≲ 1 σ1 , and σ 1 -σ 2 ≥ γ ≳ dr ′ α 2 log(d) 1+σ1/γ . Hence, we have β 11 (U t+1 ) |β ij (U t+1 )| ≥ 1 + 2ησ 1 -O ηdr ′ α 2 log(d) 1+σ1/γ 1 + η σ 1 + σ 2 + O dr ′ α 2 log(d) 1+σ1/γ β 11 (U t ) |β ij (U t )| = 1 + η(σ 1 -σ 2 -O(dr ′ α 2 log(d) 1+σ1/γ )) 1 + η(σ 1 + σ 2 + O(dr ′ α 2 log(d) 1+σ1/γ )) β 11 (U t ) |β ij (U t )| ≥ 1 + ηγ 1 + 0.5η(σ 1 + σ 2 ) β 11 (U t ) |β ij (U t )| ≥ (1 + 0.5ηγ) β 11 (U t ) |β ij (U t )| . This further implies β 11 (U t ) |β ij (U t )| ≥ (1 + 0.5ηγ) t β 11 (U 0 ) |β ij (U 0 )| . ( ) On the other hand, Equation 49 and Equation 50 imply that β 11 (U 0 ) |β ij (U 0 )| ≥ 1 16 log(d) . Hence, within O 1 ηγ log log(d) iterations, we have β ij (U t ) ≥ |β 11 (U t )| , ∀(i, j) ̸ = (1, 1 ). Moreover, we have that during this phase, |β ij (U t )| ≤ β 11 (U t ) ≤ r ′ α 2 log(d) (1 + 2ησ 1 ) O( 1 ηγ log log(d)) ≲ r ′ α 2 log(d) 1+σ1/γ , which completes the proof. After the warm-up phase, we show that β 11 (U t ) quickly approaches σ 1 while the remaining coefficients remain small. Lemma 2 (Fast growth). After the warm-up phase followed by O 1 ησ1 log σ1 α iterations, we have 0.99σ 1 ≤ β 11 (U t ) ≤ σ 1 . Moreover, for |β ij (U t )| , ∀(i, j) ̸ = (1, 1), we have |β ij (U t )| ≲ σ 1 r ′ log(d)α 2σ 1 -σ i -σ j 2σ 1 . Before providing the proof of Lemma 2 we analyze an intermediate logistic map which, as will be shown later, closely resembles the dynamic of β ij (U t ): x t+1 = (1 + ησ -ηx t )x t , x 0 = α. (logistic map) The following two lemmas characterize the dynamic of a single logistic map, as well as the dynamic of the ratio between two different logistic maps. Lemma 3 (Iteration complexity of logistic map). Suppose that α ≤ ε ≤ 0.1σ. Then, for the logistic map, we have x T ≥ σε within T = 1 log(1+ησ) (log(4σ/α) + log(4σ/ε)) iterations. Lemma 4 (Separation between two logistic maps). Let σ 1 , σ 2 be such that σ 1σ 2 > 0, and x t+1 = (1 + ησ 1 -ηx t )x t , x 0 = α y t+1 = (1 + ησ 2 -ηy t )y t , y 0 = α Then, within T = 1 log(1+ησ1) log 16σ 2 1 εα iterations, we have σ 1 -ε ≤ x T ≤ σ 1 , y T ≤ 16σ 2 1 ε α σ 1 -σ 2 σ 1 +σ 2 . The proofs of Lemmas 3 and 4 are deferred to Appendix F.4. We are now ready to provide the proof of Lemma 2. Proof of Lemma 2. Similar to the proof of Lemma 1, we use an inductive argument. Suppose that t 0 is when the second phase starts. According to Lemma 1, we have |β ij (U t0 )| ≤ β 11 (U t0 ) ≲ r ′ α 2 log(d) 1+σ1/γ ≲ σ 1 r ′ log(d)α 2σ 1 -σ i -σ j 2σ 1 . Therefore, the base case of our induction holds. Next, suppose that at some time t within the second phase, we have |β ij (U t )| ≲ σ 1 α 2σ 1 -σ i -σ j 2σ 1 , ∀(i, j) ̸ = (1, 1). Our goal is to show that |β ij (U t+1 )| ≲ σ 1 α 2σ 1 -σ i -σ j 2σ 1 . To this goal, we consider two cases. Case I: i ≤ r or j ≤ r and (i, j) ̸ = (1, 1). We have |β ij (U t+1 )| ≤ (1 + η (σ i + σ j -2β ii (U t ) -2β jj (U t ))) |β ij (U t )| + 2η k̸ =i,j |β ik (U t )β jk (U t )| + η 2   k,l β ik (U t )β kl (U t )β lj (U t ) -(σ i + σ j ) k β ik (U t )β kj (U t ) + σ i σ j β ij (U t )   (a) ≤ 1 + η σ i + σ j + ησ i σ j + O dα γ σ 1 |β ij (U t )|. Here in (a) we used the assumption |β ij (U t )| ≲ σ 1 α 2σ 1 -σ i -σ j 2σ 1 ≲ σ 1 α γ 2σ 1 , ∀(i, j) ̸ = (1, 1 ). Hence, by Lemma 4, we have |β ij (U t+1 )| ≲ σ 1 α 2σ 1 -σ i -σ j 2σ 1 , where 1 ≤ i ≤ r or 1 ≤ j ≤ r. Case II: i, j ≥ r + 1. For β ij (U t ) such that i, j ≥ r + 1, its dynamic is characterized by |β ij (U t+1 )| ≤ (1 -η (2β ii (U t ) + 2β jj (U t ))) |β ij (U t )| + 2η k̸ =i,j |β ik (U t )β jk (U t )| + η 2 k,l |β ik (U t )β kl (U t )β lj (U t )| ≤ 1 + ηO dα γ σ 1 |β ij (U t )| . Hence, for t ≲ 1 ησ1 log σ1 α , we have |β ij (U t )| ≤ 1 + ηO dα γ σ 1 O 1 ησ 1 log( σ 1 α ) |β ij (U 0 )| ≲ |β ij (U 0 )| since we assume α ≲ σ1 d σ1/γ . This completes our inductive proof for |β ij (U t )| ≲ σ 1 α 2σ 1 -σ i -σ j 2σ 1 , ∀(i, j) ̸ = (1, 1) in the second phase. Finally, we turn to β 11 (U t ). One can write β 11 (U t+1 ) = (1 + 2η (σ 1 -β 11 (U t ))) β 11 (U t ) -2η j̸ =1 β 2 1j (U t ) + η 2   j,k β 1j (U t )β 1k (U t )β jk (U t ) -2σ 1 j β 2 1j (U t ) + σ 2 1 β 11 (U t )   (a) ≥ 1 + 2η σ 1 + 0.5ησ 2 1 -β 11 (U t ) -O dα γ 2σ 1 β 11 (U t ) (b) ≥ 1 + 2η 0.9995σ 1 + 0.5ησ 2 1 -β 11 (U t ) β 11 (U t ). Here in (a) we used the fact that |β ij (U t )| ≲ σ 1 r ′ log(d)α 2σ 1 -σ i -σ j 2σ 1 ≲ σ 1 α γ 2σ 1 , ∀(i, j) ̸ = (1, 1 ). In (b), we used the assumption that α ≲ σr d 2σ1/γ . The above inequality together with Lemma 3 entails that within 1 log(1+ησ1) log 1600σ1 α = O 1 ησ1 log σ1 α iterations, we have β 11 (U t ) ≥ 0.99σ 1 . This completes the proof of Lemma 2 and marks the end of Substage 1. □ Next, we move on to Substage 2. Substage 2. In Substage 2, we show that the second component β 22 (U t ) converges to σ 2 within O 1 ησ2 log σ2 α iterations while the other coefficients remain small. To this goal, we first study the one-step dynamic of β 22 (U t ): β 22 (U t+1 ) = (1 + 2η (σ 2 -β 22 (U t ))) β 22 (U t ) -2η j̸ =2 β 2 2j (U t ) + η 2   j,k β 2j (U t )β 2k (U t )β jk (U t ) -2σ 2 j β 2 2j (U t ) + σ 2 2 β 22 (U t )   . Different from the dynamic of β 11 (U t ), not all the coefficients β ij (U t ) with i = 2 or j = 2 are smaller than β 22 (U t ) at the beginning of Substage 2. In particular, the basis coefficient |β 12 (U t )| may be much larger than β 22 (U t ) at the beginning of Substage 2. To see this, note that, according to Equation 68, we have β 12 (U t ) ≲ α σ 1 -σ 2 2σ 1 and β 22 (U t ) ≲ α 2(σ 1 -σ 2 ) 2σ 1 . Hence, it may be possible to have |β 12 (U t )| ≍ β 22 (U t ) ≫ β 22 (U t ). Therefore, the term 2ηβ 2 12 (U t ) in Equation 76must be handled with extra care. Note that if we can show σ 2 β 22 (U t ) ≫ β 2 12 (U t ), then 2ηβ 2 12 (U t ) can be combined with the first term in the RHS of Equation 76 and the argument made in Substage 1 can be repeated to complete the proof of Substage 2. However, our provided bound in Equation 68 can only imply β 12 (U t ) 2 ≍ β 22 (U t ). Therefore, we need to provide a tighter analysis to show that β 2 12 (U t ) ≪ σ 2 β 22 (U t ) along the trajectory. Upon controlling β 2 12 (U t ), we can then show the convergence of β 22 (U t ) similar to our analysis for β 11 (U t ) in Substage 1. To control the behavior of β 2 12 (U t ), we study the ratio ω(t) := β 2 12 (Ut) β22(Ut) . We will show that ω(t) ≪ σ 2 along the trajectory. To this goal, we will show that ω(t) can only increase for O( 1ησ1 log( σ1 α )) iterations. Therefore, its maximum along the solution trajectory happens at T = O( 1ησ1 log( σ1 α )). Therefore, by bounding the maximum, we can show that ω(t) remains small throughout the solution trajectory. First, at the initial point, we have ω(0) ≤ 64 log 2 (d)r ′ α 2 , which satisfies our claim. We next provide an upper bound for ω(t + 1) based on ω(t). Note that ω(t + 1) ≤ (1 + η(σ 1 + σ 2 -2β 11 (U t ) + poly(α))) 2 (1 + η(σ 2 -poly(α))) 2 • β 2 12 (U t ) β 22 (U t ) = 1 + η(σ 1 -2β 11 (U t ) + poly(α)) 1 + η(σ 2 -poly(α)) 2 ω(t) ≤ 1 + η (σ 1 -2β 11 (U t )) -0.5η 2 σ 1 σ 2 2 ω(t) ≤ (1 + ησ 1 ) 2 -2ηβ 11 (U t ) -0.4η 2 σ 1 σ 2 ω(t). Due to the first inequality, ω(t) can be increasing only until β 11 (U t ) ≥ σ1 2 ± poly(α). On the other hand, due to the dynamic of β 11 in Substage 1, we can show that β 11 (U t ) ≥ σ1 2 ± poly(α) in at most T = O 1 ησ1 log σ1 α iterations. Therefore, ω(t) takes its maximum at T = O 1 ησ1 log σ1 α . On the other hand, we know that β 11 (U t ) satisfies β 11 (U t+1 ) = (1 + ησ 1 ) 2 -2ηβ 11 (U t ) ± η poly(α) β 11 (U t ). Hence, we can bound ω(t) as ω(T ) ≤ T -1 t=0 (1 + ησ 1 ) 2 -2ηβ 11 (U t ) -0.4η 2 σ 1 σ 2 ω(0) ≤ T -1 t=0 (1 + ησ 1 ) 2 -2ηβ 11 (U t ) -0.4η 2 σ 1 σ 2 (1 + ησ 1 ) 2 -2ηβ 11 (U t ) ± η poly(α) • β 11 (U t+1 ) β 11 (U t ) ω(0) (a) = T -1 t=0 (1 -Ω(η 2 σ 1 σ 2 )) β 11 (U t+1 ) β 11 (U t ) ω(0) ≤ 256(1 -Ω(η 2 σ 1 σ 2 )) T β 11 (U T ) log 2 (d). Here in (a) we used the fact that α ≲ ησ 2 r σ1/γ . Hence, we have ω(T ) ≲ (1 -Ω(η 2 σ 1 σ 2 )) T σ 1 log 2 (d) ≲ α σ 1 Ω(ησ2) σ 1 log 2 (d). Due to our assumption α ≲ κ log 2 (d) -Ω( 1 ησr ) , we conclude that ω(t) ≤ 0.01σ 2 . Therefore, equation 76 can be lower bounded as β 22 (U t+1 ) ≥ (1 + 1.99η (σ 2 -β 22 (U t ))) β 22 (U t ) -2η j>2 β 2 2j (U t ) + η 2   j,k β 2j (U t )β 2k (U t )β jk (U t ) -2σ 2 j β 2 2j (U t ) + σ 2 2 β 22 (U t )   . ( ) The rest of the proof is a line by line reconstruction of Substage 1 and hence omitted for brevity. Substage 3 ≤ k ≤ r. Via an identical argument to Substage 2, we can show that for each Substage 3 ≤ k ≤ r, we have 0.99σ k ≤ β kk (U t ) ≤ σ k . ( ) within T k = O 1 ησ k log σ1 α iterations. This completes the proof of the first statement of Proposition 3. To prove the second statement of Proposition 3 as well as Theorem 3, we next control the residual terms. First, we consider the residual term β ij (U t ) where either i ≤ r or j ≤ r. Note that β ij (U t ) = β ji (U t ) and hence we can assume i ≤ r without loss of generality. We will show that β ij (U t ) decreases linearly once the corresponding signal β ii (U t ) converges to the vicinity of σ i . To this goal, it suffices to control the largest component β max (U t ) := max i̸ =j,i≤r |β ij (U t )|. Without loss of generality, we assume that the index (i, j) attains the maximum at time t, i.e., β max (U t ) = |β ij (U t )|. One can write |β ij (U t+1 )| ≤ (1 + η (σ i + σ j -2β ii (U t ) -2β jj (U t ))) |β ij (U t )| + 2η k̸ =i,j |β ik (U t )β jk (U t )| + η 2   k,l β ik (U t )β kl (U t )β lj (U t ) -(σ i + σ j ) k β ik (U t )β kj (U t ) + σ i σ j β ij (U t )   (a) ≤ (1 -η0.9(σ i + σ j ))|β ij (U t )| ≤ (1 -η0.9σ r )|β ij (U t )|. (78) Here in (a) we used the fact that 0.99σ i ≤ β ii (U t ) ≤ σ i , 1 ≤ i ≤ r and the fact that β max (U t ) = |β ij (U t )|. Hence, we conclude that β max (U t+1 ) ≤ (1 -0.9ησ r )β max (U t ). Therefore, within additional O 1 ησr log 1 α iterations, we have |β ij (U t )| ≤ r ′ log(d)α 2 for all (i, j) such that i ≤ r, i ̸ = j. The remaining residual terms, i.e., those coefficients β ij (U t ) for which i, j > r, can be bounded via the same approach in Case II of substage 1. In particular, we can show that |β ij (U t )| ≲ |β ij (U 0 )| ≲ r ′ log(d)α 2 , ∀i, j > r. For brevity, we omit this step. This completes the proof of the second statement of Proposition 3. Finally, to prove Theorem 3, we show that once |β ij (U t )| ≤ r ′ log(d)α 2 , ∀(i, j) ̸ = (k, k), 1 ≤ k ≤ r, the signals β kk (U t ) will further converge to σ k ± O α 2 within O 1 ησ k log σ k α iterations. To see this, we simplify the dynamic of β kk (U t ) as β kk (U t+1 ) = (1 + 2η (σ k -β kk (U t ))) β kk (U t ) -2η j̸ =k β 2 kj (U t ) + η 2   j,l β kj (U t )β kl (U t )β jl (U t ) -2σ k j β 2 kj (U t ) + σ 2 k β kk (U t )   = (1 + η (β kk (U t ) -σ k )) 2 β kk (U t ) -O ηdr ′2 log 2 (d)α 4 , which leads to σ k -β kk (U t+1 ) = (1 -ηβ kk (U t ) (2 + η(σ k -β kk (U t )))) (σ k -β kk (U t )) + O ηdr ′2 log 2 (d)α 4 ≤ (1 -1.98ησ k )(σ k -β kk (U t )) + O ηdr ′2 log 2 (d)α 4 . (80) Hence, within additional O 1 ησ k log σ k α iterations, we have |σ k -β kk (U t+1 )| = O 1 σ k dr ′2 log 2 (d)α 4 for every 1 ≤ k ≤ r. In conclusion, we have U T U ⊤ T -M ⋆ 2 F = d i,j=1 β ij (U T ) -β ⋆ ij 2 ≤ r ′2 log 2 (d)d 2 α 4 . ( ) within O 1 ησr log σr α iterations. This completes the proof of Theorem 3. □

F.4 ANALYSIS OF THE LOGISTIC MAP

In this section, we provide the proofs of Lemmas 3 and 4.

UPPER BOUND OF ITERATION COMPLEXITY

Recall the logistic map x t+1 = (1 + ησηx t )x t , x 0 = α. (82) Here the initial value satisfies 0 < α ≪ σ. Vaskevicius et al. (2019) provide both upper and lower bounds for x t that follows the above logistic map. However, their bounds are not directly applicable to our setting. Hence, we need to develop a new proof for Lemma 3. Proof of Lemma 3. We divide the dynamic into two stages: (a) x t ≤ 1 2 σ, and (b) x t ≥ 1 2 σ. Stage 1: x t ≤ 1 2 σ. We consider K = ⌈log 1 2 σ/α ⌉ substages, where in each Substage k, we have αe k ≤ x t ≤ αe k+1 . Suppose that t k is the number of iterations in Substage k. One can write x t+1 = (1 + ησ -ηx t )x t ≥ (1 + ησ -ηαe k+1 )x t . Hence, it suffices to find the smallest t = t min such that αe k (1 + ησ -ηαe k+1 ) t ≥ αe k+1 . ( ) Solving this inequality leads to t min = 1 log(1 + ησ -ηαe k+1 ) . ( ) Based on the above equality, we provide an upper bound for t min : t min = 1 log(1 + ησ) + log(1 -ηαe k+1 /(1 + ησ)) (a) ≤ 1 log(1 + ησ) - ηαe k+1 1+ησ-ηαe k+1 ≤ 1 + ηαe k+1 1+ησ-ηαe k+1 log(1 + ησ) (b) ≤ 1 + ηαe k+1 log(1 + ησ) . Here in (a) we used the fact that log(1 + x) ≤ x 1+x , ∀x > -1 and in (b) we used the fact that x t ≤ 1 2 σ. Hence, we have t k ≤ 1+ηαe k+1 log(1+ησ) . Therefore, the total iteration complexity of Stage 1 is upper bounded by T 1 = K-1 k=0 t k ≤ K-1 k=0 1 + ηαe k+1 log(1 + ησ) ≤ ⌈log 1 2 σ/α ⌉ + ησ log(1 + ησ) ≤ log(4σ/α) log(1 + ησ) . ( ) Stage 2: x t ≥ 1 2 σ. In this stage, we rewrite the equation 82 as σ - x t+1 = (1 -ηx t )(σ -x t ). ( ) Via a similar trick, we can show that within additional T 2 = log(4σ/ε) log(1+ησ) iterations, we have x t ≥ σε, which can be achieved in the total number of iterations T = T 1 + T 2 = 1 log(1+ησ) (log(4σ/α) + log(4σ/ε)) iterations. This completes the proof of Lemma 3 □

SEPARATION BETWEEN TWO INDEPENDENT SIGNALS

In this section, we show that there is a sharp separation between two logistic maps with signals σ 1 , σ 2 provided that σ 1 ̸ = σ 2 . In particular, suppose that σ 1σ 2 ≥ γ > 0 and x t+1 = (1 + ησ 1 -ηx t )x t , x 0 = α, y t+1 = (1 + ησ 2 -ηy t )y t , y 0 = α. Proof of Lemma 4. By Lemma 3, we have x T ≥ σ 1 -ε within T = 1 log(1+ησ1) log 16σ 2 1 εα iterations. Therefore, it suffices to show that y t remains small for t ≤ T . To this goal, note that y t+1 = (1 + ησ 2 -ηy t )y t ≤ (1 + ησ 2 )y t ≤ (1 + ησ 2 ) t+1 α. Hence, we need to bound Γ = (1 + ησ 2 ) T . Taking logarithm of both sides, we have log(Γ) = T log(1 + ησ 2 ) = log 16σ 2 1 εα log(1 + ησ 2 ) log(1 + ησ 1 ) . Now, we provide a lower bound for the ratio log(1 + ησ 1 )/ log(1 + ησ 2 ): log(1 + ησ 1 ) log(1 + ησ 2 ) = 1 + log(1 + η(σ1-σ2) 1+ησ2 ) log(1 + ησ 2 ) ≥ 1 + η(σ1-σ2) 1+ησ2 / 1 + η(σ1-σ2) 1+ησ2 ησ 2 (a) ≥ 1 + σ 1 -σ 2 2σ 2 = σ 1 + σ 2 2σ 2 , where (a) follows from the assumption η ≤ 1 4σ1 . Therefore, we have Γ = exp log 16σ 2 1 εα 2σ 2 σ 1 + σ 2 = 16σ 2 1 εα 2σ2/(σ1+σ2) . which implies that y T ≤ α 16σ 2 1 εα 2σ2/(σ1+σ2) ≤ 16σ 2 1 ε α σ 1 -σ 2 σ 1 +σ 2 . ( ) This completes the proof of Lemma 4. □

G PROOF FOR TENSOR DECOMPOSITION

In this section, we prove our results for the orthonormal symmetric tensor decomposition (OSTD). Different from matrix factorization, we use a special initialization that aligns with the ground truth. In particular, for all 1 ≤ i ≤ r ′ , we assume that sin(u i (0), z i ) ≤ γ for some small γ. We will show that u i (t) aligns with z i along the whole optimization trajectory. To this goal, we define v ij (t) = ⟨u i (t), z j ⟩ for every 1 ≤ i ≤ r ′ and 1 ≤ j ≤ d. Recall that Λ is a multi-index with length l. We define |Λ| k as the number of times index k appears as one of the elements of Λ.foot_4 Evidently, we have 0 ≤ |Λ| k ≤ l. Based on these definitions, one can write β Λ (U ) = r ′ i=1 d k=1 ⟨u i , z k ⟩ |Λ| k = r ′ i=1 d k=1 v |Λ| k ik . Now, it suffices to study the dynamic of v ij (t). In particular, we will show that v ij (t) remains small except for the top-r diagonal elements v jj (t), 1 ≤ j ≤ r, which will approach σ 1/l j . To make this intuition more concrete, we divide the terms {v ij (t)} into three parts: • signal terms defined as v jj (t), 1 ≤ j ≤ r, • diagonal residual terms defined as v jj (t), r + 1 ≤ j ≤ d, and • off-diagonal residual terms defined as v ij (t), ∀i ̸ = j. Moreover, we define V (t) = max i̸ =j |v ij (t)| as the maximum element of the off-diagonal residual terms at every iteration t. When there is no ambiguity, we will omit the dependence on iteration t. For example, we write v ij = v ij (t) and V = V (t). Similarly, when there is no ambiguity, we write β Λ (t) or β Λ in lieu of β Λ (U (t)). Our next lemma characterizes the relationship between β Λ and v ij . Lemma 5. Suppose that max j≥r+1 v l jj ≲ σ l-1 l 1 V , and V ≤ σ 1/l 1 d -1 l-1 . Then, • For β Λj with Λ j = (j, . . . , j), we have β Λj -v l jj ≤ r ′ V l . • For β Λ with at least two different indices in Λ, we have |β Λ | ≤ 2rσ l-1 l 1 V. ( ) The proof of this lemma is deferred to Appendix G.1. Lemma 5 reveals that the magnitude of β Λ can be upper bounded by max i̸ =j |v ij |. Next, we control v ij by providing both lower and upper bounds on its dynamics. Proposition 7 (One-step dynamics for v ij (t)). Suppose that we have V (t) ≲ σ 1/l 1 d -1 l-1 and v ii (t) ≤ σ 1/l 1 , ∀1 ≤ i ≤ d. Moreover, suppose that the step-size satisfies η ≲ 1 lσ1 . Then, • For the signal term v ii (t), 1 ≤ i ≤ r, we have v ii (t + 1) ≥ v ii (t) + ηl σ i -v l ii (t) -2d l-1 v l-2 ii (t)V 2 (t) v l-1 ii (t) -ld l ησ l-1 l 1 V l (t). • For the diagonal residual term v ii (t), r + 1 ≤ i ≤ d, we have v ii (t + 1) ≤ v ii (t) -ηlv 2l-1 ii (t) + 2ηld l σ l-1 l 1 V l (t). • For the off-diagonal term V (t), we have V (t + 1) ≤ V (t) + 3ηlσ 1 V (t) l-1 . The proof of this proposition is deferred to Appendix G.2. Equipped with the above one-step dynamics, we next provide a bound on the growth rate of v ij . Proposition 8. Suppose that the initial point satisfies sin(u i (0), z i ) ≤ γ and ∥u i (0 )∥ = α 1/l with α ≲ 1 d l 3 , γ ≲ 1 lκ l l-2 . Moreover, suppose that the step-size satisfies η ≲ 1 lσ1 . Then, within t ⋆ = 8 ηlσr α -l-2 l iterations, • For the signal term v ii (t), 1 ≤ i ≤ r, we have v l ii (t ⋆ ) -σ i ≤ 8d l-1 σ l-2 l i α 2/l γ 2 . • For the diagonal residual term v ii (t), r + 1 ≤ i ≤ d, we have |v ii (t * )| ≤ 2α 1/l . • For the off-diagonal term V (t), we have V (t * ) ≤ 2 1/l V (0) ≤ (2α) 1/l γ. ( ) The proof of this proposition is deferred to Appendix G.3. With the above proposition, we are ready to prove Theorem 4. Proof of Thereom 4. We have the following decomposition  ∥T -T ⋆ ∥ 2 F = Λ (β Λ -β ⋆ Λ ) 2 . ( -β ⋆ Λj , 1 ≤ j ≤ r β Λj (t ⋆ ) -β ⋆ Λj Lemma 5 ≤ v l jj (t ⋆ ) -σ j + r ′ V l (t ⋆ ) Proposition 8 ≤ 8d l-1 σ l-2 l j α 2/l γ 2 + 2dαγ l ≤ 16d l-1 σ l-2 l j α 2/l γ 2 , where in the last inequality, we used σ j ≥ 1, α ≤ 1, and γ ≤ 1. For the remaining diagonal elements β Λj , r + 1 ≤ j ≤ d, we have β Λj (t ⋆ ) Lemma 5 ≤ v l jj (t ⋆ ) + r ′ V l (t ⋆ ) Proposition 8 ≤ 2 l α + 2dαγ l . For the general β Λ with at least two different indices in the multi-index Λ, we have |β Λ (t ⋆ )| Lemma 5 ≤ 2rσ l-1 l 1 V (t ⋆ ) Proposition 8 ≤ 4rσ l-1 l 1 α 1/l γ. Hence, we conclude ∥T(t ⋆ ) -T ⋆ ∥ 2 F ≤ r i=1 16d l-1 σ l-2 l i α 2/l γ 2 + d i=r+1 2 l α + 2dαγ l + d l • 4rσ l-1 l 1 α 1/l γ ≤ 8rd l γσ l-1 l 1 α 1/l , which completes the proof of the theorem. G.1 PROOF OF LEMMA 5 Proof. We first analyze β Λj . Note that β Λj = r ′ i=1 ⟨u i , z j ⟩ l = r ′ i=1 v l ij = v l jj + i̸ =j v l ij . Hence, β Λj -v l jj = i̸ =j v l ij ≤ r ′ V l , where we used the definition of V . For general β Λ where there are at least two different elements in the multi-index Λ, we have |β Λ | = r ′ j=1 k∈Λ v |Λ| k jk ≤ r ′ j=1 v |Λ|j jj V l-|Λ|j = r j=1 v |Λ|j jj V l-|Λ|j + r ′ j=r+1 v |Λ|j jj V l-|Λ|j ≤ rσ l-1 l 1 V + (r ′ -r) max j≥r+1 v l jj + V l ≤ 2rσ l-1 l 1 V, where in the last inequality, we used the assumption that V ≲ σ 1/l 1 d -1 l-1 and max j≥r+1 v l jj ≲ σ l-1 l 1 V . This completes the proof.

G.2 PROOF OF PROPOSITION 7

In this section, we provide the proof for Proposition 7. For simplicity and whenever there is no ambiguity, we omit the iteration t and show iteration t + 1 with superscript '+'. For instance, we write v ij = v ij (t) and v + ij = v ij (t + 1). Recall that v ij = ⟨u i , z j ⟩. For simplicity, we denote µ = σ 1/l 1 . Hence, by our assumption, we have v ii ≤ µ, ∀1 ≤ i ≤ r. We first provide the exact dynamic of v ij in the following lemma. Lemma 6. The one-step dynamic of v ij takes the following form v + ij = v ij + ηl(σ j -v l jj )v l-1 ij -ηl k∈[r ′ ],k̸ =j v l kj v l-1 ij -η s∈[l-1] s Λ:|Λ|j =s β Λ   k∈Λ,k̸ =j v |Λ| k ik   v s-1 ij . (112) Proof. Recall that L = 1 2 Λ (β Λ -β ⋆ Λ ) 2 and β Λ = r ′ i=1 k∈Λ ⟨u i , z k ⟩ |Λ| k . Moreover, we have β ⋆ Λj = σ j for 1 ≤ j ≤ r, and β ⋆ Λ = 0 otherwise. We first calculate ∇ ui β Λ = s∈Λ |Λ| s   k∈Λ,k̸ =s ⟨u i , z k ⟩ |Λ| k   ⟨u i , z s ⟩ |Λ|s-1 z s = s∈Λ |Λ| s   k∈Λ,k̸ =s v |Λ| k ik   v |Λ|s-1 is z s . Hence, the partial derivative of L(t) with respect to u i is ∇ ui L = ∀Λ (β Λ -β ⋆ Λ )∇ ui β Λ = Λ (β Λ -β ⋆ Λ ) s∈Λ |Λ| s   k∈Λ,k̸ =s v |Λ| k ik   v |Λ|s-1 is z s . Note that {z j } j∈ [d] are unit orthogonal vectors. Hence, we have ⟨∇ ui β Λ , z j ⟩ = |Λ| j k∈Λ,k̸ =j v |Λ| k ik v |Λ|j -1 ij if j ∈ Λ 0 if j / ∈ Λ. ( ) By the definition of v ij , its update rule can be written in the following way v + ij = u + i , z j (a) = v ij -η ⟨∇ ui L, z j ⟩ = v ij + η Λ (β ⋆ Λ -β Λ ) ⟨∇ ui β Λ , z j ⟩ (b) = v ij + η Λ:|Λ|j ≥1 (β ⋆ Λ -β Λ )|Λ| j   k∈Λ,k̸ =j v |Λ| k ik   v |Λ|j -1 ij (c) = v ij + η s∈[l] Λ:|Λ|j =s |Λ| j (β ⋆ Λ -β Λ )   k∈Λ,k̸ =j v |Λ| k ik   v s-1 ij . Here in (a), we used the update rule for u i . In (b), we applied equation 114 to exclude those Λ without j. In (c), we simply rearranged the above equation according to the cardinality |Λ| j . We further isolate the term that only has v ij : v + ij (a) = v ij + ηl(σ j -β Λj )v l-1 ij -η s∈[l-1] s Λ:|Λ|j =s β Λ   k∈Λ,k̸ =j v |Λ| k ik   v s-1 ij (b) = v ij + ηl(σ j -v l jj )v l-1 ij -ηl k∈Λ,k̸ =j v l kj v l-1 ij -η s∈[l-1] s Λ:|Λ|j =s β Λ   k∈Λ,k̸ =j v |Λ| k ik   v s-1 ij . Here in (a), we rearranged terms and isolated the term with |Λ| j = l. Note that the remaining terms must satisfy 1 ≤ |Λ| j ≤ l -1, which indicates that there must be at least 2 different indexes in Λ which in turn implies β ⋆ Λ = 0. In (b), we used the definition of β Λj . This completes the proof of Lemma 6. Equipped with Lemma 6, we are ready to prove Proposition 7. Proof of Proposition 7. The proof is divided into three parts: Signal Term: v ii (t), 1 ≤ i ≤ r We first consider the signal terms v ii (t), 1 ≤ j ≤ r. First, upon setting i = j in Lemma 6, we have for |Λ| i = s ∈ {1, . . . , l -1}. We have v + ii = v ii + ηl   σ i -v l ii - k̸ =i v l ki   v l-1 ii -η (A) = r ′ j=1 h∈Λ v |Λ| h jh   k∈Λ,k̸ =i v |Λ| k ik   v |Λ|i-1 ii (a) ≤ r ′ j=1 v |Λ|j jj V l-|Λ|j V l-s v s-1 ii (b) ≤ v 2s-1 ii V 2l-2s + r ′ max j̸ =i v |Λ|j jj V 2l-s-|Λ|j v s-1 ii (c) ≤ v 2l-3 ii V 2 + r ′ max j̸ =i µ |Λ|j +s-1 V 2l-s-|Λ|j ≤ v 2l-3 ii V 2 + r ′ µ l-1 V l . In (a), we used the fact that |v ij | ≤ V, ∀i ̸ = j. In (b), we isolated the term with j = i and bounded the remaining terms with their maximum value. In (c), we used the fact that V ≤ v ii ≤ µ. After substituting equation 118 into equation 117, we have v + ii ≥ v ii + ηl σ i -v l ii -r ′ V l v l-1 ii -η l-1 s=1 s Λ:|Λ|i=s v 2l-3 ii V 2 + r ′ µ l-1 V l ≥ v ii + ηl σ i -v l ii -r ′ V l v l-1 ii -η v 2l-3 ii V 2 + r ′ µ l-1 V l l-1 s=1 sC s l (d -1) l-s , where C s l = l s . Note that l-1 s=1 sC s l (d -1) l-s = ld l-1l ≤ ld l-1 . Therefore, we have v ii (t + 1) ≥ v ii + ηl σ i -v l ii -r ′ V l v l-1 ii -ld l-1 η v 2l-3 ii V 2 + r ′ µ l-1 V l ≥ v ii + ηl σ i -v l ii -r ′ V l -d l-1 v l-2 ii V 2 v l-1 ii -ld l-1 ηr ′ µ l-1 V l ≥ v ii + ηl σ i -v l ii -2d l-1 v l-2 ii V 2 v l-1 ii -ld l-1 ηr ′ µ l-1 V l ≥ v ii + ηl σ i -v l ii -2d l-1 v l-2 ii V 2 v l-1 ii -ld l ησ l-1 l 1 V l , where in the last inequality, we used the fact that r ′ ≤ d. Diagonal Residual Term: v ii , r + 1 ≤ i ≤ d In this case we consider the terms v ii with r + 1 ≤ i ≤ d, which is similar to the case 1 ≤ i ≤ r. Without loss of generality, we assume that v ii ≥ 0. The case v ii ≤ 0 can be argued in an identical fashion. By equation 117, we have v + ii = v ii -ηl   v l ii - k̸ =i v l ki   v l-1 ii -η l-1 s=1 s Λ:|Λ|i=s β Λ   k∈Λ,k̸ =i v |Λ| k ik   v s-1 ii ≤ v ii -ηl v l ii -r ′ V l v l-1 ii -η l-1 s=1 s Λ:|Λ|i=s β Λ   k∈Λ,k̸ =i v |Λ| k ik   v s-1 ii . (121) For (A) = β Λ k∈Λ,k̸ =i v |Λ| k ik v s-1 ii , we further have (A) = r ′ j=1 h∈Λ v |Λ| h jh   k∈Λ,k̸ =i v |Λ| k ik   v s-1 ii ≥ h∈Λ v |Λ| h ih   k∈Λ,k̸ =i v |Λ| k ik   v s-1 ii - j̸ =i h∈Λ v |Λ| h jh   k∈Λ,k̸ =i v |Λ| k ik   v s-1 ii ≥   k∈Λ,k̸ =i v 2|Λ| k ik   v 2s-1 ii ≥0 -r ′ max j̸ =i v |Λ|j jj V 2l-s-|Λ|j v s-1 ii ≥ -r ′ µ l-1 V l . Therefore, we obtain v + ii ≤ v ii -ηl v l ii -r ′ V l v l-1 ii + ηld l-1 r ′ µ l-1 V l ≤ v ii -ηlv 2l-1 ii + 2ηld l-1 r ′ µ l-1 V l . ( ) Off-diagonal Residual Term: V (t) Finally, we characterize the dynamic of V (t) = max i̸ =j |v ij (t)|. To this goal, we first consider the dynamic of each v ij such that i ̸ = j. Without loss of generality, we assume that v ij ≥ 0. One can write v + ij = v ij + ηl   σ j - i∈[r ′ ] v l ij   v l-1 ij -η s∈[l-1] s Λ:|Λ|j =s β Λ   k∈Λ,k̸ =j v |Λ| k ik   v s-1 ij ≤ V + ηlσ 1 V l-1 + ηlr ′ V 2l-1 -η s∈[l-1] s Λ:|Λ|j =s β Λ   k∈Λ,k̸ =j v |Λ| k ik   v s-1 ij ≤ V + 2ηlσ 1 V l-1 -η s∈[l-1] s Λ:|Λ|j =s β Λ   k∈Λ,k̸ =j v |Λ| k ik   v s-1 ij , Iteration complexity of the diagonal residual term v ii (t), r + 1 ≤ i ≤ d. Next, we show that v ii (t), r + 1 ≤ i ≤ d will remain small during 0 ≤ t ≤ t ⋆ = 8 ηlσr α -l-2 l . First by equation 123 in the proof of Lemma 6, we have for every 0 ≤ t ≤ t ⋆ v ii (t + 1) ≤ v ii (t) + 2ηld l σ l-1 1 V l (t). (132) Note that V (t) ≤ 2 1/l V (0), ∀t ≤ t ⋆ , which leads to v ii (t) ≤ v ii (0) + 4ηld l σ l-1 1 γ l αt ≤ v ii (0) + 4ηld l σ l-1 1 γ l αt ⋆ ≤ α 1/l + O d l κα 2/l γ l ≤ 2α (133) for 0 ≤ t ≤ t ⋆ . Here we used the assumption that α ≲ 1 d l 3 and γ ≲ 1 κl 1 l-2 . Iteration complexity of the signal term v ii (t), 1 ≤ i ≤ r. As the last piece of the proof, we show that by iteration t ⋆ = 8 ηlσr α -l-2 l , the signal v ii , 1 ≤ i ≤ r will converge to the eigenvalue σ 1/l i . First, recall that v ii (t + 1) ≥ v ii (t) + ηl σ iv l ii (t) -2d l-1 v l-2 ii (t)V 2 (t) v l-1 ii (t) -2ld l ησ l-1 1 V l (t) ≥ v ii (t) + ηl σ iv l ii (t) -4d l-1 v l-2 ii (t)α 2/l γ 2 v l-1 ii (t) -4ld l ησ l-1 1 γ l α, where in the last inequality we used the fact that V (t) ≤ 2 1/l V (0) for 0 ≤ t ≤ t ⋆ . In light of the above inequality, we characterize the convergence of v ii using a similar method as in (Ma & Fattahi, 2022a) . In particular, we divide our analysis into two phases. Phase 1. In the first phase, we have v ii ≤ (0.5σ i ) 1/l . First, since v ii (0) ≥ α 1/l 1γ 2 , we can easily conclude that v ii (t + 1) ≥ v ii (t) by induction. Hence, we can simplify the dynamic as v ii (t + 1) ≥ 1 + ηl(0.99σ iv l ii )v l-2 ii (t) v ii (t) ≥ 1 + 0.49ηlσ i v l-2 ii (t) v ii (t). (135) Next, we further split the interval I = 0, 0.5σ 1/l i into N = O log 0.5σ 1/l i /α sub-intervals {I 0 , • • • , I N -1 }, where I k = [2 k v ii (0), 2 k+1 v ii (0)). Let T k collect the iterations that v ii spends in I k . Accordingly, let |T k | = t k be the number of iterations that v ii spends within I k . First note that v ii (t) ≥ 2 k v ii (0) for every t ∈ T k . Hence, we have 1 + 0.49ηlσ i 2 (l-2)k v l-2 ii (0) t k ≥ 2. ( ) which implies t k ≤ log(2) 0.49ηlσ i v l-2 ii (0) 2 -(l-2)k . ( ) By summing over k = 0, • • • , N -1, we can upper bound the required number of iterations T 3 T 3 ≤ ∞ k=0 t k ≤ ∞ k=0 log(2) 0.49ηlσ i v l-2 ii (0) 2 -(l-2)k ≤ 4 ηlσ i α l-2 ≤ 4 ηlσ r α l-2 ≪ T 1 , where the last inequality is due to our assumption γ ≲ 1 κl 1 l-2 . Phase 2. In the second phase, we have v ii ≥ 0.5σ 1/l i . We further simplify equation 134 as v ii (t + 1) ≥ v ii (t) + ηl σ i -8d l-1 σ l-2 l i α 2 γ 2 -v l ii (t) v l-1 ii (t) ≥ v ii (t) + ηl σi -v l ii (t) v l-1 ii (t), where we denote σi = σ i -8d l-1 σ l-2 l i α 2/l γ 2 . Then, via a similar trick, within additional T 4 ≤ 4 ηlσr α -l-2 l iterations, we have v l ii (t) ≥ σ i -8d l-1 σ l-2 l i α 2/l γ 2 . A similar argument on the upper bound shows v l ii (t) ≤ σ i + 8d l-1 σ l-2 l i α 2/l γ 2 , which completes the proof. □



If θ0 is selected from an isotropic Gaussian distribution, then C1 and C2 may scale with k and d. However, to streamline the presentation, we keep this dependency implicit. Suppose that {ϕi(x)} d i=1 are not orthonormal. Let { ϕi(x)}i∈I be any orthonormal basis for FΘ. Then, there exists a matrix A such that ϕi(x) = j Aij ϕj(x) for every 1 ≤ i ≤ d. Therefore, upon defining θ = θ ⊤ A, one can write f θ ⋆ (x) = i∈I θ ⋆ i ϕi(x) which has the same form as the regression model. We use the notations ut or u(t) interchangeably to denote the solution at iteration t. In "linear warm-up", we linearly increase the learning rate in the first 5 epochs. More precisely, we set the initial learning rate to 1 × 10 -5 and linearly increase it to the selected learning rate in 5 epochs. After the first 5 epochs, the learning rate follows a regular decay scheme. For instance, assume that Λ = (1, 1, 2). Then, |Λ|1 = 2 and |Λ|3 = 0.



Figure 1: The solution trajectories of LARS on AlexNet and ResNet-18 and AdamW on ViT with ℓ2-loss after projecting onto two different orthonormal bases. The first row shows the trajectories of the top-5 coefficients after projecting onto a randomly generated orthonormal basis. The second row shows the trajectories of the top-5 coefficients after projecting onto the eigenvectors of the conjugate kernel evaluated at the last epoch. More detail on our implementation can be found in Appendix B.

Figure 2: The conditions of Theorem 1 are approximately satisfied for LARS on a 2-layer CNN with MNIST dataset. Here, the coefficients are obtained by projecting the solution trajectory onto the eigenvectors of the conjugate kernel after training. (a) The top-20 basis coefficients at a small random initial point along with residual energy on the remaining coefficients. (b) The maximum value of | cos(∇βi(θ), ∇βj(θ))| for 1 ≤ i < j ≤ 10 along the solution trajectory. (c) The scaling of ∥∇βi(θt)∥ with respect to |βi(θt)| for the top-4 coefficients.

Figure 3: (First row) the projected trajectory of LARS on CNNs with MNIST dataset. The test accuracies for 2-Block, 3-Block, and 4-block CNN are 96.10%, 96.48%, 96.58%. (Second row) the projected trajectories of different optimizers on AlexNet with the CIFAR-10 dataset.The test accuracies for LARS, SGD, and AdamW are 90.54%, 91.03%, and 90.26%, respectively. We use the following settings for each optimizer: (d) LARS: learning rate of 2, Nesterov momentum of 0.9, and weight decay of 1 × 10 -4 . (e) SGD: learning rate of 2 with "linear warm-up", Nesterov momentum of 0.9, weight decay of 1 × 10 -4 . (f) AdamW: learning rate of 0.01.

Figure 4: Experimental verification to support Theorems 2, 3, and 4 in Section 3. The first row shows the projected trajectories of GD onto the specific basis functions we defined for each problem. The second row shows the estimation error.

Figure 5: Solution trajectories for different architectures trained on CIFAR-10. The test accuracies for Alexnet,

Figure 6: Solution trajectories for different architectures trained on CIFAR-100. The test accuracies for

Figure 7: Solution trajectories of LARS on the CIFAR-10 dataset with ℓ2-loss and CE loss. The test accuracies are 90.54% for ℓ2-loss and 91.09% for CE loss.

Figure 8: Solution trajectories for AlexNet on the CIFAR-10 dataset with different batch sizes. The test accuracies for batch size 32, 64 and 256 are 91.50%, 91.79%, and 90.12%, respectively.

v ii + ηl σ iv l ii -mV l v l

Jialun Zhang, Salar Fattahi, and Richard Y Zhang. Preconditioned gradient descent for overparameterized nonconvex matrix factorization. Advances in Neural Information Processing Systems, 34:5985-5996, 2021. Richard Y Zhang. Sharp global guarantees for nonconvex low-rank matrix recovery in the overparameterized regime. arXiv preprint arXiv:2104.10790, 2021. Jiacheng Zhuo, Jeongyeol Kwon, Nhat Ho, and Constantine Caramanis. On the computational and statistical complexity of over-parameterized matrix sensing. arXiv preprint arXiv:2102.02756, 2021. Numerical Verification of our Theoretical Results . . . . . . . . . . . . . . . . . . B.2 Derivation of Basis Functions for DNNs . . . . . . . . . . . . . . . . . . . . . . . B.3 Further Details on the Experiments . . . . . . . . . . . . . . . . . . . . . . . . . . B.4 Experiments for CIFAR-100 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . B.5 Experiments for Different Losses . . . . . . . . . . . . . . . . . . . . . . . . . . . B.6 Experiments for Different Batch Size . . . . . . . . . . . . . . . . . . . . . . . . . B.7 Experiments for ResNet-18 with SGD on CIFAR-10 . . . . . . . . . . . . . . . .

The summary of our experiments.

Hence, it suffices to bound each |β Λβ ⋆ Λ |. Combining Lemma 5 and Proposition 8, we have for every β Λj

ACKNOWLEDGEMENTS

We thank Richard Y. Zhang and Tiffany Wu for helpful feedback. We would also like to thank Ruiqi Gao and Chenwei Wu for their insightful discussions. This research is supported, in part, by NSF Award DMS-2152776, ONR Award N00014-22-1-2127, MICDE Catalyst Grant, MIDAS PODS grant and Startup Funding from the University of Michigan.

annex

Proof. First note that U ⊤ 0 z i ∼ N (0, α 2 I r ′ ×r ′ ). Hence, a standard concentration bound on Gaussian random vectors impliesVia a union bound, we have that with probability of at least 1e -Cr ′ :Given these bounds, one can writeandwhich completes the proof.

F.2 ONE-STEP DYNAMICS

In this section, we characterize the one-step dynamics of the basis coefficients. To this goal, we first provide a more precise statement of Proposition 3 along with its proof.Proposition 6. For the diagonal element β ii (U t ), we havewhere σ i = 0 for r < i ≤ d. Moreover, for every i ̸ = j, we haveProof. The iterations of GD on SMF take the formThis leads toBased on these definitions, one can writewhere in the last inequality we use the assumption V ≲ σ. Hence, we have thatWe further note that |Λ| h + |Λ| i ≤ ls. Hence, it can be lower bounded byPluggin our estimation for (A) into equation 124, we finally haveThe last inequality comes from the assumption that V ≤ 1 d l . This completes the proof of Proposition 7.

G.3 PROOF OF PROPOSITION 8

In this section, we provide the proof of Proposition 8. We will show that the signal terms v ii (t), 1 ≤ i ≤ r quickly converge to σ 1/l i , and the residual terms remain small. To this goal, we first study the dynamic of V (t).Iteration complexity of the off-diagonal residual term V (t).To start with the proof, we first study the time required for the off-diagonal term V (t) to go from V (0) to 2 1/l V (0), i.e., T 1 = min t≥0 {V (t) ≥ 2 1/l V (0)}. By Proposition 7, we knowHence, for 0 ≤ t ≤ T 1 , we haveSolving the above inequality for T 1 , we obtainOn the other hand, our initial point satisfies sin(u i (0), z i ) ≤ γ and ∥u i (0)∥ = α 1/l . Hence, we have V (0) ≤ α 1/l γ. Substituting this into the above equation, we conclude thatNote that T 1 ≥ log(2) 6ηl 2 σ1γ l-2 α -l-2 l ≫ t ⋆ . Hence, we have V (t ⋆ ) ≤ 2 1/l V (0) ≤ (2α) 1/l γ.

H AUXILIARY LEMMAS

Lemma 7 (Bernoulli inequality). For 0 ≤ x < 1 r-1 , and r > 1, we have ) 

