NEURAL NETWORKS EFFICIENTLY LEARN LOW-DIMENSIONAL REPRESENTATIONS WITH SGD

Abstract

We study the problem of training a two-layer neural network (NN) of arbitrary width using stochastic gradient descent (SGD) where the input x ∈ R d is Gaussian and the target y ∈ R follows a multiple-index model, i.e., y = g(⟨u 1 , x⟩, . . . , ⟨u k , x⟩) with a noisy link function g. We prove that the first-layer weights of the NN converge to the k-dimensional principal subspace spanned by the vectors u 1 , . . . , u k of the true model, when online SGD with weight decay is used for training. This phenomenon has several important consequences when k ≪ d. First, by employing uniform convergence on this smaller subspace, we establish a generalization error bound of O( kd/T ) after T iterations of SGD, which is independent of the width of the NN. We further demonstrate that, SGD-trained ReLU NNs can learn a single-index target of the form y = f (⟨u, x⟩) + ϵ by recovering the principal direction, with a sample complexity linear in d (up to log factors), where f is a monotonic function with at most polynomial growth, and ϵ is the noise. This is in contrast to the known d Ω(p) sample requirement to learn any degree p polynomial in the kernel regime, and it shows that NNs trained with SGD can outperform the neural tangent kernel at initialization. Finally, we also provide compressibility guarantees for NNs using the approximate low-rank structure produced by SGD.

1. INTRODUCTION

The task of learning an unknown statistical (teacher) model using data is fundamental in many areas of learning theory. There has been a considerable amount of research dedicated to this task, especially when the trained (student) model is a neural network (NN), providing precise and non-asymptotic guarantees in various settings (Zhong et al., 2017; Goldt et al., 2019; Ba et al., 2019; Sarao Mannelli et al., 2020; Zhou et al., 2021; Akiyama & Suzuki, 2021; Abbe et al., 2022; Ba et al., 2022; Damian et al., 2022; Veiga et al., 2022) . As evident from these works, explaining the remarkable learning capabilities of NNs requires arguments beyond the classical learning theory (Zhang et al., 2021) . The connection between NNs and kernel methods has been particularly useful towards this expedition (Jacot et al., 2018; Chizat et al., 2019) . In particular, a two-layer NN with randomly initialized and untrained weights is an example of a random features model (Rahimi & Recht, 2007) , and regression on the second layer captures several interesting phenomena that NNs exhibit in practice (Louart et al., 2018; Mei & Montanari, 2022) , e.g. cusp in the learning curve. However, NNs also inherit favorable characteristics from the optimization procedure (Ghorbani et al., 2019; Allen-Zhu & Li, 2019; Yehudai & Shamir, 2019; Li et al., 2020; Refinetti et al., 2021) , which cannot be captured by associating NNs with regression on random features. Indeed, recent works have established a separation between NNs and kernel methods, relying on the emergence of representation learning as a consequence of gradient-based training (Abbe et al., 2022; Ba et al., 2022; Barak et al., 2022; Damian et al., 2022) , which often exhibits a natural bias towards low-complexity models. A theme that has emerged repeatedly in modern learning theory is the implicit regularization effect provided by the training dynamics (Neyshabur et al., 2014) . The work by Soudry et al. (2018) has inspired an abundance of recent works focusing on the implicit bias of gradient descent favoring, in some sense, low-complexity models, e.g. by achieving min-norm and/or max-margin solutions despite the lack of any explicit regularization (Gunasekar et al., 2018; Li et al., 2018; Ji & Telgarsky, 2019; Gidel et al., 2019; Chizat & Bach, 2020; Pesme et al., 2021) . However, these works mainly consider linear models or unrealistically wide NNs, and the notion of reduced complexity as well as its implications on generalization varies. A concrete example in this domain is compressiblity and its connection to generalization (Arora et al., 2018; Suzuki et al., 2020) . Indeed, when a trained NN can be compressed into a smaller NN with similar prediction behavior, the resulting models exhibit similar generalization performance. Thus, the model complexity of the original NN can be explained by the smaller complexity of the compressed one, which is classically linked to better generalization. In this paper, we demonstrate the emergence of low-complexity structures during the training procedure. More specifically, we consider training a two-layer student NN with arbitrary width m where the input x ∈ R d is Gaussian and the target y ∈ R follows a multiple-index teacher model, i.e. y = g(⟨u 1 , x⟩, . . . , ⟨u k , x⟩; ϵ) with a link function g and a noise ϵ independent of the input. In this setting, we prove that the first-layer weights trained by online stochastic gradient descent (SGD) with weight decay converge to the k-dimensional subspace spanned by the weights of the teacher model, span(u 1 , , . . . , u k ), which we refer to as the principal subspace. Our primary focus is the case where the target values depend only on a few important directions along the input, i.e. k ≪ d, which induces a low-dimensional structure on the SGD-trained first-layer weights, whose impact on generalization is profound. First, convergence to the principal subspace leads to an improved bound on the generalization gap for SGD, independent of the width of the NN. In the specific case of learning a single-index target with a ReLU student network, we show that this convergence leads to useful features that improve upon the initial random features. Hence we prove that NNs can learn certain degree-p polynomials with a number of samples (almost) linear in d using online SGD, while learning a degree p polynomial with any rotationally invariant kernel, including the neural tangent kernel (NTK) at initialization, requires d Ω(p) samples (Donhauser et al., 2021) . We summarize our contributions as follows. • We show in Theorem 3 that NNs learn low-dimensional representations by proving that the iterates of online SGD on the first layer of a two-layer NN with width m converge to √ mε neighborhood of the principal subspace after O(d/ε 2 ) iterations, with high probability. The error tolerance of √ mε is sufficient to guarantee that the risk of SGD iterates and that of its orthogonal projection to the principal subspace are within O(ε) distance. • We demonstrate the impact of learning low-dimensional representations with three applications. -For a single-index target y = f (⟨u, x⟩) + ϵ with a monotonic link function f where f ′′ has at most polynomial growth, we prove in Theorem 4 that ReLU networks of width m can learn this target after T iterations of SGD with an excess risk estimate of Õ( d/T + 1/m), with high probability (see the illustration in Figure 1 ). In particular, the number of iterations is linear in the input dimension d, even when f is a polynomial of any (fixed) degree p. -Based on a uniform convergence argument on the principal subspace, we prove in Theorem 5 that T iterations of SGD will produce a model with generalization error of O( kd/T ), with high probability. Remarkably, this rate is independent of the width m of the NN, even in the case k ≍ d where the target is any function of the input, and not necessarily low-dimensional. -Finally, we provide a compressiblity result directly following from the low-dimensionality of the principal subspace. We prove that T iterations of SGD produce first-layer weights that are compressible to rank-k with a risk deviation of O( d/T ), with high probability. The rest of the paper is organized as follows. We discuss the notation and the related work in the remainder of this section. We describe the problem formulation and preliminaries in Section 2, and provide an analysis for the warm-up case of population gradient descent in Section 2.1. Our main result on SGD is presented in Section 3. We discuss three implications of our main theorem in Section 4, where we provide results on learnability, generalization gap, and compressibility in Sections 4.1, 4.2, and 4.3, respectively. We finally conclude with a brief discussion in Section 5. Notation. For a loss function ℓ : R 2 → R, let ∂ i ℓ and ∂ 2 ij ℓ denote its partial derivatives with respect to ith and jth inputs for i, j ∈ {1, 2}. For quantities a and b, a ≲ b implies a ≤ Cb for some absolute constant C > 0, and a ≍ b implies both a ≳ b and a ≲ b. Finally, Unif(A) denotes the uniform distribution over a set A and N (0, I d ) denotes the d-dimensional isotropic Gaussian distribution.

1.1. RELATED WORK

Training dynamics of NNs. Several works have demonstrated learnability in a special case of teacherstudent setting where the teacher model is similar to the student NN being trained (Zhong et al., 2017; Brutzkus & Globerson, 2017; Li & Yuan, 2017; Zhang et al., 2019; Zhou et al., 2021) . This setting has also been studied through the lens of loss landscape (Safran et al., 2021) and optimization over measures (Akiyama & Suzuki, 2021) . We stress that our results work under misspecification and hold for generic teacher models that are not necessarily NNs with similar architecture to the student. Two scaling regimes of analysis have seen a surge of recent interest. In the regime of lazy training (Chizat et al., 2019) , the parameters hardly move from initialization and the NN does not learn useful features, behaving like a kernel method (Jacot et al., 2018; Du et al., 2019; Allen-Zhu et al., 2019; Arora et al., 2019; Oymak & Soltanolkotabi, 2020) . However, many works have shown that deep learning is more powerful than kernel models (Yehudai & Shamir, 2019; Ghorbani et al., 2020; Geiger et al., 2020) , establishing a clear separation between them; thus, several important characteristics of NNs cannot be captured with lazy training (Ghorbani et al., 2019) , even though it might still perform better than feature learning in certain low-dimensional settings (Petrini et al., 2022) . In the other scaling regime, gradient descent on infinitely wide NNs reduces to Wasserstein gradient flow, known as the mean-field regime where feature learning is possible (Chizat & Bach, 2018; Rotskoff & Vanden-Eijnden, 2018; Mei et al., 2019; Nitanda et al., 2022; Chizat, 2022) . Closer to our results, the concurrent work of Hajjar & Chizat (2022) shows that low-dimensional targets induce low-dimensional dynamics on mean-field NNs. However, these results mostly hold for infinite or very wide NNs, and quantitative guarantees are difficult to obtain. Our setting is different from both of these regimes, as we allow for NNs of arbitrary width without excessive overparameterization. Feature learning with multiple-index teacher models. The task of learning a target of an unknown low-dimensional function of the input is fundamental in statistics (Li & Duan, 1989) . Several recent works in the learning theory literature have also focused on this problem, with an aim to demonstrate NNs can learn useful feature representations, outperforming kernel methods (Bauer & Kohler, 2019; Ghorbani et al., 2020) . In particular, Abbe et al. (2022) studies the necessary and sufficient conditions for learning with sample complexity linear in d with inputs on the hypercube, in the mean-field limit. Closer to our setting are the recent works Ba et al. (2022) ; Damian et al. (2022) ; Barak et al. (2022) which demonstrate a clear separation between NNs and kernel methods, leveraging the effect of representation learning. However, their analysis considers a single (full) gradient step on the first-layer weights followed by training the second-layer parameters. In contrast, in our learnability result, we consider training both layers with SGD, which induces essentially different learning dynamics. Generalization bounds for SGD. A popular algorithm-dependent approach for studying generalization is through algorithmic stability (Bousquet & Elisseeff, 2002; Feldman & Vondrak, 2018; Bousquet et al., 2020) , which has been used to study the generalization behavior of gradient-based methods in various settings (Hardt et al., 2016; Bassily et al., 2020; Farghly & Rebeschini, 2021; Kozachkov et al., 2022) . Other approaches include studying the low-dimensional structure of the trajectory (Simsekli et al., 2020; Park et al., 2022) or the invariant measure of continuous-time approximations of SGD (Camuto et al., 2021) , and employing information-theoretic tools (Neu et al., 2021) . Among these works, Barsbey et al. (2021) show that SGD is able to learn compressible networks. However, they require large width for the mean-field approximation and assume that the SGD iterates converge to a heavy-tailed distribution, while we do not make either of the assumptions.

2. PRELIMINARIES: NEURAL NETWORKS AND THE PRINCIPAL SUBSPACE

For an input x ∈ R d , we consider training a two-layer neural network (NN) with m neurons ŷ(x; W , a, b) = m i=1 a i σ(⟨w i , x⟩ + b i ), (2.1) where σ is the activation function, {w i } 1≤i≤m are the first-layer weights collected in the rows of the matrix W ∈ R m×d , b ∈ R m is the bias, and a ∈ R m is the second-layer weights. We assume x ∼ N (0, I d ) and the target is generated from a multiple-index (teacher) model given by y = g(⟨u 1 , x⟩, . . . , ⟨u k , x⟩; ϵ), (2.2) for a weakly differentiable link function g : R k+1 → R and a noise ϵ. Throughout the paper, the noise ϵ is assumed to be independent from the input x, and our framework covers the special noiseless case where ϵ = 0. While our results remain valid regardless of how k and d compare, they are most insightful when k ≪ d; thus, we specifically consider this regime when interpreting the results. We also collect the teacher weights {u i } 1≤i≤k in the rows of the matrix U ∈ R k×d and use y = g(U x; ϵ) for simplicity. For a given loss function ℓ(ŷ, y), we consider the population and the empirical risks R(W, a, b) := E[ℓ(ŷ(x; W, a, b), y)] and R(W, a, b) := 1 T T -1 t=0 ℓ(ŷ(x (t) ; W, a, b), y (t) ), where the expectation is over the data distribution. Similarly, for some τ ≥ 1, the truncated loss is defined as ℓ τ (ŷ, y) := ℓ(ŷ, y) ∧ τ with the corresponding risks R τ and Rτ , both of which are used in Section 4 to obtain sharp high probability statements. In the warm-up case, we consider the L 2 -regularized population risk with a penalty parameter λ ≥ 0, defined as R λ (W , a, b) := R(W , a, b) + λ 2 ∥W ∥ 2 F . (2.3) To minimize (2.3), we use stochastic gradient descent (SGD) over the first-layer weights, where we are interested in the convergence of iterates to the principal subspace defined by the teacher weights S(U ) := span(u 1 , . . . , u k ) m = {CU : C ∈ R m×k }. Notice that the principal subspace satisfies S(U ) ⊆ R m×d , and its dimension is mk as opposed to the ambient dimension of md, with any matrix in this subspace having rank at most k. For any vector v ∈ R d , we let v ∥ denote the orthogonal projection of v onto span(u 1 , . . . , u k ) and v ⊥ := vv ∥ . Similarly, for a matrix W ∈ R m×d , we define W ∥ and W ⊥ by applying the projection to each row. We make the following assumption on the data generating process. Assumption 1 (Student-teacher setup). The student model is a two-layer NN (2.1) trained over the data set {(x (i) , y (i) )} i≥1 , where the target values y (i) are generated according to the teacher model (2.2) and the inputs satisfy x (i) iid ∼ N (0, I d ). The link function g(•, . . . , •; ϵ) is weakly differentiable (see e.g. Evans (2010, Sec. 5.2.1) for definition) for any fixed ϵ. The Gaussian input is a rather standard assumption in the literature, especially in recent works that consider the student-teacher setup; see e.g. Safran et al. (2021) ; Zhou et al. (2021) ; Damian et al. (2022) . The multiple-index teacher model (2.2) can encode a broad class of input-output relations through the non-linear link function, including a multi-layer fully-connected NN with arbitrary depth and width and weakly differentiable activations. The smoothness properties of the activation σ play an important role in our analysis. As such, we consider two scenarios, with different requirements on the loss function. Commonly used activations such as sigmoid and tanh satisfy Assumption 2.A. For ReLU activation in Assumption 2.B, we choose σ ′ (z) = 1(z ≥ 0) as its weak derivative. We highlight that Assumption 2.B is satisfied by common Lipschitz and convex loss functions such as the Huber loss ℓ H (ŷ -y) := 1 2 (ŷ -y) 2 if |ŷ -y| ≤ 1 |ŷ -y| -1 2 if |ŷ -y| > 1, (2.4) as well as the logistic loss ℓ L (ŷ, y) := log(1 + e -ŷy ), up to appropriate scaling constants.

