NEURAL NETWORKS EFFICIENTLY LEARN LOW-DIMENSIONAL REPRESENTATIONS WITH SGD

Abstract

We study the problem of training a two-layer neural network (NN) of arbitrary width using stochastic gradient descent (SGD) where the input x ∈ R d is Gaussian and the target y ∈ R follows a multiple-index model, i.e., y = g(⟨u 1 , x⟩, . . . , ⟨u k , x⟩) with a noisy link function g. We prove that the first-layer weights of the NN converge to the k-dimensional principal subspace spanned by the vectors u 1 , . . . , u k of the true model, when online SGD with weight decay is used for training. This phenomenon has several important consequences when k ≪ d. First, by employing uniform convergence on this smaller subspace, we establish a generalization error bound of O( kd/T ) after T iterations of SGD, which is independent of the width of the NN. We further demonstrate that, SGD-trained ReLU NNs can learn a single-index target of the form y = f (⟨u, x⟩) + ϵ by recovering the principal direction, with a sample complexity linear in d (up to log factors), where f is a monotonic function with at most polynomial growth, and ϵ is the noise. This is in contrast to the known d Ω(p) sample requirement to learn any degree p polynomial in the kernel regime, and it shows that NNs trained with SGD can outperform the neural tangent kernel at initialization. Finally, we also provide compressibility guarantees for NNs using the approximate low-rank structure produced by SGD.

1. INTRODUCTION

The task of learning an unknown statistical (teacher) model using data is fundamental in many areas of learning theory. There has been a considerable amount of research dedicated to this task, especially when the trained (student) model is a neural network (NN), providing precise and non-asymptotic guarantees in various settings (Zhong et al., 2017; Goldt et al., 2019; Ba et al., 2019; Sarao Mannelli et al., 2020; Zhou et al., 2021; Akiyama & Suzuki, 2021; Abbe et al., 2022; Ba et al., 2022; Damian et al., 2022; Veiga et al., 2022) . As evident from these works, explaining the remarkable learning capabilities of NNs requires arguments beyond the classical learning theory (Zhang et al., 2021) . The connection between NNs and kernel methods has been particularly useful towards this expedition (Jacot et al., 2018; Chizat et al., 2019) . In particular, a two-layer NN with randomly initialized and untrained weights is an example of a random features model (Rahimi & Recht, 2007) , and regression on the second layer captures several interesting phenomena that NNs exhibit in practice (Louart et al., 2018; Mei & Montanari, 2022) , e.g. cusp in the learning curve. However, NNs also inherit favorable characteristics from the optimization procedure (Ghorbani et al., 2019; Allen-Zhu & Li, 2019; Yehudai & Shamir, 2019; Li et al., 2020; Refinetti et al., 2021) , which cannot be captured by associating NNs with regression on random features. Indeed, recent works have established a separation between NNs and kernel methods, relying on the emergence of representation learning as a consequence of gradient-based training (Abbe et al., 2022; Ba et al., 2022; Barak et al., 2022; Damian et al., 2022) , which often exhibits a natural bias towards low-complexity models. A theme that has emerged repeatedly in modern learning theory is the implicit regularization effect provided by the training dynamics (Neyshabur et al., 2014) . The work by Soudry et al. (2018) has inspired an abundance of recent works focusing on the implicit bias of gradient descent favoring, in some sense, low-complexity models, e.g. by achieving min-norm and/or max-margin solutions despite the lack of any explicit regularization (Gunasekar et al., 2018; Li et al., 2018; Ji & Telgarsky, 2019; Gidel et al., 2019; Chizat & Bach, 2020; Pesme et al., 2021) . However, these works mainly consider linear models or unrealistically wide NNs, and the notion of reduced complexity as well as its implications on generalization varies. A concrete example in this domain is compressiblity and its connection to generalization (Arora et al., 2018; Suzuki et al., 2020) . Indeed, when a trained NN can be compressed into a smaller NN with similar prediction behavior, the resulting models exhibit similar generalization performance. Thus, the model complexity of the original NN can be explained by the smaller complexity of the compressed one, which is classically linked to better generalization. In this paper, we demonstrate the emergence of low-complexity structures during the training procedure. More specifically, we consider training a two-layer student NN with arbitrary width m where the input x ∈ R d is Gaussian and the target y ∈ R follows a multiple-index teacher model, i.e. y = g(⟨u 1 , x⟩, . . . , ⟨u k , x⟩; ϵ) with a link function g and a noise ϵ independent of the input. In this setting, we prove that the first-layer weights trained by online stochastic gradient descent (SGD) with weight decay converge to the k-dimensional subspace spanned by the weights of the teacher model, span(u 1 , , . . . , u k ), which we refer to as the principal subspace. Our primary focus is the case where the target values depend only on a few important directions along the input, i.e. k ≪ d, which induces a low-dimensional structure on the SGD-trained first-layer weights, whose impact on generalization is profound. First, convergence to the principal subspace leads to an improved bound on the generalization gap for SGD, independent of the width of the NN. In the specific case of learning a single-index target with a ReLU student network, we show that this convergence leads to useful features that improve upon the initial random features. Hence we prove that NNs can learn certain degree-p polynomials with a number of samples (almost) linear in d using online SGD, while learning a degree p polynomial with any rotationally invariant kernel, including the neural tangent kernel (NTK) at initialization, requires d Ω(p) samples (Donhauser et al., 2021) . We summarize our contributions as follows. • We show in Theorem 3 that NNs learn low-dimensional representations by proving that the iterates of online SGD on the first layer of a two-layer NN with width m converge to √ mε neighborhood of the principal subspace after O(d/ε 2 ) iterations, with high probability. The error tolerance of √ mε is sufficient to guarantee that the risk of SGD iterates and that of its orthogonal projection to the principal subspace are within O(ε) distance. • We demonstrate the impact of learning low-dimensional representations with three applications. -For a single-index target y = f (⟨u, x⟩) + ϵ with a monotonic link function f where f ′′ has at most polynomial growth, we prove in Theorem 4 that ReLU networks of width m can learn this target after T iterations of SGD with an excess risk estimate of Õ( d/T + 1/m), with high probability (see the illustration in Figure 1 ). In particular, the number of iterations is linear in the input dimension d, even when f is a polynomial of any (fixed) degree p. -Based on a uniform convergence argument on the principal subspace, we prove in Theorem 5 that T iterations of SGD will produce a model with generalization error of O( kd/T ), with high probability. Remarkably, this rate is independent of the width m of the NN, even in the case k ≍ d where the target is any function of the input, and not necessarily low-dimensional. -Finally, we provide a compressiblity result directly following from the low-dimensionality of the principal subspace. We prove that T iterations of SGD produce first-layer weights that are compressible to rank-k with a risk deviation of O( d/T ), with high probability. The rest of the paper is organized as follows. We discuss the notation and the related work in the remainder of this section. We describe the problem formulation and preliminaries in Section 2, and provide an analysis for the warm-up case of population gradient descent in Section 2.1. Our main result on SGD is presented in Section 3. We discuss three implications of our main theorem in Section 4, where we provide results on learnability, generalization gap, and compressibility in Sections 4.1, 4.2, and 4.3, respectively. We finally conclude with a brief discussion in Section 5.



Figure 1: Two-layer ReLU network with m = 1000, d = 2 is trained to recover a tanh singleindex model via SGD with weight decay. Initial neurons (red) converge to the principal subspace. 10% of student neurons are visualized.