2.1. WARM-UP: POPULATION GRADIENT DESCENT

In this section, we study the dynamics of population gradient descent (PGD) to motivate our investigation of the more practically relevant case of SGD. When initialized from W 0 , PGD with a fixed step size η will update the current iterate W t according to the update rule W t+1 = W t -η∇ W R λ (W t ), (2.5) We use the following initialization throughout the paper. Assumption 3 (Initialization). For all 1 ≤ i ≤ m, 1 ≤ j ≤ d, we initialize the NN weights and biases with √ dW 0 ij iid ∼ N (0, 1), ma 0 i iid ∼ Unif([-1, 1]), and b 0 i iid ∼ Unif({-1, 1}). While this initialization is standard in the mean-field regime, we only use it to simplify the exposition. Indeed, we can initialize W and a with any scheme that guarantees ∥W ∥ F ≲ √ m and ∥a∥ ∞ ≲ m -1 with high probability. Further, initialization of b mostly matters in the analysis of ReLU activation. Next, we show that the population gradient admits a certain decomposition which plays a central role in our analysis. For smooth activations, the below result is a remarkable consequence of Stein's lemma, which provides a certain alignment between the true statistical model (teacher) and the model being trained (student), which has profound impact on the learning dynamics. We generalize this result for ReLU through a sequence of smooth approximations (see Appendix A.1 for details). Lemma 1. Under Assumptions 1&2.A or 1&2.B, the gradient of the population risk can be written as ∇ W R λ (W ) = (H(W ) + λI m )W + D(W )U , (2.6) for some H(W ) ∈ R m×m and D(W ) ∈ R k×d (with explicit forms provided in Appendix A.1). Notice that the subset of critical points {W * : ∇R λ (W * ) = 0} for which H(W * )+λI m is invertible belongs to the principle subspace, i.e. W * ∈ S(U ). Further, if we initialize PGD (2.5) within the principal subspace, i.e. W 0 ∈ S(U ), the subsequent iterates for t > 0 remain in this subspace. In statistics literature, the setting we consider is often termed as model misspecification, i.e. the teacher model generating the data and the student model being trained are different. Proposition 2 states a general result that, as long as the target depends on certain directions, PGD will produce weights in their span, despite the possible mismatch between the two models. We highlight that the classical results on dimension reduction, e.g. Li & Duan (1989) ; Li (1991) , rely on a similar principle, which was also used for designing optimization algorithms, see e.g. Erdogdu et al. (2019) . However, we are interested in the implications of this phenomenon in the setting of NNs trained with SGD. The following result, proved in Appendix A.2, demonstrates the algorithmic implications of Lemma 1 for the simplistic case of PGD, and shows that the iterates converges to the principal subspace. Proposition 2. Consider running T PGD iterations (2.5) with an initialization satisfying Assumption 3 and a step size η > 0. For any γ > 0, choose λ = λ m and η = mη based on the following. 1. Smooth activation. Under Assumptions 1&2.A, let λ ≥ 1 + γ + 1 + 2γ + 2R(W 0 ) and η ≲ λ + λς γ + ς -1 where ς := E[∥U ⊤ ∇g(U x; ϵ)∥ 2 ]. 2. ReLU activation. Under Assumptions 1&2.B, let λ ≥ γ + 2 2 eπ and η ≲ λ-1 . Then, with probability at least 1 -e -Cmd over the initialization, where C is an absolute constant, the iterates of PGD satisfy ∥W T ⊥ ∥ F ≤ (1 -ηγ) T ∥W 0 ⊥ ∥ F . (2.7) A few remarks are in order. First and most importantly, PGD iterates converge to the principal subspace as T → ∞ and this phenomenon is mainly due to the alignment provided by Lemma 1. Indeed in the limit, (2.7) provides sparsity in the basis of principal subspace, and it is widely known that L 2 -regularization does not have a similar sparsity effect (unless λ → ∞) in contrast to its L 1 counterpart. Thus, the choice of λ in Proposition 2 will lead to non-trivial orthogonal projections W T ∥ in general, as we will demonstrate in Section 4.1 with a learnability result. However, without L 2regularization, i.e. λ = 0, it is possible to converge to a critical point W * for which H(W * ) + λI m is not invertible; hence, the weights are likely to be outside of the principal subspace in this case (cf. Figures 1&2). It is also worth emphasizing that the penalty level used in the above proposition still allows for non-convexity as we demonstrate with an example in Appendix D. We finally remark that, as evident from the proof, Proposition 2 remains valid even with unbounded smooth activations.

3. CONVERGENCE OF STOCHASTIC GRADIENT DESCENT

We now consider stochastic gradient descent (SGD) in the online setting where at each iteration t, we have access to a new data point (x (t) , y (t) ) drawn independently of the previous samples from the same distribution. We update the first-layer weights W t with a (possibly) time varying step size η t and a weight decay, according to the update rule t) ). W t+1 = (1 -η t λ)W t -η t ∇ W ℓ(ŷ(x (t) ; W t , a, b), y (3.1) The above algorithm can be used to minimize the population risk (2.3) in practice (Polyak & Juditsky, 1992) , even in certain non-convex landscapes (Yu et al., 2021) . As we demonstrate next, SGD still preserves several important characteristics of its population counterpart, PGD. Theorem 3. Consider running T SGD iterations (3.1) over samples satisfying Assumption 1, with an initialization satisfying Assumption 3, and using the following step size schedules. 1. Constant step size. Under Assumption 2.A, choose the constant step size η t = 2m log(T ) γT . For any γ > 0, let λ ≥ 1 + γ + 1 + 2γ + 2R(W 0 ) + C log(T ) 2 d γ 2 T where C is an absolute constant, and suppose T ≳ (1/δ) C/d ∨ λ γ log λ γ . Then, for a penalty λ = λ m , with probability at least 1 -δ,  ∥W T ⊥ ∥ F √ m ≲ log(T )(d + log(1/δ)) γ 2 T , ( -δ, ∥W T ⊥ ∥ F √ m ≲ d + log(1/δ) γ 2 T , whenever m ≳ log(1/δ) and T ≳ λ2 d+log(1/δ) . Remark. Our results are most insightful when γ ≍ 1 with respective rates of O( d/T ) and O( d/T ) in the constant and decreasing step size settings; this scaling allows efficient learning of certain targets (see Section 4.1). Indeed, choosing a large γ may significantly restrict the learnability properties and will result in underfitting. However, if we ignore the underfitting issue, one can get the fastest convergence rate by choosing γ ≍ T (d + log(1/δ)), from which we obtain ∥W T ⊥ ∥ F / √ m ≲ 1/T . We also note that the convergence rate is stated in the normalized distance to the principal subspace, i.e. ∥W T ⊥ ∥ F / √ m ≤ ε, as this is sufficient to guarantee that the risk of W T and its orthogonal projection W T ∥ are within O(ε) distance. The above result states that, with a number of samples linear in d, SGD is able to produce iterates that are in close proximity to the principal subspace; thus, it efficiently learns (approximately) low-dimensional weights, exhibiting an implicit bias towards low-complexity models. While prior works have established that NNs adapt to low-dimensional manifold structures in the data in some contexts (Chen et al., 2019; Buchanan et al., 2021; Wang et al., 2021) , our result has a different nature. More specifically, the interplay between two forces is in effect here. The most important one is the linear relationship between the first-layer weights and the input in both student and teacher models together with the input distribution. The alignment described in Lemma 1 yields sparsified weights in a basis defined by the teacher network, effectively reducing the dimension from d to k. The second force is the explicit L 2 -regularization. We emphasize that L 2 -regularization does not play the main role in this sparsification; even though it may provide shrinkage to zero, L 2 penalty will in general produce non-sparse solutions. However, it is still required to ensure that SGD avoids critical points outside of the principal subspace. Although Theorem 3 does not have any implications on the convergence behavior of W ∥ , in the next section we show that the implied low-dimensional structure is sufficient to provide guarantees on the generalization error, learnability, and compressibility of SGD. The proof of this Theorem is provided in Appendix B, and is based on a recursion on the moment generating function of ∥W t ⊥ ∥ F . We note that, as in Proposition 2, regularization in Theorem 3 does not imply (strong) convexity in general, which we demonstrate in a non-convex example in Appendix D.

4.1. LEARNING SINGLE-INDEX TARGETS

An essential characteristic of NNs is their ability to learn useful representations, which allows them to adapt to the underlying misspecified statistical model. Although this fundamental property has been the guiding principle in all empirical studies, it was mathematically proven only recently for gradient-based training (Abbe et al., 2022; Ba et al., 2022; Barak et al., 2022; Damian et al., 2022; Frei et al., 2022) ; see also a survey of prior works in Malach et al. (2021) . Our results in the previous section are in the same spirit, establishing the convergence of SGD to the principal subspace which is indeed a span of useful directions associated with the target function being learned. As such, we leverage the learned low-dimensional representations to demonstrate that SGD is capable of learning a target function of the form y = f (⟨u, x⟩) + ϵ with a number of samples linear in d (up to logarithmic factors). For simplicity, we work with the Huber loss below; however, we can accommodate any Lipschitz and convex loss at the expense of a more detailed analysis. Algorithm 1 Training a two-layer ReLU network with SGD. Input: a 0 , b 0 ∈ R m , W 0 ∈ R m×d , {(x (t) , y (t) )} 0≤t≤T -1 , (η t ) t≥0 , (η ′ t ) t≥0 , λ, λ ′ , ∆. 1: for t = 0, ..., T -1 do 2: W t+1 = (1 -η t λ)W t -η t ∇ W ℓ(ŷ(x (t) ; W t , a 0 , b 0 ), y (t) ). 3: end for 4: Let b j iid ∼ Unif(-∆, ∆) for 1 ≤ j ≤ m. 5: for t = 0, ..., T ′ -1 do 6: Sample i t ∼ Unif{0, ..., T -1}. 7: a t+1 = (1 -η ′ t λ ′ )a t -η ′ t ∇ a ℓ(ŷ(x (it) ; W T , a t , b), y (it) ) 8: end for 9: return (W T , a T ′ , b). In the sequel, we use Algorithm 1 and train the first layer of the NN with online SGD using T data samples. Then, we randomly choose the biases and run T ′ SGD iterations on the second layer using the same data samples used to train the first layer. Thus, the overall sample complexity is T whereas the total number of SGD iterations performed is T + T ′ . We highlight that the recent works Ba et al. (2022) ; Barak et al. (2022) ; Damian et al. (2022) perform only one gradient step on the first layer weights, whereas in Algorithm 1, we train the entire NN with SGD. Theorem 4. Suppose that the data is from a single-index model y = f (⟨u, x⟩) + ϵ with a monotone differentiable f and ν-sub-Gaussian noise ϵ, and Assumptions 1&2.B hold. Further, let ∥u∥ 2 = 1, |f (0)| < 1, and consider the Huber loss (2.4) for simplicity. Consider running Algorithm 1 with the initialization 0 < a 0 j = a ≲ 1/m, 0 < b 0 j = b ≲ 1, and w 0 j = w 0 ∼ N (0, 1 d I d ) for all j with the hyper-parameters λ = λ m = γ m + 2a b 2 eπ for any γ ≍ 1, η t = m 2(t * +t)+1 γ(t * +t+1) 2 with t * ≍ γ -1 , η ′ t = 2t+1 λ ′ (t+1) 2 , and ∆ ≍ log(T /δ). Then, for T ≳ (d + log( 1 δ )) ∨ ( λ γd log( m δ )), some λ ′ > 0 (see (C.8)), and sufficiently large T ′ (see (C.9)), with probability at least 1 -δ, R τ (W T , a T ′ , b) -E[ℓ H (ϵ)] ≲ ∆ 2 * log(T /δ) m + d + log(1/δ) T + ν log(1/δ) T , (4.1) where ∆ * , defined in (C.7), is poly(log( T δ )) when f ′′ has at most polynomial growth. This result implies that a ReLU NN trained with SGD can learn any monotone polynomial with a sample complexity linear in the input dimension d, up to logarithmic factors. Indeed, this is consistent with the work of Ben Arous et al. ( 2021); they establish a sharp sample complexity of Õ(d 1∨(I-2) ) to learn a target with online SGD using the same activation f in the student network, where I is the information exponent (I = 1 in the above case due to the monotonocity of f ). Despite assuming the link function f is known, we highlight that their setting covers I ≥ 1, whereas Theorem 4 is a proof of concept to demonstrate the learnability implications of convergence to the principal subspace, even when f is unknown. Building on their work, the concurrent work of Bietti et al. (2022) also proves learnability for unkown single-index targets with I ≥ 1, albeit with a sample complexity of d 2 for I = 1 when training ReLU students. Nonparametric regression with NNs has also been considered within the NTK framework (Hu et al., 2021; Kuzborskij & Szepesvári, 2021) , but our result holds beyond this regime as m grows with poly log(T /δ), in contrast to the poly(T ) requirement of the NTK regime. Additionally, learning any degree p polynomial using rotationally invariant kernels requires d Ω(p) samples for a variety of input distributions including isotropic Gaussian (Donhauser et al., 2021) ; thus, our result shows that SGD is able to efficiently learn a target function where kernel methods cannot. For polynomial targets, Damian et al. (2022) consider training the first-layer weights with one gradient descent step with a carefully chosen weight decay, and obtain a sample complexity of d 2 to learn any degree p polynomial depending on a few directions. Finally, Chen & Meka (2020) propose a method that can train NNs to learn such polynomials with sample complexity linear in d; yet, their algorithm is not a simple variant of SGD and requires a non-trivial warm-start initialization. The proof of Theorem 4, detailed in Appendix C.2, relies on the fact that after training the first layer, the weights will align with the true direction u. Then, similarly to Damian et al. (2022) we construct an optimal a * with |a * i | ≍ m -1 for every i with a small empirical risk, and employ the generalization bound of Theorem 5 to achieve a rate estimate on the population risk.

4.2. GENERALIZATION GAP

For a given learning algorithm, the gap between its empirical and population risks is termed as the generalization gap (not to be confused with excess risk), and establishing convergence estimates for this quantity is a fundamental problem in learning theory. Classical results rely on uniform convergence over the feasible domain containing the weights; thus, they apply to any learning algorithm including SGD (Neyshabur et al., 2019) . However, these bounds often diverge with the width of the NN, yielding vacuous estimates in the overparameterized regime (Zhang et al., 2021) . To alleviate this, recent works considered establishing estimates for a specific learning algorithm; see e. Here, we are interested in deriving an estimate for the generalization gap over the SGD-trained firstlayer weights, which holds uniformly over the second layer weights and biases. More specifically, we study, after T iterations of SGD (3.1) initialized with (W 0 , a 0 , b 0 ), the following quantity E(W T ) := sup S R τ (W T , a, b)-Rτ (W T , a, b) with S := {a, b ∈ R m : ∥a∥ 2 ≤ ra √ m , ∥b∥ ∞ ≤ r b }, where the scaling ensures ŷ = O(1) when ∥w j ∥ 2 ≍ 1, which is the setting considered in Theorem 4. We state the following bound on E(W T ); the proof is provided in Appendix C.1, and it is based on a covering argument over the smaller dimensional principal subspace implied by Theorem 3. Theorem 5. Consider the setting of Theorem 3 with either decreasing or constant step size. For any δ > 0, if T ≳ (d + log(1/δ)) ∨ ( κ λ γd log(m/δ)), then with probability at least 1 -δ, E(W T ) ≲ τ r a κ(d + log(1/δ)) γ 2 T + (r b + λ-1 ) dk T , (4.2) where we let κ = 1 for decreasing step size and κ = log(T ) for constant step size. The above bound is independent of the width m of the NN, and only grows with the dimension of the input space d and that of the principal subspace k; thus, producing non-vacuous estimates in the overparametrized regime where m is large. Further, the bound is stable in the number of SGD iterations T , that is, it converges to zero as T → ∞. We remark that generalization bounds for SGD that rely on algorithmic stability are optimal for strongly convex objectives (Hardt et al., 2016 ); yet, they lead to unstable diverging bounds in non-convex settings as T → ∞. As such, these techniques often require early stopping, which is clearly not needed in our result.

4.3. COMPRESSIBILITY

NNs exhibit compressiblity features in empirical studies, which is known to be associated with better generalization. Under the assumption that the trained network is compressible, several works established generalization bounds, see e.g. Arora et al. (2018) ; Suzuki et al. (2020) . However, a theoretical justification of this assumption, specifically for a NN trained with SGD, was missing. Indeed, Theorem 3 provides a concrete answer to this question; since the SGD iterate W T converges to a low-rank matrix, the resulting weights are compressible. More precisely, let π k : R m×k → R m×k be the low-rank approximation operator defined by π k (W ) := arg min {W ′ :rank(W ′ )≤k} ∥W -W ′ ∥ F . As W T ∥ lies in the principal subspace, it has rank at most k. Thus, we can write ∥W T -π k (W T )∥ F ≤ ∥W T -W T ∥ ∥ F = ∥W T ⊥ ∥ F , and the following is an immediate consequence of the bound in Theorem 3, that is ∥W T ⊥ ∥ F / √ m ≲ d/T , combined with the Lipschitzness of R τ (W ), which we prove in Lemma 17. Proposition 6. Consider the setting of Theorem 3 with either decreasing or constant step size. Then, with probability at least 1 -δ, R τ (π k (W T ), a, b) -R τ (W T , a, b) ≲ τ κ γ d + log(1/δ) T , (4.3) where we let κ = 1 for decreasing step size and κ = log(T ) for constant step size. This result demonstrates that the low-dimensionality exhibited by the trained NN provides a rate of O( d/T ) for the gap between its population risk and that of its compressed version. The bound is independent of both the width m and the dimension of the principal subspace k. Finally, we highlight that Suzuki et al. ( 2020) provide generalization bounds by assuming a near low-rank structure for the weight matrix, namely that its jth singular value decays proportional to j -α for some α > 1/2. However, this condition imposes a structure quite different than what we proved in Theorem 3. We studied the dynamics of SGD with weight decay on two-layer NNs, and proved that under a multiple-index teacher model, the first-layer weights converge to the principal subspace, i.e. the span of the weights of the teacher. This phenomenon is of particular interest when the target depends on the input along a few important directions. In this setting, we proved novel generalization bounds for SGD via uniform convergence on the low-dimensional principal subspace. Further, we proved that two-layer ReLU networks can learn a single-index target with a monotone link that has at most polynomial growth, using online SGD, with a number of samples almost linear in d. Thus, as an implication of low-dimensionality, we established a separation between kernel methods and trained NNs where the former suffers from the curse of dimensionality.

5. CONCLUSION

Two principal forces are responsible for the emergence of the low-dimensional structure. The main one is the linear interaction between the Gaussian input and the first-layer weights in both student and teacher models. The secondary one is the weight decay which allows SGD to avoid critical points outside of principal subspace. Figure 2 shows the convergence behavior in absence of weight decay. Understanding more precisely the range of λ that implies convergence to the principal subspace, as well as investigating the possibility of learning multiple-index models using this convergence, are left as important directions for future studies. A PROOFS OF SECTION 2.1 Additional Notation. We employ the following notation throughout the appendix. For vectors v and u, we use ⟨v, u⟩ and v • u to denote their Euclidean inner product and the element-wise product, and we use ∥ v ∥ p and diag(v) to denote the L p -norm and the diagonal matrix whose diagonal entries are v. For matrices V and W , we use ⟨V , W ⟩ F , ∥V ∥ F , and ∥V ∥ 2 to denote the Frobenius inner product, Frobenius norm, and the operator norm, respectively. For an activation function σ : R → R, σ ′ and σ ′′ denote its first and second (weak) derivatives, which are applied element-wise for vector inputs. We frequently use ∇ℓ to denote ∇ W ℓ when it is clear from the context. We use the shorthand notation σ a,b (W x) to denote a • σ(W x + b), and similarly for σ ′ a,b (W x) and σ ′′ a,b (W x). We use vec(A) ∈ R mn to denote the vectorized representation of a matrix A ∈ R m×n , and A ⊗ B for the Kronecker product of two matrices A ∈ R m×n and B ∈ R p×q ; we recall that the Kronecker product is an mp × nq block matrix comprised of m × n blocks of shape p × q, where block (i, j) is given by A ij B. In the appendix, we will prove the statements of the main text in a more general formulation. In particular, for smooth activations, we assume sup|σ ′ | ≤ β 1 and sup|σ ′′ | ≤ β 2 for some β 1 , β 2 ∈ R + , and we denote sup|σ| ≤ β 0 , β 0 ∈ (0, ∞]. We will consider the following general case for the bias vector b ∈ R m : b j iid ∼ D b , such that |b j | ≥ b * > 0, for some b * > 0. This setting clearly covers the case of b j = ±1 from the initialization of Assumption 3. Throughout the appendix, C will denote a generic positive absolute constant (e.g. 10), whose value may change from line to line.

A.1 PROOF OF LEMMA 1

In what follows, ∇ ⊤ is the Jacobian matrix and ∇ is the transpose of Jacobian for vector valued functions, which is the same as gradient for real-valued functions. When σ is twice differentiable (Assumption 2.A), standard matrix calculations yield ∇ W E[R(W )] (a) = E[∇ W ℓ(ŷ(x; W ), y)] = E[∂ 1 ℓ(ŷ(x; W ), y)∇ W ŷ(x; W )] = E ∂ 1 ℓ(ŷ(x; W ), y)σ ′ a,b (W x)x ⊤ = E E ∂ 1 ℓ(ŷ(x; W ), y)σ ′ a,b (W x)x ⊤ | ϵ (b) = E E ∇ ⊤ x ∂ 1 ℓ(ŷ(x; W ), g ϵ (U x))σ ′ a,b (W x) | ϵ = E ∂ 2 1 ℓ σ ′ a,b (W x)∇ ⊤ x ŷ(x; W ) + ∂ 1 ℓ ∇ ⊤ x σ ′ a,b (W x) + ∂ 2 12 ℓ σ ′ a,b (W x)∇ ⊤ x g ϵ (U x) = E ∂ 2 1 ℓ σ ′ a,b (W x)σ ′ a,b (W x) ⊤ + ∂ 1 ℓ diag(σ ′′ a,b (W x)) W + + E ∂ 2 12 ℓ σ ′ a,b (W x)∇g ϵ (U x) ⊤ U = H(W )W + D(W )U , (A.1) where (a) follows from the dominated convergence theorem and (b) follows from the Stein's lemma, and ∇g ϵ is the weak derivative of g ϵ w.r.t. its inputs. Combining the above calculations with the gradient of the regularization term, with D(W ) = E ∂ 2 12 ℓ(ŷ, y)(a • σ ′ (W x + b))∇g ⊤ ϵ , (A.2) where ∇g ϵ is the weak derivative of g ϵ w.r.t. its inputs, and H(W ) = E (a • σ ′ (W x + b))(a • σ ′ (W x + b)) ⊤ + E[(ŷ -y) diag((a • σ ′′ (W x + b)))], (A.3) the proof is complete for smooth activations. For ReLU activations and ℓ satisfying Assumption 2.B, we introduce the following smooth approximation σ ι (z) = 1 ι log(1 + e ιz ) , ι > 0 . Then we have H ι (W ) = E ∂ 2 1 ℓ (a • σ ′ ι (W x + b))(a • σ ′ ι (W x + b)) ⊤ + E[∂ 1 ℓ diag(a • σ ′′ ι (W x + b))] ⪰ -∥a∥ ∞ max 1≤j≤m E[|σ ′′ ι (⟨w j , x⟩ + b j )|]I m . As σ ′′ τ ≥ 0, the critical step is to show lim ι→∞ E[σ ′′ ι (⟨w, x⟩ + b)] < ∞, uniformly for all w. Let z = ⟨w, x⟩ + b. Then z ∼ N (b, ∥w∥ 2 ), and ∞ 0 σ ′′ ι (z) e - (z-b) 2 2∥w∥ 2 2 √ 2π∥w∥ 2 dz ≤ ι ∞ 0 e -ιz- (z-b) 2 2∥w∥ 2 2 √ 2π∥w∥ 2 dz = ιe -b 2 2∥w∥ 2 2 + (ι∥w∥ 2 - b ∥w∥2 ) 2 2 ∞ 0 e -1 2 ( z ∥w∥2 +ι∥w∥2- b ∥w∥2 ) 2 √ 2π∥w∥ 2 dz. = ιe -b 2 2∥w∥ 2 2 + (ι∥w∥ 2 - b ∥w∥2 ) 2 2 (1 -Φ(ι∥w∥ 2 - b ∥w∥ )). (a) ≤ ιe - b 2 2∥w∥ 2 2 √ 2π∥w∥ 2 (ι -b ∥w∥ 2 2 ) (b) ≤ 2 π e -b 2 2∥w∥ 2 2 ∥w∥ 2 (c) ≤ 1 |b| 2 eπ , where (a) follows from the Gaussian tail bound 1 -Φ(x) ≤ e -x 2 /2 √ 2πx , where Φ is the standard Gaussian CDF; (b) holds for large enough ι; and (c) holds by considering supremum over ∥w∥ 2 . Thus E[σ ′′ ι (⟨w j , x⟩ + b j )] ≤ 2 |bj | 2 eπ and consequently, -2∥a∥ ∞ b * 2 eπ I m ⪯ H ι (W ) ⪯ ∥a∥ 2 2 + 2∥a∥ ∞ b * 2 eπ I m where b * = min 1≤j≤m |b j |. Moreover, as σ ′ ι (W x+b) converges a.s. (i.e. except when ⟨w j , x⟩+b j = 0 for some j) to σ ′ (W x + b), by the dominated convergence theorem, ∇R(W ) = lim ι→∞ H ι (W )W + lim ι→∞ D ι (W )U We can immediately observe from the dominated convergence theorem that D ι (W ) → D(W ) as ι → ∞ with D(W ) given in (A.2). Moreover, we let H(W ) = lim ι→∞ H ι (W ), and observe that -2∥a∥ ∞ b * 2 eπ I m ⪯ H(W ) ⪯ ∥a∥ 2 2 + 2∥a∥ ∞ b * 2 eπ I m . (A.4) This finishes the proof of Lemma 1. In the case of smooth activations (Assumption 2.A), we have the following bounds. Lemma 7. Let R(W ) := E[ℓ(ŷ(x; W , a, b), y)] be the unregularized population risk. Under Assumptions 1&2.A we have -β 2 ∥a∥ ∞ 2R(W )I m ⪯ H(W ) ⪯ β 2 1 ∥a∥ 2 2 + β 2 ∥a∥ ∞ 2R(W ) I m . (A.5) Proof. Assumption 2.A requires ℓ(ŷ, y) = 1 2 (ŷ -y) 2 . Hence by definition of H, H(W ) = E σ ′ a,b (W x)σ ′ a,b (W x) ⊤ + E (ŷ(x; W ) -y) diag(σ ′′ a,b (W x)) . The first term is positive semi-definite and it can be easily bounded: 0 ≤ v ⊤ E σ ′ a,b (W x)σ ′ a,b (W x) ⊤ v ≤ E ∥σ ′ a,b (W x) 2 ∥ 2 ∥v∥ 2 2 ≤ β 2 1 ∥a∥ 2 2 ∥v∥ 2 2 for an arbitrary vector v ∈ R m . For the second term, we have -β 2 ∥a∥ ∞ E[|ŷ -y|]I m ⪯ E (ŷ(x; W ) -y) diag(σ ′′ a,b (W x)) ⪯ β 2 ∥a∥ ∞ E[|ŷ -y|]I m and E[|ŷ(x; W ) -y|] ≤ 2R(W ) by Jensen's inequality. A.2 PROOF OF PROPOSITION 2 In order to present the proof of Proposition 2, we need a uniform control over the eigenspectrum of H(W ). In the case of ReLU (Assumption 2.B), this follows from (A.4). For smooth activations (Assumption 2.A), we need to establish the boundedness of R(W t ) along the trajectory. The first step towards achieving this goal is to obtain an estimate of R λ (W t+1 ) -R λ (W t ), which depends on the local smoothness of R λ (W ). We denote by ∇ 2 R λ (W ) the full Hessian of the risk function, an md × md matrix comprised of d × d blocks (∇ 2 wi,wj R λ (W )) 1≤i,j≤m where (∇ 2 wi,wj R λ (W )) pq = ∂ 2 R λ (W ) ∂(wi)p∂(wj )q . Lemma 8. Let R(W ) := E[ℓ(ŷ(x; W , a, b), y)] be the unregularized population risk. Under Assumptions 1&2.A, we have the following estimate for the eigenspectrum of the Hessian λ -β 2 ∥a∥ ∞ 6R(W ) I md ⪯ ∇ 2 R λ (W ) ⪯ λ + β 2 1 ∥a∥ 2 2 + β 2 ∥a∥ ∞ 6R(W ) I md . (A.6) Proof. By the chain rule for derivatives, we have ∇ 2 wi,wj R(W ) = E {a i a j σ ′ (⟨w i , x⟩ + b i )σ ′ (⟨w j , x⟩ + b j ) + (ŷ(x; W ) -y)δ ij a i σ ′′ (⟨w i , x⟩ + b i )}xx ⊤ , where δ ij is the Kronecker delta. As a result, in matrix form, the Hessian reads ∇ 2 R(W ) = E σ ′ a,b (W x)σ ′ a,b (W x) ⊤ ⊗ xx ⊤ + E (ŷ(x; W ) -y) diag(σ ′′ a,b (W x)) ⊗ xx ⊤ . The first term is a positive semi-definite matrix with bounded spectral norm; indeed, for any V ∈ R m×d 0 ≤ vec(V ) ⊤ E σ ′ a,b (W x)σ ′ a,b (W x) ⊤ ⊗ xx ⊤ vec(V ) = E σ ′ a,b (W x)x ⊤ , V 2 F = = E σ ′ a,b (W x), V x 2 ≤ E ∥σ ′ a,b (W x)∥ 2 2 ∥V x∥ 2 2 ≤ β 2 1 ∥a∥ 2 2 ∥V ∥ 2 F . The second term is bounded by the following: v ⊤ E (ŷ(x; W ) -y)a j σ ′′ (⟨w j , x⟩)xx ⊤ v = E (ŷ(x; W ) -y)a j σ ′′ (⟨w j , x⟩ + b j )(x ⊤ v) 2 ≤ β 2 ∥a∥ ∞ E (ŷ(x; W ) -y) 2 1 2 E (x ⊤ v) 4 1 2 = β 2 ∥a∥ ∞ 2R(W ) 3∥v∥ 4 , for all 1 ≤ j ≤ m and for any v ∈ R d , which completes the proof. Lemma 9. In the same setting as the previous Lemma, for any W , W ′ ∈ R m×d we have R λ (W ′ ) ≤ R λ (W ) + ∇R λ (W ), W ′ -W F + (λ + β 2 1 ∥a∥ 2 2 + 2β 2 ∥a∥ ∞ 3R(W )) 2 ∥W ′ -W ∥ 2 F + √ 6β 2 β 1 ∥a∥ ∞ ∥a∥ 2 2 ∥W ′ -W ∥ 3 F . Proof. By Taylor's theorem, R λ (W ′ ) = R λ (W ) + ∇R λ (W ), W ′ -W F + 1 2 vec(W ′ -W ), ∇ 2 R λ (W α ) vec(W ′ -W ) (A.7) for some W α = W + α(W ′ -W ), α ∈ [0, 1]. The last term can be estimated using Lemma 8: vec(W ′ -W ), ∇ 2 R λ (W α ) vec(W ′ -W ) ≤ ∥∇ 2 R λ (W α )∥ 2 ∥W ′ -W ∥ 2 F ≤ ≤ λ + β 2 1 ∥a∥ 2 2 + β 2 ∥a∥ ∞ 6R(W α )) ∥W ′ -W ∥ 2 F , (A.8) Next, we provide an upper bound for R(W α ): 2(R(W α ) -R(W )) = E (ŷ(x; W α ) -y) 2 -(ŷ(x; W ) -y) 2 = E (ŷ(x; W α ) -ŷ(x; W )) 2 + 2 E[(ŷ(x; W α ) -ŷ(x; W ))(ŷ(x; W ) -y)] ≤ E (ŷ(x; W α ) -ŷ(x; W )) 2 + 2 E (ŷ(x; W α ) -ŷ(x; W )) 2 1 2 2R(W ) ( * ) ≤ β 2 1 ∥a∥ 2 2 ∥W α -W ∥ 2 F + 2β 1 ∥a∥ 2 ∥W α -W ∥ F 2R(W ) ≤ β 2 1 ∥a∥ 2 2 ∥W ′ -W ∥ 2 F + 2β 1 ∥a∥ 2 ∥W ′ -W ∥ F 2R(W ) ≤ 2β 2 1 ∥a∥ 2 2 ∥W ′ -W ∥ 2 F + 2R(W ), (A.9) where the last inequality follows from Young's inequality and ( * ) is due to the estimate below: E (ŷ(x; W α )) -ŷ(x; W )) 2 = E      m j=1 a j {σ(⟨(w α ) j , x⟩ + b j ) -σ(⟨w j , x⟩ + b j )}  2    ≤ m j=1 a 2 j E (σ(⟨(w α ) j , x⟩ + b j ) -σ(⟨w j , x⟩ + b j )) 2 ≤ m j=1 a 2 j β 2 1 E   m j=1 ⟨(w α ) j -w j , x⟩ 2   = β 2 1 ∥a∥ 2 2 ∥W α -W ∥ 2 F ., (A.10) Plugging (A.9) into (A.8) completes the proof. In order to prove Proposition 2 we will additionally need a bound on the norm of the iterates {W t } t≥0 of the trajectory. Lemma 10. Let {W t } t≥0 be the sequence of PGD iterates (2.5). Suppose that there exists T ≥ 1 such that R λ (W t ) is non-increasing in t = 0, 1, . . . , T . Under Assumptions 1&2.A, for η = mη λ > β 2 ∥a∥ ∞ 2R λ (W 0 ) and η < λ + β 2 1 ∥a∥ 2 2 + β 2 ∥a∥ ∞ 2R λ (W 0 ) -1 , (A.11) we have ∥W t ∥ F ≤ (1 -ηγ) t ∥W 0 ∥ F + β 1 m∥a∥ 2 E ∥∇g ⊤ ϵ U ∥ 2 γ ∀ t ≤ T , (A.12) where γ/m = λ -β 2 ∥a∥ ∞ 2R λ (W 0 ). Proof. The update rule of PGD reads W t+1 = (I m -η(H(W t ) + λI m ))W t -ηD(W t )U . (A.13) Since R(W t ) ≤ R λ (W t ) ≤ R λ (W 0 ), for all t ≤ T , we obtain from Lemma 7 λ -β 2 ∥a∥ ∞ 2R λ (W 0 ) I m ⪯ H(W t ) + λI m ⪯ λ + β 2 1 ∥a∥ 2 2 + β 2 ∥a∥ ∞ 2R λ (W 0 ) I m , (A.14) and for η as in (A.11) we have 0 ⪯ I m -η(H(W t ) + λI m ) ⪯ (1 -ηγ)I m . Therefore, ∥W t+1 ∥ F ≤ (1 -ηγ)∥W t ∥ F + mη∥D(W t )U ∥ F . (A.15) and we can easily bound the last term ∥D(W t )U ∥ F = ∥E σ ′ a,b (W t x)∇g ⊤ ϵ U ∥ F ≤ E ∥σ ′ a,b (W t x)∇g ⊤ ϵ U ∥ F ≤ β 1 ∥a∥ 2 E ∥∇g ⊤ ϵ U ∥ 2 . (A.16) The statement of the lemma then follows by plugging the above bound back into (A.15) and expanding the recursion. We are now ready to prove Proposition 2. Proof. [Proposition 2] Along the proof, we set ϱ := λ + β 2 1 ∥a∥ 2 2 + β 2 ∥a∥ ∞ 2R λ (W 0 ). Notice that in the setting of Proposition 2, we have ϱ ≍ λ m = λ. We will consider the event where ∥W 0 ∥ F ≤ √ 2m, which happens with probability at least 1 -exp(-Cmd). Smooth activations. We begin by considering the following condition λ ≥ γ m + β 2 ∥a∥ ∞ 2R λ (W 0 ). (A.17) Solving the quadratic equation to find the range of λ where the above condition is satisfied yields λ ≥ γ m + β 2 2 ∥a∥ 2 ∞ ∥W 0 ∥ 2 F 2 + β 4 2 ∥a∥ 4 ∞ ∥W 0 ∥ 4 F 4 + γ m β 2 2 ∥a∥ 2 ∞ ∥W 0 ∥ 2 F + 2β 2 2 ∥a∥ 2 ∞ R(W 0 ) In the setting of Proposition 2 with ∥a∥ ∞ ≲ m -1 , the above simplifies to λ ≥ 1 + γ + 1 + 2γ + 2R(W 0 ) m , which is satisfied in Proposition 2. Thus we will assume (A.17) holds in the rest of the proof for smooth activations. We will use induction on t to show that R λ (W t ) is non-increasing. The base case is trivial, and assuming the claim holds up to time t, Lemma 9 implies R λ (W t+1 ) ≤ R λ (W t ) -η∥∇R λ (W t )∥ 2 F + Cη 2 ϱ∥∇R λ (W t )∥ 2 F + Cβ 1 β 2 ∥a∥ ∞ ∥a∥ 2 η 3 ∥∇R λ (W t )∥ 3 F . (A.18) Moreover, we have the following upper bound on gradient norm ∥∇R λ (W t )∥ F (a) ≤ ∥H(W t ) + λI m ∥ 2 ∥W t ∥ F + ∥D(W t )U ∥ F (b) ≲ ϱ∥W 0 ∥ F + β 1 ς∥a∥ 2 mϱ γ + 1 (A.19) where (a) follows from the closed form of the gradient (2.6) and (b) follows from (A.12), (A.14) and (A.16). Thus with a choice of η ≲ ϱ∥W 0 ∥ F ∥a∥ 2 + β 1 ς∥a∥ 2 2 ( mϱ γ + 1) -1 we have η∥a∥ 2 ∥∇R λ (W )∥ F ≤ 1. Consequently, with η∥a∥ 2 ∥∇R λ (W )∥ F ≤ 1, R λ (W t+1 ) ≤ R λ (W t ) -η∥∇R λ (W t )∥ 2 F + Cη 2 (ϱ + β 1 β 2 ∥a∥ ∞ )∥∇R λ (W t )∥ 2 F . Therefore, with a choice of η ≲ (ϱ + β 1 β 2 ∥a∥ ∞ ) -1 , we will have R λ (W t+1 ) ≤ R λ (W t ) -Cη∥∇R λ (W t )∥ 2 F . As ∥a∥ 2 ≲ m -1/2 , ∥W 0 ∥ F ≲ m 1/2 , and ϱ ≍ λ m , we can simplify the two conditions to η ≲ m( λ + λς γ + ς) -1 and η ≲ m λ-1 respectively, hence proof of the induction is complete. Finally, recall the update rule of PGD (A.13): W t+1 = (I m -η(H(W t ) + λI m ))W t -ηD(W t )U . By projecting each row of this recursion onto the orthogonal complement of the principal subspace, we have W t+1 ⊥ = (I m -η(H(W t ) + λI m ))W t ⊥ . Again from (A.14), we have that for η < ϱ -1 , 0 ⪯ I m -η(H(W ) + λI m ) ⪯ (1 -ηγ)I m , therefore ∥W t+1 ⊥ ∥ F ≤ (1 -ηγ)∥W t ⊥ ∥ F . We have shown (2.7) and the proof for smooth activations is complete.

ReLU activation. By (

A.4), for λ ≥ γ + 2∥a∥∞ b * 2 eπ and η < (λ + ∥a∥ 2 2 + 2∥a∥∞ b * 2 eπ ) -1 we have 0 ⪯ I m -η(H(W ) + λI m ) ⪯ (1 -ηγ)I m . Notice that in the setting of Proposition 2, ∥a∥ 2 2 + ∥a∥ ∞ ≲ λ, thus η ≲ λ -1 suffices for the above inequality to hold. The rest of the proof follows similarly to the smooth case.

B PROOFS OF SECTION 3

We begin by characterizing the tail behavior of the stochastic gradient noise in the SGD updates (3.1) through the following lemma. Lemma 11. For any fixed W ∈ R m×k , let Γ := ∇ℓ(ŷ(x; W ), y) -E[∇ℓ(ŷ(x; W ), y)] denote the zero-mean stochastic noise in the gradient of the loss function ℓ when (x, y) are generated according to Assumption 1, and recall that ∇ℓ(ŷ(x; W ), y) = ∂ 1 ℓ(ŷ(x; W ), y)σ ′ a,b (W x)x ⊤ . Suppose sup ŷ,y |∂ 1 ℓ(ŷ, y)| ≤ κ. Then for any V ∈ R m×d , the zero-mean random variable ⟨V , Γ⟩ F is Cβ 1 κ∥a∥ 2 ∥V ∥ F -sub-Gaussian. Proof. We use the shorthand notation ∇ℓ := ∇ W ℓ(W x, y) and ∇R := ∇ W R(W ). We compute the following E[|⟨V , ∇ℓ -∇R⟩ F | p ] 1 p (a) ≤ E[|⟨V , ∇ℓ⟩ F | p ] 1 p + E[|⟨V , ∇R⟩ F | p ] 1 p (b) ≤ 2 E[|⟨V , ∇ℓ⟩ F | p ] 1 p ≤ 2κ E | V , σ ′ a,b (W x)x ⊤ F | 2p 1 2p . where (a) and (b) follow from the Minkowski and Jensen inequalities respectively. Furthermore, we have E | V , σ ′ a,b (W x)x ⊤ F | 2p 1 2p = E | V x, σ ′ a,b (W x) | 2p 1 2p ≤ β 1 ∥a∥ 2 E ∥V x∥ 2p 2 1 2p ≤ β 1 ∥a∥ 2 (∥V ∥ F + C∥V ∥ 2 √ p), where the last inequality follows from Gaussianity of V x and Lemma 32. Hence E[|⟨V , ∇ℓ -∇R⟩ F | p ] 1 p ≤ Cβ 1 κ∥a∥ 2 ∥V ∥ F √ p. Invoking Lemma 28 implies sub-Gaussianity of ⟨V , ∇ℓ -∇R⟩ F and completes the proof. We proceed by presenting a lemma which constitutes the main part of the proof of Theorem 3 via establishing a recursive bound on the moment generating function (MGF) of ∥W t ⊥ ∥ 2 F , which will in turn be used to prove high probability statements for ∥W  (A t ) -δ, ∥W t ⊥ ∥ 2 F ≲ t-1 j=0 (1 - ηj γ m )∥W 0 ⊥ ∥ 2 F + mη t κ 2 (d + log(1/δ)) γ . (B.1) Proof. Let F t denote the sigma algebra generated by {W j } t j=0 . Recall from Lemma 11 that we define Γ t = ∇ℓ(ŷ(x (t) ; W t ), y (t) ) -E ∇ℓ(ŷ(x (t) ; W t ), y (t) ) with ∇ℓ(ŷ(x (t) ; W t ), y (t) ) = ∂ 1 ℓ(ŷ(x (t) ; W t ), y (t) )σ ′ a,b (W t x (t) )(x (t) ) ⊤ . Then for the SGD updates we have W t+1 = W t -η t ∇R λ (W t ) -η t Γ t . By projecting the iterates onto the orthogonal complement of the principal subspace, W t+1 ⊥ = I m -η t (H(W t ) + λI m ) W t ⊥ -η t Γ t ⊥ . Let M t := I m -η t (H(W t ) + λI m ). Then, by observing that 1 At+1 ≤ 1 At , for any 0 ≤ s ≲ γ mηtκ 2 we have E 1 At+1 e s∥W t+1 ⊥ ∥ 2 F | F 0 ≤ E 1 At e s∥M tW t ⊥ ∥ 2 F +sη 2 t ∥Γ t ⊥ ∥ 2 F +⟨-2sηtM tW t ⊥ ,Γ t ⊥ ⟩ F | F 0 = E 1 At e s∥M tW t ⊥ ∥ 2 F E e sη 2 t ∥Γ t ⊥ ∥ 2 F e ⟨-2sηtMtW t ⊥ ,Γ t ⊥ ⟩ F | F t | F 0 ≤ E 1 At e s∥M tW t ⊥ ∥ 2 F E e 2sη 2 t ∥Γ t ⊥ ∥ 2 F | F t 1 2 E e ⟨-4sηtMtW t ⊥ ,Γ t ⊥ ⟩ F | F t 1 2 | F 0 , (B.2) where the last inequality follows from the Cauchy-Schwartz inequality for conditional expectation. Moreover, it is straightforward to observe that ∥∇ℓ(ŷ(x (t) ; W t ), y (t) )∥ 2 F ≤ κ 2 ∥x∥ 2 2 , hence E ∥∇ℓ(ŷ(x (t) ; W t ), y (t) )∥ 2 F ≤ κ 2 d. Note that by Jensen's inequality ∥Γ t ⊥ ∥ 2 F ≤ 2∥∇ℓ(W t x (t) , y (t) )∥ 2 F + 2 E ∥∇ℓ(W t x (t) , y (t) )∥ 2 F . Consequently E exp 2sη 2 t ∥Γ t ⊥ ∥ 2 F | F t ≤ exp 4sη 2 t κ 2 d E exp 4sη 2 t κ 2 ∥x∥ 2 2 | F t ≤ exp 4sη 2 t κ 2 d exp 8sη 2 t κ 2 d , where the second inequality follows from Lemma 33 for 4sη 2 t κ 2 ≤ 1/4. Since s ≲ γ mηtκ 2 , in order to satisfy the condition of Lemma 33 we need to ensure η t γ/m ≲ 1, which is guaranteed by our η t ρ ≲ 1 assumption for a suitably small absolute constant, as γ/m ≤ λ ≤ ρ. Next, we bound the last term in (B.2). Let V := -4sη t M t W t ⊥ . Then by Lemma 11 we have E exp V , Γ t ⊥ F | F t ≤ exp Cs 2 η 2 t κ 2 ∥M t W t ⊥ ∥ 2 F Putting things back together in (B.2) and using the tower property of expectation, we have E 1 At+1 e s∥W t+1 ⊥ ∥ 2 F | F 0 ≤ E 1 At e s(1+Csη 2 t κ 2 )∥M tW t ⊥ ∥ 2 F +Csη 2 t κ 2 d | F 0 . (B.3) Next, we bound ∥M t ∥ 2 . By definition of A t , we can already ensure H(W t ) ⪰ γ m I m in (B.3). Recall the definition of H(W t ) H(W t ) = E ∂ 2 1 ℓ(ŷ(x; W t ), y)σ ′ a,b (W t x)σ ′ a,b (W t x) ⊤ + ∂ 1 ℓ(ŷ(x; W t ), y) diag(σ ′′ a,b (W t x)) . Notice that 0 ≤ ∂ 2 1 ℓ(ŷ(x; W ), y) ≤ 1 under either Assumption 2.A or Assumption 2.B. Moreover we have, |∂ 1 ℓ(ŷ, y)| ≤ κ. Thus, H(W t ) + λI m ⪯ λ + β 2 1 ∥a∥ 2 2 + β 2 κ∥a∥ ∞ I m = ρI m . Therefore, 0 ⪯ I m -η t (H(W t ) + λI m ) ⪯ (1 -ηtγ m )I m . As a result ∥M t ∥ 2 ≤ 1 -ηtγ m . Combined with (B.3) we have E 1 At+1 exp s∥W t+1 ⊥ ∥ 2 F | F 0 ≤ E 1 At exp s(1 + Csη 2 t κ 2 )(1 -ηtγ m ) 2 ∥W t ⊥ ∥ 2 F + Csη 2 t dκ 2 | F 0 ≤ exp Csη 2 t κ 2 d E 1 At exp s(1 -ηtγ m )∥W t ⊥ ∥ 2 F | F 0 (B.4 ) where the second inequality holds by the fact that Csη 2 t κ 2 ≤ η t γ/m, which in turn holds when a small enough absolute constant is chosen in 0 ≤ s ≲ γ mηtκ 2 . Also notice that for decreasing stepsize, 1 -γηt m = (t + t * ) 2 (t + t * + 1) 2 ≤ 1 -(t+t * ) 2 (t+t * +1) 2 1 -(t+t * -1) 2 (t+t * ) 2 = η t η t-1 , (B.5) (and the above holds trivially for constant step size), thus when s ≤ C1γ ηtκ 2 for some absolute constant C 1 , we have s(1 -η t γ) ≤ C1γ ηt-1κ 2 with the same absolute constant. Hence we are allowed to expand the recursion (B.4), which implies E 1 At exp s∥W t ⊥ ∥ 2 F | F 0 ≤ exp   s t-1 j=0 (1 - ηj γ m )∥W 0 ⊥ ∥ 2 F + Csκ 2 d t-1 i=0 η 2 i t-1 j=i+1 (1 - ηj γ m )   for all 0 ≤ s ≲ γ mηt-1κ 2 . Moreover, direct calculation implies that with both constant and decreasing step sizes of Lemma 12, we have t-1 i=0 η 2 i t-1 j=i+1 (1 - ηj γ m ) ≤ Cmηt γ (with C = 1 for constant stepsize). Thus, for all 0 ≤ s ≲ γ mηt-1κ 2 E 1 At exp s∥W t ⊥ ∥ 2 F | F 0 ≤ exp   s t-1 j=0 (1 - ηj γ m )∥W 0 ⊥ ∥ 2 F + Csmηtκ 2 d γ   . Finally, we can apply a Chernoff bound to obtain P A t ∩ {∥W t ⊥ ∥ 2 F ≥ ε} | F 0 ≤ exp   s    t-1 j=0 (1 -η j γ)∥W 0 ⊥ ∥ 2 F + Cmηtκ 2 d γ -ε      By choosing ε = t-1 j=0 (1 - ηj γ m )∥W 0 ⊥ ∥ 2 F + Cmη t κ 2 (d + log(1/δ)) γ . and the largest possible s ≲ γ mηtκ 2 , we obtain P ∥W t ⊥ ∥ 2 F ≥ ε | F 0 ≤ P A t ∩ {∥W t ⊥ ∥ F ≥ ε} + P A C t ≤ δ + P A C t . Taking another expectation to remove conditioning on initialization completes the proof. The proof of Theorem 3 for decreasing stepsize follows by a direct computation of the quantities in Lemma 12 and is presented below. On the other hand, in order to get a better dependence on λ, choosing the events A t for constat stepsize is more subtle and is presented in Section B.2.

B.1 PROOF OF THEOREM 3 FOR DECREASING STEPSIZE

This part is directly implied by Lemma 12. The following argument holds on an event where ∥W ∥ F ≲ √ m, which happens with probability at least 1 -O(δ). In order to see this connection, we will first present an improved statement over Lemma 7 for the case of smooth activations. Recall the definition of H(W ) for the squared error loss ℓ(ŷ, y) = (ŷ-y) 2 2 , H(W ) = E σ ′ a,b (W x)σ ′ a,b (W x) ⊤ + E (ŷ(x; W , a, b) -y) diag(σ ′′ a,b (W x)) . Notice that under Assumption 2.A we have |ŷ| ≤ β 0 ∥a∥ 1 . Then basic matrix algebra similar to that of Lemma 7 along with the triangle inequality shows -β 2 ∥a∥ ∞ (β 0 ∥a∥ 1 + E[|y|])I m ≺ H(W ) ⪯ β 2 1 ∥a∥ 2 2 + β 2 ∥a∥ ∞ (β 0 ∥a∥ 1 + E[|y|]) I m . Therefore, with λ ≥ γ/m + β 2 ∥a∥ ∞ (∥a∥ 1 β 0 + E[|y|] ), we have H(W ) + λI m ⪰ γ/mI m for all W . In addition, |∂ 1 ℓ(ŷ, y)| ≤ β 0 ∥a∥ 1 + K by the triangle inequality. Thus we can invoke Lemma 12 with t = m γ 1 -(t * +t) 2 (t * +t+1) 2 , κ = β 0 ∥a∥ 1 + K, and 1 At = 1. Recall that in the statement of the theorem, (recall b * = 1 in the statement of the theorem), we have H(W ) + λI m ⪰ γ/mI m . Hence this time, we can invoke Lemma 12 with the same decreasing η t , 1 At = 1, and κ = 1. β 0 = β 1 = β 2 = 1, K ≲ 1, ∥a∥ ∞ ≤ 1/m, ∥a∥ 2 ≤ 1/ √ m,

B.2 PROOF OF THEOREM 3 FOR CONSTANT STEPSIZE

In order to improve the condition on λ, we will specifically look at the events A t on which max 0≤j≤t R λ (W j ) is bounded. The following lemma indicates that these events occur with high probability. Lemma 13. Under Assumptions 1&2.A, consider the setting of Lemma 12 with constant stepsize η ≲ ρ-1 . Then we have with probability at least 1 -T exp(-CT η ρd), max 0≤t≤T R λ (W t ) ≤ R λ (W 0 ) + CT η 2 κ 2 ρd. (B.6) Proof. First, recall from Lemma 8 together with R λ (W ) ≲ β 0 ∥a∥ 1 + K = κ, that ∥∇ 2 R λ (W )∥ 2 ≲ λ + β 2 1 ∥a∥ 2 2 + β 2 κ∥a∥ ∞ = ρ. We will first prove that for any t ≤ T and any s ≲ (ηκ 2 ) -1 , we have E e sR λ (W t ) | W 0 ≤ E e sR λ (W 0 )+Cstη 2 κ 2 ρd . Recall that F t is the sigma algebra generated by {W j } t j=0 , and Γ t := ∇ℓ(W t ) -∇R(W t ). By Taylor's theorem and Young's inequality R λ (W t+1 ) ≤ R λ (W t ) -η ∇R λ (W t ), ∇R λ (W t ) + Γ t F + ρη 2 2 ∥∇R λ (W t ) + Γ t ∥ 2 F ≤ R λ (W t ) -η(1 -η ρ)∥∇R λ (W t )∥ 2 F -η ∇R λ (W t ), Γ t F + ρη 2 ∥Γ t ∥ 2 F , E e sR λ (W t+1 ) | F 0 ≤ E e sR λ (W t )-sη(1-η ρ)∥∇R λ (W t )∥ 2 F -sη⟨∇R λ (W t ),Γ t ⟩ F +s ρη 2 ∥Γ t ∥ 2 F | F 0 = E e sR λ (W t )-sη(1-η ρ)∥∇R λ (W t )∥ 2 F E e -sη⟨∇R λ (W t ),Γ t ⟩ F +s ρη 2 ∥Γ t ∥ 2 F | F t | F 0 (a) ≤ E e sR λ (W t )-sη(1-η ρ)∥∇R λ (W t )∥ 2 F E e -2sη⟨∇R λ (W t ),Γ t ⟩ F | F t 1 2 E e 2s ρη 2 ∥Γ t ∥ 2 F | F t 1 2 | F 0 , where (a) follows from the Cauchy-Schwartz inequality for conditional expectation. Moreover, in this setting we have |∂ 1 ℓ(ŷ, y)| = |ŷ -y| ≤ β 0 ∥a∥ 1 + K, thus letting κ = β 0 ∥a∥ 1 + K in Lemma 11, by the sub-Gaussianity of Γ t we have E e ⟨-2sη∇Rλ(W t ),Γ t ⟩ F | F t ≤ e Cs 2 η 2 κ 2 ∥∇R λ (W t )∥ 2 F . Furthermore, we have the following upper bound E e Cs ρη 2 ∥∇ℓ(W t )∥ 2 F | F t ≤ E e Cs ρη 2 κ 2 ∥x∥ 2 2 | F t ≤ e Cs ρη 2 κ 2 d where the second inequality holds for s ≲ 1 ρη 2 κ 2 with a sufficiently small absolute constant by Lemma 33. Similar to the argument in Lemma 12, as we choose s ≲ (ηκ 2 ) -1 , in order to satisfy the condition of Lemma 12 it suffices to have η ρ ≲ 1 for a sufficiently small absolute constant. Putting the above bounds back together, we have E e sR λ (W t+1 ) | F 0 ≤ E e sR λ (W t )-sη(1-η ρ-Csηκ 2 )∥∇R λ (W t )∥ 2 F +Csη 2 κ 2 ρd | F 0 . Expanding the recursion yields, for any s ≲ (ηκ 2 ) -1 (with a sufficiently small absolute constant chosen) E e sR λ (W t ) | F 0 ≤ e sR λ (W 0 )+Cstη 2 κ 2 ρd . As a result, by applying Markov's inequality at time t ≤ T , we have P R λ (W t ) ≥ R λ (W 0 ) + CT η 2 κ 2 ρd ≤ e -CsT η 2 κ 2 ρd ≤ e -CT η ρd . Consequently, with a union bound we have P max 0≤t≤T R λ (W t ) ≥ R λ (W 0 ) + CT η 2 κ 2 ρd ≤ T e -CT η ρd , which completes the proof. As depicted by the following proposition, the rest of the proof is analogous to the decreasing stepsize case. Proposition 14. Consider the setting of Lemma 13 with constant stepsize η ≲ ρ-1 and λ sufficiently large such that λ ≥ γ/m + β 2 ∥a∥ ∞ 2(R λ (W 0 ) + CT η 2 κ 2 ρd). Then with probability at least 1 -T e -CT η ρd -δ we have ∥W T ⊥ ∥ 2 F ≤ (1 -ηγ m ) T ∥W 0 ∥ 2 F + Cmηκ 2 (d + log(1/δ)) γ (B.7) Proof. Let A t = {max 0≤i≤t R λ (W i ) ≤ R λ (W 0 ) + CT η 2 κ 2 ρd}. Notice that A t is {W j } t j=0 measurable and A t+1 ⊆ A t . By the bound established on H(W ) in Lemma 7, on A t we have H(W i ) + λI m ⪰ γ/mI m for all 0 ≤ i ≤ T . Moreover, from Lemma 13, we have P A C T ≤ T e -CT η ρd . Invoking Lemma 12 finishes the proof. The above proposition immediately implies the statement of Theorem 3 for constant stepsize, which we repeat here as a corollary of Proposition 14. Corollary 15 (Proof of Theorem 3 for constant stepsize). Consider the setting of Lemma 13 with λ given in Proposition 14, for constant stepsize η = 2m log(T ) γT with T ≥ (1/δ) C/d . Then with probability at least 1 -δ we have ∥W T ⊥ ∥ F ≲ ∥W 0 ⊥ ∥ F T + mκ γ log(T )(d + log(1/δ)) T . (B.8) C PROOFS OF SECTION 4 C.1 PROOF OF THEOREM 5 As our arguments are based on the Rademacher complexity of a two-layer neural network, we require the knowledge of the norm of W t . We prove a high probability bound for this norm in the following lemma. Lemma 16. Under Assumptions 1&2.A or 1&2.B with either decreasing or constant stepsize as in Theorem 3, let κ = sup ŷ,y |∂ 1 ℓ(ŷ, y)| < ∞ and κ ∞ := β 1 ∥a∥ ∞ κ. Then for any t ≥ 1, with probability at least 1 -m exp -γtd mφλ we have for all 1 ≤ j ≤ m ∥w t j ∥ 2 ≤ t-1 i=0 (1 -η i λ)∥w 0 j ∥ 2 + 3κ ∞ √ d λ , (C.1) where φ = 1 for decreasing step size and φ = log(T ) for constant step size. Proof. First, we prove that for any t > 0 and 0 ≤ s ≤ 2 √ d κ∞ηt-1 , we have E exp(s∥w t j ∥ 2 ) | W 0 ≤ exp s t-1 i=0 (1 -η i λ)∥w 0 j ∥ 2 + 2sκ ∞ √ d λ , (C.2) The base case of t = 0 is trivial, and for the induction step, for any 0 ≤ s ≤ 2 √ d κ∞ηt we have E exp s∥w t+1 j ∥ 2 | W 0 = E exp s∥(1 -η t λ)w t j -η t ∇ wj ℓ(ŷ(x; W t ), y)∥ 2 | W 0 ≤ E exp s(1 -η t λ)∥w t j ∥ 2 + sη t ∥∇ wj ℓ(ŷ(x; W t ), y)∥ 2 | W 0 = E exp s(1 -η t λ)∥w t j ∥ 2 + sη t κ ∞ ∥x∥ 2 | W 0 = E exp s(1 -η t λ)∥w t j ∥ 2 E exp(sη t κ ∞ ∥x∥ 2 ) | W t , W 0 | W 0 (a) ≤ E exp s(1 -η t λ)∥w t j ∥ 2 exp sη t κ ∞ √ d + s 2 κ 2 ∞ η 2 t 2 | W 0 (b) ≤ exp s t i=0 (1 -η i λ)∥w 0 j ∥ 2 + 2sκ ∞ √ d λ where (a) holds since ∥x∥ 2 is a 1-Lipschitz function of a standard Gaussian random vector, thus it is sub-Gaussian with parameter 1 (Lemma 29) and additionally E[∥x∥ 2 ] ≤ √ d, and (b) holds by the induction hypothesis (notice that for decreasing stepsize s(1 -η t λ) ≤ 2 √ d κ∞ηt-1 by (B.5)). Next, we apply the following Chernoff bound, P ∥w t j ∥ 2 > t-1 i=0 (1 -η i λ)∥w 0 j ∥ 2 + 3κ ∞ √ d λ | W 0 ≤ exp - sκ ∞ √ d λ , which holds for any 0 ≤ s ≤ 2 √ d κ∞ηt-1 . Choosing the largest s possible and noting that η t-1 ≤ 2mφ γt yields an exp -γtd mφλ upper bound on the conditional probability, which followed by taking expectation removes the randomness of conditioning on w 0 j . Finally applying a union bound gives us the desired bound. In addition, we would like to approximate R τ (W T ) and Rτ (W T ) with R τ (W T ∥ ) and Rτ (W T ∥ ) respectively. As a result, we will investigate the Lipschitzness of the population and empirical risk in the next lemma. Lemma 17. Under either Assumptions 1&2.A or 1&2.B, the truncated risk W → R τ (W ) is √ 2τ β 1 ∥a∥ 2 -Lipschitz. Moreover, for T ≥ d + log(1/δ) with probability at least 1 -δ over the stochasticity of {x (t) } 0≤t≤T -1 , the truncated empirical risk W → Rτ (W ) is Cτ β 1 ∥a∥ 2 -Lipschitz for some absolute constant C. Proof. We begin by simple observation that ŷ → ℓ(ŷ, y) ∧ τ is √ 2τ -Lipschitz when ℓ(ŷ, y) = 1/2(ŷ -y) 2 and 1-Lipschitz when |∂ 1 ℓ(ŷ, y)| ≤ 1. As τ ≥ 1, we can consider both of them as √ 2τ Lipschitz. Thus by Jensen's inequality |R τ (W ) -R τ (W ′ )| ≤ √ 2τ E |ŷ(x; W ) -ŷ(x; W ′ )| ≤ √ 2τ E      m j=1 a j σ(⟨w j , x⟩ + b j ) - m j=1 a j σ( w ′ j , x + b j )   2    1 2 (a) ≤ √ 2τ ∥a∥ 2 m j=1 E σ(⟨w j , x⟩ + b j ) -σ( w ′ j , x + b j ) 2 ≤ √ 2τ β 1 ∥a∥ 2 m j=1 E w j -w ′ j , x 2 (C.3) ≤ √ 2τ β 1 ∥a∥ 2 ∥W -W ′ ∥ F where (a) follows from the Cauchy-Schwartz inequality. Note that Equation (C.3) also holds for | Rτ (W )-Rτ (W ′ )| when expectation is over the empirical distribution given by the training samples, meaning | Rτ (W ) -Rτ (W ′ )| ≤ √ 2τ β 1 ∥a∥ 2 m j=1 (w j -w ′ j ) ⊤ 1 T T -1 t=0 x (t) x (t) ⊤ (w j -w ′ j ). (C.4) By Lemma 30, with probability at least 1 -δ, we have 1 T T -1 t=0 x (t) x (t) ⊤ -I d 2 ≲ 1, which completes the proof. Lemma 18. Suppose either Assumptions (1,2.A) or (1,2.B) hold. Denote the loss with ℓ(ŷ, y) = ℓ(ŷ -y), S = W ∈ R m×k , a, b ∈ R m : ∥a∥ 2 ≤ r a √ m , ∥b∥ ∞ ≤ r b , ∥ wj ∥ 2 ≤ r w , ∀ 1 ≤ j ≤ m and G = (x, y) → ℓ(ŷ(x; W , a, b), y) ∧ τ : ( W , a, b) ∈ S for x ∈ R k and y ∈ R. Let R(G) denote the Rademacher complexity of the function class G (see Lemma 18 for definition). Then with x ∼ N (0, U U ⊤ ) for some U ∈ R k×d we have R(G) ≤ 2τ β 1 (r w∥U ∥ F + r b )r a 2 T , where T is the number of samples. 2022) Proof. Let F = {(x, y) → f a, W (x, y) : ( W , a, b) ∈ S} for f a, W (x, y) = ŷ(x; W , a, b) -y. Define g(z) := ℓ(z) ∧ τ , and notice G = {(x, y) → g(f a, W (x, y)) : f a, W ∈ F }, R(F) = E sup ( W ,a,b)∈ S 1 T T -1 t=0 ξ t a ⊤ σ( W x(t) + b) -y (i) = E sup ( W ,a,b)∈ S 1 T T -1 t=0 ξ t a ⊤ σ( W x(t) + b) (a) ≤ r a T E sup ( W ,b)∈ S∥ T -1 t=0 ξ t σ( W x(t) + b)∥ ∞ ≤ r a T E sup ∥ w∥2≤r w ,| b|≤r b | T -1 t=0 ξ t σ w, x(t) + b | (b) ≤ 2β 1 r a n E sup ∥ w∥2≤r w ,|b|≤r b | T -1 t=0 ξ t w, x(t) + b | ≤ 2β 1 r a T E sup ∥ w∥2≤r w | T -1 t=0 ξ t ⟨ w, x⟩| + sup | b|≤r b | T -1 t=0 ξ t b| ≤ 2β 1 r a T r w E ∥ T -1 t=0 ξ t x(t) ∥ 2 + r b √ T ≤ 2β 1 (r w∥U ∥ F + r b )r a √ T , where (a) holds by Hölder's inequality and the fact that ∥a∥ 1 ≤ √ m∥a∥ 2 ≤ r a , and (b) follows from the fact that σ is β 1 Lipschitz, thus another application of Talagrand's contraction principle. Proof. [Proof of Theorem 5] Let E 1 denote the event of Lemma 17. We begin with the following decomposition for generalization error which holds on E 1 , R τ (W T ) -Rτ (W T ) = R τ (W T ) -R τ (W T ∥ ) + R τ (W T ∥ ) -Rτ (W T ∥ ) + Rτ (W T ∥ ) -Rτ (W T ) ≤ Cτ β 1 ∥a∥ 2 ∥W T ⊥ ∥ F + R τ (W T ∥ ) -Rτ (W T ∥ ). where the upper bound follows from Lemma 17. Consequently, sup a,b R τ (W T , a, b) -Rτ (W T , a, b) ≤ Cτ β1ra √ m ∥W T ⊥ ∥ F + sup a,b R τ (W T ∥ , a, b) -Rτ (W T ∥ , a, b). (C.5) We begin by upper bounding the first term. From Theorem 3, on an event E 2 we have with probability at least 1 -O(δ) ∥W T ⊥ ∥ F √ m ≲ κ d + log(1/δ) γ 2 T . Next, we bound the second term in (C.5). For each W , define W := U † W ∥ , where U † is the Moore-Penrose pseudo-inverse of U . Then, since we have the representation W ∥ = M U for some M ∈ R m×k , W U = W ∥ U † U = M U U † U = M U = W ∥ . Thus, W x = W x and ℓ(ŷ(x; W , a, b), y) = ℓ(ŷ(x; W , a, b), y) for x = U x, when W is in the principal subspace, i.e. W = W ∥ . Let E 3 denote the event of Lemma 16, on which ∥w T j ∥ 2 ≤ T -1 i=0 (1 -ηiγ m )∥w 0 j ∥ 2 + 3κ ∞ √ d λ and consequently ∥ wT j ∥ 2 ≤ ∥U † ∥ 2 T -1 i=0 (1 -ηiγ m )∥w 0 j ∥ 2 + 3κ ∞ √ d λ for any 1 ≤ j ≤ m. Define r wT as the RHS bound above. Then on E 3 sup a,b R τ (W T ∥ ) -Rτ (W ∥ ) ≤ sup ( W ,a,b)∈ S R τ ( W , a, b) -Rτ ( W , a, b), where we recall S := W ∈ R m×k , a, b ∈ R m : ∥a∥ 2 ≤ ra √ m , ∥b∥ ∞ ≤ r b , ∥ wj ∥ 2 ≤ r wT , ∀ 1 ≤ j ≤ m . Additionally define G = {(x, y) → ℓ(ŷ(x; W , a, b), y) ∧ τ : ( W , a, b) ∈ S}. Then Lemma 31 and Lemma 18 yield E sup ( W ,a,b)∈ S R τ ( W ) -Rτ ( W ) ≤ 2R(G) ≲ τ β 1 (r wT + r b )r a ∥U ∥ F 1 T . Besides, as the loss is bounded by τ , by McDiarmid's inequality, on an event E 4 which happens with probability at least 1 -O(δ) we have sup ( W ,a,b)∈ S R τ ( W ) -Rτ ( W ) ≤ E sup ( W ,a,b)∈ S R( W ) + Cτ log(1/δ) T . and consequently on ∩ 4 i=1 E i sup a,b R τ (W T ∥ , a, b) -Rτ T ∥ , a, b) ≲ τ β 1 (r wT + r b )r a ∥U ∥ F 1 T + τ log(1/δ) T . Finally, observe that ∥a∥ 1 ≤ √ m∥a∥ 2 ≤ r a , and without loss of generality assume U is orthonormal, hence ∥U † ∥ 2 = 1 and ∥U ∥ F = √ k, thus with probability at least 1 -o(δ), sup a,b R τ (W T , a, b) -Rτ (W T , a, b) ≲τ β 1 r a κ d + log(1/δ) γ 2 T + τ β 1 r a t * t * + T 2 r w + κ ∞ λ + r b dk T + τ log(1/δ) T . (C.6) We remark that in the setting of Theorem 3 which is adapted in Theorem 5, ∥a∥ ∞ ≲ m -1 , thus κ ∞ ≲ m -1 . Finally, we observe that r w ≤ √ 2m with probability at least 1-O(δ) over initialization, which completes the proof.

C.2 PROOF OF THEOREM 4

Note that due to the special symmetry in the initialization of Algorithm 1, while training the first layer, all neurons have an identical value, i.e. w t j = w t for all j, and that the stochastic gradient with respect to any neuron can be denote by ∇ℓ = a∂ 1 ℓ(ŷ, y)σ ′ (⟨w, x⟩ + b)x. Furthermore, ∇ wj R λ (W ) will also be identical for all j, which due to the population gradient formula (2.6), we denote by ⟨u, x⟩) . Additionally, via the arguments in the proof of Lemma 1, it is not difficult to observe γ/m ≤ h(w) + λ ≲ m -1 . Furthermore, similar to the arguments of Lemma 11, ⟨∇ℓ, v⟩ is Ca∥v∥ 2 -sub-Gaussian for any v ∈ R d . Next, we will derive a lower bound for ⟨w t , u⟩ to argue that useful features are learned, which first requires obtaining a sharper upper bound on ∥w t ∥ 2 than that of Lemma 16. This improvement is possible due to considering the special case of w t j = w t here. ∇R λ (w) = (h(w) + λ)w + d(w)u, where h(w) = m j=1 H ij (W ) and d(w) = a E ∂ 2 12 ℓ(ŷ, y)σ ′ (⟨w, x⟩ + b)f ′ ( Lemma 19. Suppose t ≥ d. Then, ∥w t ∥ 2 ≤ t * t * + t ∥w 0 ∥ 2 + Cma with probability at least 1 -exp(-C(t * + t)). In particular, using the union bound, we have sup t≥t0 ∥w t ∥ 2 ≤ ∥w 0 ∥ 2 + Cma γ ≲ 1 with probability at least 1 -exp(-C(t * + t 0 )) -exp(-Cd). Proof. Let e t := ∇ w ℓ -∇ w R. Then we have w t+1 = w t -η t ∇ w R λ -η t e t . Recall that ⟨e t , v⟩ is Ca∥v∥ 2 -sub-Gaussian, and F t is the sigma algebra generated by {w j } 0≤j≤t . Let ω t := w t -η t ∇ w R λ . Then, for any 0 ≤ s ≲ γ mηta 2 , E exp s∥w t+1 ∥ 2 2 | F 0 = E exp s∥ω t ∥ 2 2 -2sη t ω t , e t + sη 2 t ∥e t ∥ 2 2 | F 0 ≤ E exp s∥ω t ∥ 2 2 E exp -4sη t ω t , e t | F t 1 2 E exp 2sη 2 t ∥e t ∥ 2 2 | F t 1 2 | F 0 . By sub-Gaussianity of ⟨ω t , e t ⟩ we have E[exp(-4sη t ⟨ω t , e t ⟩) | F t ] ≤ exp(Cs 2 η 2 t a 2 ∥ω t ∥ 2 2 ). Moreover, as ∥∇ℓ∥ 2 ≤ |a|∥x∥ 2 , by Jensen's inequality ∥e t ∥ 2 2 ≤ 2∥∇ℓ∥ 2 2 + 2 E ∥∇ℓ∥ 2 2 ≤ 2a 2 (∥x∥ 2 2 + d). Thus E exp(2sη 2 t ∥e t ∥ 2 2 ) | F t ≤ exp(Csη 2 t a 2 d) for s ≲ 1 η 2 t a 2 (which holds by s ≲ γ mηta 2 , see the proof of Lemma 12 for more details), i.e. we have E exp s∥w t+1 ∥ 2 2 | F 0 ≤ E exp s(1 + Csη 2 t a 2 )∥ω t ∥ 2 2 + Csη 2 t a 2 d | F 0 . Recall that by our choice of η t , 0 ≤ (1 -η t (h(w t ) + λ)) ≤ 1 -ηtγ m (cf. proof of Lemma 12), and ω t = (1 -η t (h(w t ) + λ))w t -η t d(w t )u. As ∥u∥ 2 = 1 and |d(w t )| ≲ |a|, we have ∥ω t ∥ 2 2 ≤ (1 -ηtγ m ) 2 ∥w t ∥ 2 2 + Ca 2 η 2 t + 2η t Ca(1 -ηtγ m )∥w t ∥ 2 (a) ≤ (1 -ηtγ m ) 2 ∥w t ∥ 2 2 + η t 4Cma 2 γ + γ 4m (1 -ηtγ m ) 2 ∥w t ∥ 2 2 + Ca 2 η 2 t (b) ≤ (1 -3ηtγ 2m )∥w t ∥ 2 2 + Cmη t a 2 γ + Ca 2 η 2 t . where (a) holds by Young's inequality and (b) holds for η t γ/m ≲ 1 with a sufficiently small absolute constant. Therefore, for s ≲ γ mηta 2 , E exp s∥w t+1 ∥ 2 2 | F 0 ≤ E exp s(1 -ηtγ m )∥w t ∥ 2 2 + Csmη t a 2 γ + Csη 2 t a 2 d | F 0 . Notice that we can expand the recursion since s(1 -ηtγ m ) ≲ γ mηt-1a 2 (cf. proof of Lemma 12, Eq. (B.5)). Expanding the recursion yields, E exp s∥w t ∥ 2 2 | F 0 ≤ exp s t * t * + t 2 ∥w 0 ∥ 2 F + Csm 2 a 2 (t + d) γ 2 (t * + t) . Finally, we apply a Chernoff bound with the maximum choice of s ≲ γ mηta 2 , and combine it with the fact that ∥w 0 ∥ 2 ≲ 1 with probability at least 1 -exp(-Cd). Lemma 20. Suppose mab < 1 -|f (0)|. Then, we have |⟨w t , u⟩| ≳ 1 with probability at least 1 -2 exp(-Ct) -exp(-Cd). Proof. We will only prove for the case where f is increasing as the case for decreasing f is similar. We begin by proving an upper bound for d(w) when ∥w∥ ≲ 1. By the triangle inequality,  |ŷ -y| ≤ |ŷ| + |f (0)| + |f (⟨u, x⟩) -f (0)| + |ϵ|. d(w) = a E ∂ 2 12 ℓ(ŷ, y)σ ′ (⟨w, x⟩ + b)f ′ (⟨u, x⟩) ≲ -a E[1(|ϵ| ≲ 1)1(|⟨w, x⟩| ≲ 1)1(|⟨u, x⟩| ≲ 1)f ′ (⟨u, x⟩)] = -a E[1(|ϵ| ≲ 1)] E[1(|⟨w, x⟩| ≲ 1)1(|⟨u, x⟩| ≲ 1)f ′ (⟨u, x⟩)] ≲ -a. where the last line is obtained by considering supremum over ∥w∥ 2 ≲ 1. Let A t = {sup t0≤t ′ ≤t ∥w t ′ ∥ 2 ≲ 1}. Then, E exp -s w t+1 , u 1 At+1 ≤ E exp -s w t+1 , u 1 At = E exp -s w t , u + sη t ∇ℓ + λw t , u 1 At (a) ≤ E exp -s w t , u + sη t ⟨∇R λ , u⟩ + Cs 2 η 2 t a 2 1 At (b) = E exp -s(1 -η t (h(w t ) + λ)) w t , u + sη t (d(w t ) + Csη t a 2 ) 1 At (c) ≤ exp(-Csη t a) E exp -s(1 -η t (h(w t ) + λ)) w t , u 1 At , where (a) follows from the sub-Gausianity of the stochastic noise in the gradient, (b) follows since ⟨∇R λ (w t ), u⟩ = d(w t ) by definition, and (c) holds for s ≲ (η t a) -1 with a sufficiently small absolute constant. Notice that by the condition on t * inherited from Theorem 3, 1-η t (h(w t )+λ)) > 0, and since s(1 -η t (h(w t ) + λ)) ≤ s(1 -ηtγ m ), we can expand the recursion, E exp -s w t , u 1 At ≤ E   exp   -Csa t-1 i=t0 η i t-1 j=i+1 (1 - ηj γ m ) + s t-1 i=t0 (1 -ηiγ m )| w t0 , u |   1 At 0   ≤ E exp -Cs 1 - t * + t 0 t * + t 2 + Cs t * + t 0 t * + t 2 . where in the second inequality we used a ≍ m -1 and γ ≍ 1. Applying the Chernoff bound implies that ⟨w t , u⟩ ≳ 1 with probability at least 1 -P A C t -exp(-Ct) ≥ 1 -exp(-C(t * + t 0 )) - exp(-Cd) -exp(-Ct). Finally the result follows by letting t 0 = Ct for a sufficiently small absolute constant C. We have proven that |⟨w t , u⟩| ≳ 1 while ∥w t ⊥ ∥ 2 → 0. This fact shows that the features learned in the first layer are useful. What remains to be shown is an approximation result, such that for a carefully constructed second layer, the network can approximate polynomials of the desired type. This type of approximation using random biases has been adopted from Damian et al. (2022) . We first present an approximation result using infinite neurons. Lemma 21. Let 0 < |α| ≤ r and b ∼ Unif(-2r∆, 2r∆). For any smooth f : R → R, let fα : R → R be a smooth function such that fα (z) = f (z) for |z| ≤ r∆ |α| and fα (- 2r∆ α ) = f ′ α (-2r∆ α ) = 0. Then, for |z| ≤ ∆ we have E b 4r∆ α 2 f ′′ α - b α σ(αz + b) = f (z). Published as a conference paper at ICLR 2023 Proof. Using integration by parts, we have E b 4r∆ α 2 f ′′ α - b α σ(αz + b) = 2r∆ -αz f ′′ α (- b α )(z b α ) db α = -f ′ α (- 2r∆ α )(z + 2r∆ α ) + z - 2r∆ α f ′ α (b) db = fα (z) = f (z). Now, by a concentration argument, we state an approximation result with finitely many neurons. Lemma 22. Let r * ≤ |α j | ≤ r and b j ∼ Unif(-2r∆, 2r∆). Let ∆ * := ∆ sup j,|z|≤ 2r∆ r * | f ′′ αj (z)|, (C.7) where fαj is the extension of f αj introduced in Lemma 21. Then there exists a(α j , b j ) such that for any fixed z ∈ [-∆, ∆], with probability at least 1 -δ over the choice of (b j ), we have m j=1 a(α j , b j )σ(α j z + b j ) -f (z) ≲ r 2 ∆∆ * r * 2 log(1/δ) m . Moreover, ∥a∥ 2 ≲ r∆ * r * 2 √ m . Proof. Let fα (z) be a candidate in Lemma 21, which can be obtained by e.g. extending f with suitable polynomials (notice that fα only needs to be twice differentiable on its domain). Now choose In the following lemma, we will briefly record useful properties of W T which will be of help for invoking the above approximation results and providing guarantees when the second layer is optimized by SGD. Through the rest of the proof, we will add the mild assumption that d ≳ log(1/δ). Otherwise, we need to add e -Cd to the probability of failure in Theorem 4. Lemma 23. Suppose T ≳ d + log(1/δ). Then with probability at least 1 -δ over the choice of (b j ) 1≤j≤m and {(x (t) , y (t) )} T -1 t=0 , the following statements hold: a j = 4 r∆ α 2 j m f ′′ αj (- 1. ∥w T j ∥ 2 ≍ | w T j , u | ≍ 1 for all 1 ≤ j ≤ m. 2. ∥ 1 T T -1 t=0 x (t) (x (t) ) ⊤ ∥ 2 ≲ 1. 3. ∥W T ⊥ ∥ F ≲ m(d+log(1/δ)) T . 4. | u, x (t) | ≲ ∆ for all 0 ≤ t ≤ T -1. 5. ∥W T x (t) ∥ 2 ≲ √ m( √ d + ∆) for all 0 ≤ t ≤ T -1. Proof. We will show that each of the events holds with probability (w.p.) at least 1 -O(δ). Recall from Lemma 19 that ∥w T j ∥ 2 ≲ 1 for all j w.p. ≥ 1 -O(δ), which implies the same for w T j , u . On the other hand, from Lemma 20, | w T j , u | ≳ 1 for all j w.p. ≥ 1 -O(δ). Combining these events implies that | w T j , u | ≍ 1. The fact that ∥ 1 T T -1 t=0 x (t) (x (t) ) ⊤ ∥ 2 ≲ 1 w.p. ≥ 1 -O(δ) for T ≳ d + log(1/δ) follows from the statement of Lemma 30. Furthermore, ∥W T ⊥ ∥ F ≲ m(d+log(1/δ)) T w.p. 1 -O(δ) follows from Theorem 3. Note that as u, (t) ∼ N (0, 1), by the choice of ∆, we have u, x (t) ≳ ∆ w.p. ≤ O(δ/T ), thus | u, x (t) | ≲ ∆ for any 0 ≤ t ≤ T -1 w.p. ≥ 1 -O(δ) by a union bound. Finally, we have ∥W T x (t) ∥ 2 ≤ ∥W T ∥ x (t) ∥ 2 + ∥W T ⊥ x (t) ∥ 2 ≲ √ m| u, x (t) | + m(d + log(1/δ)) T ∥x (t) ∥ 2 ≲ √ m| u, x (t) | + √ m∥x (t) ∥ 2 The first term is already bounded by √ m∆ with probability at least 1 -O(δ). Moreover, recall that ∥x (t) ∥ 2 -E ∥x (t) ∥ 2 is 1-sub-Gaussian, thus by the union bound ∥x (t) ∥ 2 - √ d ≲ log(T /δ) ≲ ∆ for all 0 ≤ t ≤ T -1. Thus w.p. ≥ 1-O(δ) we have ∥W T x (t) ∥ 2 ≲ √ m( √ d+∆) which completes the proof. From this point onwards, we will denote the Huber loss with ℓ H (ŷ, y) = ℓ(ŷ -y). Notice that ℓ H is 1-Lischitz. Lemma 24. Recall R(W T , a, b) = 1 T T -1 t=0 ℓ H   m j=1 a j σ( w T j , x (t) + b j ) -f ( u, x (t) ) -ϵ (t)   , the empirical risk of W T given by Algorithm 1. Let ∆ ≍ log( T δ ), ∆ * as defined in (C.7), and b j i.i.d. ∼ Unif(-∆, ∆). Then, with probability at least 1 -δ (over the randomness of (b j ) 1≤j≤m and {x (t) , y (t) } T -1 t=0 hence W T ), for T ≳ d + log(1/δ), there exists a * with ∥a * ∥ 2 ≲ ∆ * √ m such that R(W T , a * , b) -E[ℓ H (ϵ)] ≲ ∆ * ∆ log(T /δ) m + ∆ * d + log(1/δ) T + ν log(1/δ) T . Proof. We will condition the following discussion on the event of Lemma 23. Let α j = w T j , u , and let a * be constructed according to Lemma 22. By the Lipschitzness of the Huber loss, for an inividual sample (x, y) we have ℓ H (ŷ(x; W T , a * , b) -f (⟨u, x⟩) -ϵ) ≤ ℓ H (ϵ) + |ŷ(x; W T , a * , b) -f (⟨u, x⟩)| ≤ ℓ H (ϵ) + |ŷ(x; W T , a * , b) -ŷ(x; W T ∥ , a * , b)| + |ŷ(x; W T ∥ , a * , b) -f (⟨u, x⟩)|. Moreover, by the Cauchy-Schwartz inequality |ŷ(x; W T , a * , b) -ŷ(x; W T ∥ , a * , b)| ≤ ∥a * ∥ 2 m j=1 σ( w T j , x + b j ) -σ( (w T j ) ∥ , x + b j ) 2 ≤ ∥a * ∥ 2 m j=1 (w T j ) ⊥ , x 2 . Additionally, since ∥ 1 T T -1 t=0 x (t) (x (t) ) ⊤ ∥ 2 ≲ 1, by Jensen's inequality, T -1 t=0 1 T |ŷ(x; W T , a * , b) -ŷ(x; W T ∥ , a * , b)| ≤ ∥a * ∥ 2 1 T T -1 t=0 ∥W T ⊥ x (t) ∥ 2 F ≲ ∥a * ∥ 2 ∥W T ⊥ ∥ F ≲ ∆ * d + log(1/δ) T On the other hand, let z (t) := u, x (t) ≲ ∆. Then, we can apply Lemma 22 along with a union bound, which states that with probability 1 -O(δ) over choice of (b j ) 1≤j≤m , 1 T T -1 t=0 |ŷ(x (t) ; W T ∥ , a * , b) -f ( u, x (t) )| ≤ 1 T T -1 t=0 | m j=1 a * j σ(α j z (t) + b j ) -f (z (t) )| ≲ ∆∆ * log(T /δ) m . Combining the events above, we have with probability at least 1 -δ, R(W T , a * , b) - 1 T T -1 t=0 ℓ H (ϵ (t) ) ≲ ∆ * ∆ log(T /δ) m + d + log(1/δ) T . The final step is to apply a concentration bound for t) ). Note that as ℓ H (ϵ) ≤ |ϵ|, if |ϵ| is ν-sub-Gaussian, then ℓ H (ϵ) -E[ℓ H (ϵ)] is also Cν-sub-Gaussian (can be verified e.g. by Lemma 28). Now we can analyze the SGD run on the second layer a to give a high probability statement for Rλ ′ (a T ). As Rλ ′ (a) is a smooth and strongly convex function of a, we will state the following well-known elementary convergence result of SGD for smooth and strongly convex functions with bounded noise, which we present in a high-probability framework suitable for our analysis. Lemma 26. Let R : R m → R be a µ-strongly convex function satisfying µI m ⪯ ∇ 2 a R(a) ⪯ LI m . Suppose we run the SGD iterates a t+1 = a t -η t g t with E[g t | a t ] = ∇ a R(a t ) and ∥g t ∥ 2 ≤ G. Choose η t = 2t+1 µ(t+1) 2 . Then with probability at least 1 -δ T -1 t=0 ℓ H (ϵ ( R(a T ) -R * ≤ R 0 T 2 + CLG 2 µ 2 T + CG 2 log(1/δ) µT , where R * = arg min a R(a). Proof. Let e t = g t -∇ a R(a t ) denote the stochastic noise. By the smoothness property of R, we have R(a t+1 ) -R * ≤ R(a t ) -R * -η t ∇ a R(a t ), ∇ R(a Finally, applying a Chernoff bound using s = (4η t-1 G 2 ) -1 concludes the proof. We are finally in a position to complete the proof of Theorem 4. Proof. [Proof of Theorem 4] We will consider the event of Lemma 23 on which from Lemma 24 we know with probabilility at least 1 -δ over the dataset and (b j ) 1≤j≤m we have m where 1 m is the vector of all ones. It is easy to observe that the results hold with high probability when a follows the initialization of Assumption 3 as well. Furthermore, we work on the event where ∥W 0 ∥ F ≤ 2 √ m, which happens with probability at least 1 -exp(-md/2). We begin by constructing a non-convex example for Proposition 2. For this example, we choose σ such that β 1 ≤ 1, σ(1) = σ(-1) = 0, σ ′ (1) = σ ′ (-1) = 0, and σ ′′ (-1) = σ ′′ (1) = β 2 = 1. An example of such a function is σ(z) = cos(πz)+1 π 2 . Then, using the computations of Lemma 8 we have where the above inequality holds for sufficiently large K, hence R λ (W ) is non-convex at least on a neighborhood around zero. ∇ 2 W R λ (W ) = E σ ′ a,b (W x)σ Next, we construct a non-convex example for the smooth and decaying step size case of Theorem 3. This time, we require σ(±1) = -β 0 = -1 (which automatically implies σ ′ (±1) = 0 as σ attains its minimum) and σ ′′ (±1) = β 2 = 1. For instance, we can choose σ(z) = cos(πz)-π 2 +1 π 2 . Then simplifying ∇ 2 R λ (0) yields ∇ 2 W R λ (0) = I m ⊗ (-I d -E yxx ⊤ ) m + λ m I md = I m ⊗ λ -1 m I d - E yxx ⊤ m .



Figure 1: Two-layer ReLU network with m = 1000, d = 2 is trained to recover a tanh singleindex model via SGD with weight decay. Initial neurons (red) converge to the principal subspace. 10% of student neurons are visualized.

Assumption 2.A (Smooth activation). The activation function σ satisfies |σ(z)|, |σ ′ (z)|, |σ ′′ (z)| ≤ 1 for all z ∈ R, the loss is ℓ(ŷ, y) = 1 2 (ŷ -y) 2 for simplicity, and y satisfies |y| ≤ K almost surely. Assumption 2.B (ReLU activation). The activation function σ is σ(z) = max(z, 0) for z ∈ R. The loss satisfies 0 ≤ ∂ 2 1 ℓ(ŷ, y) ≤ 1, |∂ 1 ℓ(ŷ, y)| ≤ 1, and |∂ 2 12 ℓ(ŷ, y)| ≤ 1.

g. Hardt et al. (2016); Soudry et al. (2018); Yun et al. (2021); Park et al. (2022).

Figure 2: Neurons fail to converge to the principal subspace without weight decay, in the same experimental setup of Figure 1.

, |ŷ| ≤ ma(|⟨w, x⟩| + b). Thus, for |⟨w, x⟩| ≤ 1 -|f (0)| -mab 2ma ∧ b, |⟨u, x⟩| ≲ 1 and |ϵ| ≲ 1 for sufficiently small absolute constants, we have |ŷ -y| ≤ 1 hence ∂ 2 12 ℓ(ŷ, y) = -1. Then we have,

αj ). Then Lemma 21 ensures that E bj [a(α j , b j )σ(α j z + b j )] = f (z). It immediately follows that ∥a∥ 2 ≤ Cr∆ * r * 2 √ m and |a j σ(αz+b j )| ≤ Cr 2 ∆∆ * r * 2 m . Applying the Hoeffding's inequality finishes the proof.

Then a sub-Gaussian concentration bound implies thatE[ℓ H (ϵ)] -1 T T -1 t=0 ℓ H (ϵ (t) ) ≲ ν log(1/δ)T , which finishes the proof.Let E S [•] denote expectation w.r.t. the random sampling of SGD used to train a, hence conditioned on {x(t) , y (t) } T -1 t=0 . Also, define the stochastic noise in the gradient w.r.t. a ase t a = ∇ a ℓ(ŷ(x (it) ; W T , a t , b) -y (it) ) -∇ a R(W T , a t , b). Notice that E S [e ta ] = 0. Lemma 25. On the event of Lemma 23 and with (b j )i.i.d.∼ Unif(-∆, ∆), consider the mapping a → Rλ ′ (a). Then, ∇ 2 a Rλ ′ (a) ≾ m∆ 2 + λ ′ , and ∥e t a For ∇ 2 a R(a), and any v ∈ R m with ∥v∥ 2 = 1, we have the following computation:ŷ, y)v ⊤ σ(W T x (t) + b)σ(W T x (t) + b) ⊤ v + λ ′ T x (t) + b)∥ 2 + λ ′ (a) ≲ ∥W T ∥ F + ∥b∥ 2 2 + λ ′ ≲ m∆ 2 + λ ′ where (a) holds since ∥ 1 T T -1 t=0 x (t) x (t) ⊤ ∥ 2 ≲ 1. Thus ∇ 2 a Rλ ′ (a) ≾ m∆ 2 + λ ′ . On the other hand, as ∥W T x (t) ∥ 2 ≲ √ m( √ d + ∆) for all 0 ≤ t ≤ T -1, we have ∥e t a ∥ 2 ≤ 2∥∇ a ℓ∥ 2 ≤ 2∥W T x (t) + b∥ 2 ≲

T , a, b) -E[ℓ H (ϵ)] ≲ ∆ 2 *Notice that a → R(W , a, b) is a convex function. Thus by strong duality, there exists λ ′ > 0 such that the value of the above constrained minimization problem is equal to the value of the following regularized minimization problem,min a Rλ ′ (W T , a, b) -E[ℓ H (ϵ)] ≲ ∆ 2 * λ ′ can be chosen such that the unique solution to ∇ a R(W T , a * , b) + λ ′ a * = 0 (C.8) has ∥a * ∥ 2 ≲ ∆ * √ m .Notice that this a * is the unique solution to arg min a Rλ ′ (W T , a, b). Moreover, from Lemma 26 we haveRλ ′ (W T , a T ′ , b) -Rλ ′ (W T , a * , b) ≲ R(W T , a 0 , b) T ′ 2 + (d + ∆ 2 )(∆ 2 + λ ′ /m + log(1/δ)) (λ ′ /m) 2 T ′ ,and by strong convexity∥a T ′ -a * ∥ 2 2 ≤ 2 m Rλ ′ (W T , a T ′ , b) -Rλ ′ (W T , a * , b) .Thus, with sufficiently large T ′ such that R(W T , a 0 , b)T ′ 2 + (d + ∆ 2 )(∆ 2 + λ ′ /m + log(1/δ)) (λ ′ /m) 2 T ′ ≲ ∆ 2 * d + log(1/δ) T ∧ λ ′ ∆ * √ m , (C.9) we have ∥a T ′ ∥ 2 ≲ ∆ * √ m and Rλ ′ (a T ′ ) -E[ℓ H (ϵ)] ≲ ∆ 2 * invokeTheorem 5, to close the generalization gap and getR τ (W T , a ′ , b) -E[ℓ H (ϵ)] ≲ ∆ 2 * NON-CONVEX R λ (W )Here, we outline examples for which R λ (W ) is non-convex on a neighborhood around W = 0 while λ = λ m satisfies the condition in Proposition 2 or Theorem 3. For simplicity of exposition, in both examples we fix a = 1m

′ a,b (W x) ⊤ + (ŷ(x; W , a, b) -y) diag(σ ′′ a,b (W x)) ⊗ xx ⊤ ⊤ + λI md = I m ⊗ -E yxx ⊤ m + λ m I d . Therefore, ∇ 2 R λ (0) is not positive semi-definite (PSD) if and only if -1 m E yxx ⊤ + λ m I d is not PSD. Moreover, by Jensen's inequality ŷ2 (x; W 0 , a, b) ≤ 1 m m i=1 (σ( w 0 j , x + b j ) -σ(b j )) 0 ) ≤ 2 E ŷ2 + 2 E y 2 ≤ 8 + 2 E y 2 .Now, let y = K⟨w, x⟩ 2 for some w with ∥w∥ 2 = 1. Then E y 2 = 3K 2 , and choosing λ = 1 + √ 9 + 6K 2 + ϑ for arbitrarily small ϑ suffices to satisfy the condition of Proposition 2. Then we havew ⊤ ∇ 2 W R λ (0)w = w ⊤ -E yxx ⊤

Lemma 12. Consider running the iterates of SGD (3.1), under either Assumptions 1&2.A or 1&2.B, with stepsize sequence {η t } t≥0 that is either constant η t = η or decreasing η t = m 2(t * +t)+1 γ(t * +t+1) 2 (cf.(Gower et al., 2019, Theorem 3.2)). Let κ := sup|∂ 1 ℓ(ŷ, y)|, κ := β 1 ∥a∥ 2 κ, and ρ := λ + β 2 1 ∥a∥ 2 2 + β 2 κ∥a∥ ∞ . Suppose η 0 ≲ ρ-1 . Let F t denote the sigma algebra generated by {W j } t j=0 , and let {A t } t≥0 be a sequence of decreasing events (i.e. A t+1 ⊆ A t ), such that A t ∈ F t and on A t we have H(W t ) + λI m ⪰ γ m I m . Then, for every t ≥ 0, with probability at least P

and ∥a∥ 1 ≤ 1, hence ρ ≍ λ, and with t * ≍ λ γ we can guarantee η t λ ≲ 1. As the step size condition of Lemma 12 is satisfied, the desired result follows.Similarly, for ReLU we have |∂ 1 ℓ(ŷ, y)| ≤ 1 by Assumption 2.B, and for λ ≥ γ/m + 2∥a∥∞

Rademacher random variables. Then similar to the Rademacher bound ofDamian al. (

t ) + e t + -η t ∥∇ a R(a t )∥ 2 2 -η t ∇ a R(a t ), e t + Notice that by Jensen's inequality, ∥∇ a R(a t )∥ 2 ≤ G, thus ∥e t ∥ 2 ≤ 2G and the zero-mean random variable ⟨∇ a R(a t ), e t ⟩ is 2G∥∇ a R(a t )∥ 2 -sub-Gaussian conditioned on a t . Now, we can establish the following recursive bound on the MGF of R t := R(a t ) -R * . For 0 ≤ s ≤ 1 4ηtG 2 we haveE e sR t+1 ≤ E exp sR t -sη t ∥∇ a R(a t )∥ 2 2 -sη t ∇ a R(a t ), e t + η 2 t LG 2 2 ≤ E exp sR t -sη t (1 -2sη t G 2 )∥∇ a R(a t )∥ 2 2 + ≤ E exp s(1 -η t µ)R t + LG 2 η 2 t 2where (a) follows since R(a) is strongly convex thus satisfies the Polyak-Łojasiewicz inequality 2µ(R(a) -R * ) ≤ ∥∇ a R(a)∥ 2 2 . As s(1 -η t µ) ≤

ACKNOWLEDGMENTS

The authors would like to thank Denny Wu for generating the figures, and both DW and Matthew S. Zhang for valuable feedback on the manuscript. This project was mainly funded by the CIFAR AI Catalyst grant. The authors also acknowledge the following funding sources: SP was supported by Institute of Information & communications Technology Planning & Evaluation (IITP) grant funded by the Korea government (MSIT) (No. 2019-0-00079, Artificial Intelligence Graduate School Program, Korea University). MG was supported by NSERC Grant [2022-04106], IM was supported by a Samsung grant and CIFAR AI Chairs program. Finally, MAE was supported by NSERC Grant [2019-06167] and CIFAR AI Chairs program.

E AUXILIARY LEMMAS

In order to be explicit, we state the following definitions and lemmas that will be used in the proof of Theorem 5. We only state the next definitions and lemmas and refer the reader to Wainwright (2019) and Vershynin (2018) for proof and more details. Definition 27. (Wainwright, 2019, Definitions 2.2  and 2 2 ), and is said to be u-sub-exponential if for all |s| ≤ 1 ν we have) . Lemma 28. (Vershynin, 2018, Propositions 2.5 .2 and 2.7.1) Suppose z is a zero-mean random variable andThen z is cL-sub-exponential for an absolute constant c > 0.Lemma 29. (Wainwright, 2019, Theorem 2.26) Lemma 30. (Wainwright, 2019, Example 6 .3) Let {x (i) } 1≤i≤n be a sequence of i.i.d. standard Gaussian random vectors x i ∼ N (0, I d ). It holds with probability at least 1 -δ thatwhere C is an absolute constant.The next lemma is the well-known symmetrization argument that upper bounds the expected value of an empirical process with Rademacher complexity. Lemma 31. (Mohri et al., 2018, Theorem 3. 3) Let F be a class functions f : R p → R for some p > 0. For a number of samples T and a probability distribution P on R p , define the Rademacher complexity of F aswhere {x (t) } T -1 t=0 i.i.d.∼ P and {ξ t } T -1 t=0 are independent Rademacher random variables (i.e. ±1 equiprobably). Then the following holds,Furthermore, we have the following fact for standard normal random vectors.

Lemma x ∼ N I d ).

There exists an constant C > 0 such that for any V ∈ R m×k and p ≥ 1 we haveProof. First of all, ∥V x∥ 2 is a ∥V ∥ 2 -Lipschitz function of x, thus Lemma 29 applies and ∥V x∥ 2 is sub-Gaussian. Furthermore, by applying Lemma 28 to ∥V x∥ 2 -x∥ 2 ] and Minkowski's inequality, we havewhere the last inequality follows from Jensen's inequality.Lemma 33. Let x ∼ N (0, I d ). Then E exp(c∥x∥ 2 2 ) ≤ exp(2cd) for c ≤ 1/4.Proof. Gaussian integration yields E exp(cx 2 i ) =1 √ 1-2c . Furthermore, for c ≤ 1 4 we have

