FLATTER, FASTER: SCALING MOMENTUM FOR OPTI-MAL SPEEDUP OF SGD

Abstract

Commonly used optimization algorithms often show a trade-off between good generalization and fast training times. For instance, stochastic gradient descent (SGD) tends to have good generalization; however, adaptive gradient methods have superior training times. Momentum can help accelerate training with SGD, but so far there has been no principled way to select the momentum hyperparameter. Here we study training dynamics arising from the interplay between SGD with label noise and momentum in the training of overparametrized neural networks. We find that scaling the momentum hyperparameter 1 -β with the learning rate to the power of 2/3 maximally accelerates training, without sacrificing generalization. To analytically derive this result we develop an architecture-independent framework, where the main assumption is the existence of a degenerate manifold of global minimizers, as is natural in overparametrized models. Training dynamics display the emergence of two characteristic timescales that are well-separated for generic values of the hyperparameters. The maximum acceleration of training is reached when these two timescales meet, which in turn determines the scaling limit we propose. We confirm our scaling rule for synthetic regression problems (matrix sensing and teacher-student paradigm) and classification for realistic datasets (ResNet-18 on CIFAR10, 6-layer MLP on FashionMNIST), suggesting the robustness of our scaling rule to variations in architectures and datasets.

1. INTRODUCTION

The modern paradigm for optimization of deep neural networks has engineers working with vastly overparametrized models and training to near perfect accuracy (Zhang et al., 2017) . In this setting, a model will typically have not just isolated minima in parameter space, but a continuous set of minimizers, not all of which generalize well. Liu et al. (2020) demonstrate that depending on parameter initialization and hyperparameters, stochastic gradient descent (SGD) is capable of finding minima with wildly different test accuracies. Thus, the power of a particular optimization method lies in its ability to select a minimum that generalizes amongst this vast set. In other words, good generalization relies on the implicit bias or regularization of an optimization algorithm. There is a significant body of evidence that training deep nets with SGD leads to good generalization. Intuitively, SGD appears to prefer flatter minima (Keskar et al., 2017; Wu et al., 2018; Xie et al., 2020) , and flatter minima generalize better (Hochreiter & Schmidhuber, 1997) . More recently, a variant of SGD which introduces "algorithmic" label noise has been especially amenable to rigorous treatment. In the overparametrized setting Blanc et al. (2020) were able to rigorously determine that SGD with label noise converges not just to any minimum, but to those minima that lead to the smallest trace norm of the Hessian. However, Li et al. (2022) show that the dynamics of this regularization happen on a timescale proportional to the inverse square of the learning rate η -much slower than the time to first converge to an interpolating solution. Therefore we consider the setting where we remain near the local minima, which is responsible for significant regularization after the initial convergence of the train loss (Blanc et al., 2020) . With the recent explosion in size of both models and datasets, training time has become an important consideration in addition to asymptotic generalization error. In this context, adaptive gradient methods such as Adam (Kingma & Ba, 2015) are unilaterally preferred over variants of SGD, even though they often yield worse generalization errors in practical settings (Keskar & Socher, 2017; Wilson et al., 2017) , though extensive hyperparameter tuning (Choi et al., 2019) or scheduling (Xie et al., 2022) can potentially obviate this problem. These two constraints motivate a careful analysis of how momentum accelerates SGD. Classic work on acceleration methods, which we refer to generally as momentum, have found a provable benefit in the deterministic setting, where gradient updates have no error. However, rigorous guarantees have been harder to find in the stochastic setting, and remain limited by strict conditions on the noise (Polyak, 1987; Kidambi et al., 2018) or model class and dataset structure Lee et al. (2022) . In this work, we show that there exists a scaling limit for SGD with momentum (SGDM) which provably increases the rate of convergence.

Notation.

In what follows, we denote by C n , for n = 0, 1, . . . the set of functions with continuous n th derivatives. For any function f , ∂f [u] and ∂ 2 f [u, v] will denote directional first and second derivatives along directions defined by vectors u, v ∈ R D . We may occasionally also write ∂ 2 f [Σ] = D i,j=1 ∂ 2 f [e i , e j ]Σ ij . Given a submanifold Γ ⊂ R D and w ∈ Γ, we denote by T w Γ the tangent space to Γ in w, and by P L (w) the projector onto T w Γ (we will often omit the dependence on w and simply write P L ). Given a matrix H ∈ R D × R D , we will denote by H the transpose, and H † the pseudoinverse of H.

1.1. HEURISTIC EXPLANATION FOR OPTIMAL MOMENTUM-BASED SPEEDUP

Deep neural networks typically posses a manifold of parametrizations with zero training error. Because the gradients of the loss function vanish along this manifold, the dynamics of the weights is completely frozen under gradient descent. However, as appreciated by Wei & Schwab (2019) and Blanc et al. (2020) , noise can generate an average drift of the weights along the manifold. In particular, SGD noise can drive the weights to a lower-curvature region which, heuristically, explains the good generalization properties of SGD. Separately, it is well-known that adding momentum typically leads to acceleration in training Sutskever et al. (2013) Below, we will see that there is a nontrivial interplay between the drift induced by noise and momentum, and find that acceleration along the valley is maximized by a particular hyperparameter choice in the limit of small learning rate. In this section we will illustrate the main intuition leading to this prediction using heuristic arguments, and defer a more complete discussion to Sec. 3. We model momentum SGD with label noise using the following formulation π k+1 = βπ k -∇L(w k ) + σ(w k )ξ k , w k+1 = w k + ηπ k+1 , where η is the learning rate, β is the (heavy-ball) momentum hyperparameter (as introduced by Polyak (1964)), w ∈ R D denotes the weights and π ∈ R D denotes momentum (also called the auxiliary variable). L : R D → R is the training loss, and σ : R D → R D×r is the noise function, whose dependence on the weights allows to model the gradient noise due to SGD and label noise. Specifically this admits modeling phenomena such as automatic variance reduction (Liu & Belkin, 2020) , and expected smoothness, satisfying only very general assumptions such as those developed by Khaled & Richtárik (2020) . Finally, ξ k ∈ R r is sampled i.i.d. at every timestep k from a distribution with zero mean and unit variance, and > 0. We will now present a heuristic description of the drift dynamics that is induced by the noise along a manifold of minimizers Γ = {w : L(w) = 0} ⊆ R D , in the limit → 0. In practice, this limit corresponds to choosing small strength of label noise and large minibatch size. Let us assume that the weights at initial time w 0 are already close to Γ, and π 0 = 0. Because of this, the gradients of L at w 0 are very small, and so only fluctuations transverse to the manifold will generate systematic drifts. Denoting by δw k = w 0 -w k the displacement of the weights after k timesteps, let us Taylor expand the first equation in (1) to get π k+1 = βπ k -∇ 2 L(w 0 )[δw k ] + σ(w 0 )ξ k . By construction, the Hessian ∇ 2 L(w 0 ) vanishes along the directions tangent to Γ, while in the transverse direction we have an Ornstein-Uhlenbeck (OU) process. The number of time steps it takes to this process to relax to the stationary state, or mixing time, is τ 1 = Θ (1/(1 -β)) as β → 1, which can be anticipated from the first equation in (1) since π k+1 -π k ∼ -(1 -β)π k (see Sec. 3.3 for a more detailed derivation). After this time, the variance of this linearized OU process becomes independent of the time step, and can be estimated to be (see Appendix F):  (δw T k ) δw T k = Θ 2 η/(1 -β) , L π k+1 = -P L k j=0 β k-j ∇L(w j ) = -1 2 k j=0 β k-j ∂ 2 (P L ∇L)[ (δw j ) δw j ] = -1 2 1 1-β ∂ 2 (P L ∇L)[ (δw j ) δw j ]. The variance of the transverse displacements therefore generates longitudinal motion, and the latter will then scale as P L π k+1 = Θ 2 η/(1 -β) 2 . Using the second equation in (1), we see that P L ( δw k+1 -δw k ) = P L η π k+1 . Define τ 2 to be the number of time steps it takes so that the displacement of w k along Γ becomes a finite number as we take → 0 first, and η → 0 afterward. From the above considerations, we then find that τ 2 = Θ (1 -β) 2 / 2 η 2 . The key observation now is that, when the weights are initialized near the valley, the convergence time is controlled by the largest timescale among τ 1 and τ 2 . While τ 2 decreases as β → 1, τ 1 increases. Therefore, the optimal speedup of training is achieved when these two timescales intersect τ 1 = τ 2 , which happens for 1 -β = Cη 2/3 . More generally, we consider a double scaling where β → 1 as η → 0 according to β = 1 -Cη γ . (2) We consider this scaling limit in this paper, and we thus find that the scaling power γ = 2/3 achieves the optimal speed-up along the zero-loss manifold.

1.2. LIMIT DRIFT-DIFFUSION

We now describe the rationale for obtaining the limiting drift-diffusion on the zero-loss manifold, for a process of the form (1) which foreshadows the rigorous results presented in section 3. As discussed above, the motion along the manifold is slow, as it takes Θ( -2 ) time steps to have a finite amount of longitudinal drift. We want to extract this slow longitudinal sector of the dynamics by projecting out the fast-moving components of the weights. For stable values of the optimization hyperparameters, the noiseless ( 2 = 0) dynamics (1) will map a generic pair (π, w), as k → ∞, to (0, w ∞ ), where w ∞ ∈ Γ. Define Φ : R D×D → R D to be this mapping, i.e. Φ(π, w) = w ∞ . As we now show, when > 0, Φ can be used precisely to project onto the slow, noise-induced longitudinal dynamics. Let us collectively denote x k = (π k , w k ) and write eq. ( 1) as x k+1 = x k + F (x k ) + σ(x t )ξ t . We can perform a Taylor expansion in to obtain Φ(x t+1 ) -Φ(x t ) = ∂Φ(x t )[ σ(x t )ξ t ] + ∂ 2 Φ(x t )[ σ(x t )ξ t , σ(x t )ξ t ] + • • • . Therefore, denoting Y (t = 2 k) = Φ(x k ), the limit dynamics as → 0 can be well-approximated by the continuous time equation dY = ∂Φ(x t )[σ(x t )ξ t ] √ dt + ∂ 2 Φ(x t )[σ(x t )ξ t , σ(x t )ξ t ]dt , where we interpret the time increment dt = 2 , and ξ 2 t = 1. Note that until here we have not taken a small learning rate limit. The learning rate can be finite, as far as the map Φ(π, x) exists. The small noise limit is sufficient to allow a continuous-time description of the limit dynamics because the noise-induced drift-diffusion along the valley requires Θ( -2 ) timesteps to lead to appreciable longitudinal displacements. A similar approach to what we just described was used in Li et al. (2022) , although in our case the limit drift-diffusion is obtained in the small noise limit, rather than small learning rate. The reason for our choice is that, since we scale β according to (2), the deterministic part of eq. ( 1) becomes degenerate as we take η → 0, in which case it would not possible to apply the mathematical framework of Katzenberger (1991) on which our results below rely. To further simplify our analysis, particularly the statement of Theorem B.4, we will further take η → 0 after taking → 0, and retain only leading order contributions in η. The main contributions of this paper are: 1. We develop a general formalism to study SGD with (heavy-ball) momentum in Sec. 3, extending the framework of Li et al. (2022) to study convergence rates and generalization with momentum. 2. We find a novel scaling regime of the momentum hyperparameter 1 -β ∼ η γ , and demonstrate a qualitative change in the noise-induced training dynamics as γ is varied. 3. We identify a special scaling limit, 1 -β ∼ η 2/3 , where training achieves a maximal speedup at fixed learning rate η. 4. In Sec. 4, we demonstrate the relevance of our theory with experiments on toy models (2layer neural networks, and matrix sensing) as well as realistic models and datasets (ResNet-18 on CIFAR10).

2. RELATED WORKS

Loss Landscape in Overparametrized Networks The geometry of the loss landscape is very hard to understand for real-world models. Choromanska et al. (2015) conjectured, based on empirical observations and on an idealized model, that most local minima have similar loss function values. Subsequent literature has shown in wider generality the existence of a manifold connecting degenerate minima of the loss function, particularly in overparametrized models. This was supported by work on mode connectivity (Freeman & Bruna, 2017; Garipov et al., 2018; Draxler et al., 2018; Kuditipudi et al., 2019) , as well as on empirical observations that the loss Hessian possesses a large set of (nearly) vanishing eigenvalues (Sagun et al., 2016; 2017) . In particular, Nguyen (2019) showed that for overparametrized networks with piecewise linear activations, all global minima are connected within a unique valley. The Implicit Regularization of SGD Wei & Schwab (2019) , assuming the existence of a zeroloss valley, observed that SGD noise leads to a decrease in the trace of the Hessian. Blanc et al. (2020) demonstrated that SGD with label noise in the overparametrized regime induces a regularized loss that accounts for the decrease in the trace of the Hessian. Damian et al. (2021) extend this analysis to finite learning rate. HaoChen et al. ( 2021) study the effect of non-isotropic label noise in SGD and find a theoretical advantage in a quadratic overparametrized model. Wu et al. (2022) show that only minima with small enough Hessians (in Frobenius norm) are stable under SGD. The specific regularization induced by SGD was found in quadratic models (Pillaud-Vivien et al., 2022) , 2-layer Relu networks (Blanc et al., 2020) , linear models (Li et al., 2022) , diagonal networks (Pesme et al., 2021) terov, 1983) and Heavy Ball (HB) or Polyak (Polyak, 1964) . We focus on the latter in this paper, which we refer to simply as momentum. Momentum provably improves convergence time in the deterministic setting. Intuitively, introducing β in 1 gives the motion in parameter space an effective "inertia" or memory, which promotes motion not strictly following the local gradient, but moving rather along the directions which persistently decrease the loss function across iterations Polyak (1964); Sutskever et al. (2013) . Less is known rigorously when stochastic gradient updates are used. Indeed, Polyak (1987) suggests the benefits of acceleration with momentum disappear with stochastic optimization unless certain conditions are placed on the properties of the noise. See also (Jain et al., 2018; Kidambi et al., 2018) for more discussion and background on this issue. Nevertheless, in practice it is widely appreciated that momentum is important for convergence and generalization (Sutskever et al., 2013) , and widely used in modern adaptive gradient algorithms Kingma & Ba (2015) . Some limited results have been obtained showing speedup in the mean-field approximation (Mannelli & Urbani, 2021) and linear regression Jain et al. (2018) . Modifications to Nesterov momentum to make it more amenable to stochasticity (Liu & Belkin, 2020; Allen-Zhu, 2017) , and near saddle points (Xie et al., 2022) have also been considered.

3.1. GENERAL SETUP

Following the line from section 1.2 In this and the following section, we will rigorously derive the limiting drift-diffusion equation for the weights on the zero-loss manifold Γ, and extract the timescale τ 2 associated to this noise-induced motion. In Sec. 3.3 we will then compare τ 2 to the timescale τ 1 associated to the noiseless dynamics and evaluate the optimal value of γ discussed around Eq. ( 2). We will use Eq. (1) to model momentum SGD. As illustrated in Sec. 1.1, the drift is controlled by the second moment of fluctuations, and we thus expect the drift timescale to be Θ( 2 ). We will then rescale time k = t/ 2 , so that the motion in the units of t is O(1) as → 0. More explicitly, take n to be a positive sequence such that n → 0 as n → ∞. For each n we consider the stochastic process that solves Eq. (1): X n (t) = X n (0) + t 0 σ(X n )dZ n + t 0 F (X n )dA n , with A n (t) = t 2 n , Z n (t) = n An(t) k=1 ξ k (5) and where X(t = 2 k) = (π k , w k ), σ(X) = (σ, ησ) and F (X) = ((β -1)π -∇L(w), η(βπ - ∇L(w))). x denotes the integer part of a real number x. See Appendix B.2 for a proof of equivalence between (1) and (4). Assumption 3.1. The loss function L : R D → R is a C 3 function whose first 3 derivatives are locally Lipschitz, σ is continuous, and Falconer, 1983) . Γ = {w ∈ R D : L(w) = 0} is a C 2 -submanifold of R D of dimension M , with 0 ≤ M ≤ D. Additionally, for w ∈ Γ, rank(∇ 2 L(w)) = D -M . Assumption 3.2. There exists an open neighborhood U of {0} × Γ ⊆ R D × R D such that the gradient descent starting in U converges to a point x = (π, w) ∈ {0} × Γ. More explicitly, for x ∈ U , let ψ(x, 0) = x and ψ(x, k + 1) = ψ(x, k) + F (ψ(x, k)), i.e. ψ(x, k) is the k th iteration of x + F (x). Then Φ(x) ≡ lim k→∞ ψ(x, k) exists and is in Γ. As a consequence, Φ ∈ C 2 on U (

3.2. LIMITING DRIFT-DIFFUSION IN MOMENTUM SGD

In this section we shall obtain the explicit expression for the limiting drift-diffusion. The general framework is based on Katzenberger (1991) (reviewed in Appendix B). Before stating the result, we will need to introduce a few objects. Consider the process in Eq. (4). Note that, while at initialization we can have X n (0) / ∈ Γ, the solution X n (t) → Γ as n → ∞, i.e. it becomes discontinuous. This is an effect of the speedup of time introduced around Eqs. (4),(5). To overcome this issue, it is convenient to introduce Y n (t) ≡ X n (t) -ψ(X n (0), A n (t)) + Φ(X n (0)), so that Y n (0) ∈ Γ is initialized on the manifold. Theorem 3.4 (Informal). Suppose the loss function L, the noise function σ, the manifold of minimizers Γ and the neighborhood U satisfy assumptions (3.1) and (3.2), and that X n (0) = X(0) ∈ U . Then, as n → 0, and subsequently taking η → 0, Y n (t) converges to Y (t), where the latter satisfies the limiting drift-diffusion equation dY =( 1 C η 1-γ + η)P L σdW -1 2C 2 η 2-2γ (∇ 2 L) † ∂ 2 (∇L)[Σ LL ]dt -1 C 2 η 2-2γ P L ∂ 2 (∇L)[(∇ 2 L) † Σ T L ]dt -1 2C 2 η 2-2γ P L ∂ 2 (∇L)[ L-1 ∇ 2 L Σ T T ]dt , where W (t) is a Wiener process. A rigorous version of this theorem is given in section B. The first term in Eq. ( 6) induces diffusion in the longitudinal direction. The second term is of geometrical nature, and is necessary to guarantee that Y (t) remains on Γ. The second line describes the drift induced by the transverse fluctuations. Eq. ( 6) resembles in form that found in Li et al. (2022) , although there are two crucial differences. First, time has been rescaled using the strength of the noise , rather than the learning rate. The different rescaling was necessary as the forcing term F in Eq. ( 4) depends non-homogeneously on η, and thus the theory of Katzenberger (1991) would not be directly applied had we taken the small learning rate limit. Second, and more crucially, the drift terms in Eq. ( 6) are proportional to η 2-2γ , which is a key ingredient leading to the change in hierarchy of the timescales discussed in Sec. 1.1. One final difference, is that the last term involves the operator LH instead of the Lyapunov operator. For γ < 1 2 , LH reduces to the Lyapunov operator L H at leading order in η, with L H S ≡ {H, S}. For γ > 1 2 , however, we cannot neglect the η-dependent term in LH (see discussion at the end of Appendix C). Corollary 3.5. In the case of label noise, i.e. when, for w ∈ Γ, Σ = c∇ 2 L , for some constant c > 0, Eq. ( 6) reduces to dY = - 2 η 2-2γ 4C 2 P L ∇ Tr(c∇ 2 L)dt , where we have rescaled time back to t = k, i.e. we performed t → t 2 .

3.3. SEPARATION OF TIMESCALES AND OPTIMAL MOMENTUM SCALING

The above results provide the estimate for the timescale τ 2 of the drift along the zero-loss valley. As discussed in Sec. 1.1, training along the zero-loss manifold Γ is maximally accelerated if this time scale is equal to the timescales τ 1 for relaxation of off-valley perturbations. As we take → 0, this relaxation is governed by the nonzero eigenvalues of the Hessian as well as by the learning rate η and momentum β. Therefore we expect τ 1 = Θ( 0 ), and this will be confirmed by the analysis below. It will be therefore sufficient to obtain the leading order expression of τ 1 by focusing on the noiseless = 0 dynamics. Additionally, since we are interested in local relaxation, it will suffice to look at the linearized dynamics around Γ. Working in the extended phase space x k = (π k , w k ), and linearizing Eq. ( 1) around a fixed point x * = (0, w * ), with w * ∈ Γ, the linearized update rule is δx k+1 = J(x * )δx k , where δx k = x k -x * and J(x * ) is the Jacobian evaluated at the fixed point (with the explicit form given in Eq. ( 133)). Denote by q i the eigenvector and λ i the corresponding eigenvalue of the Hessian. We show in Appendix E that the Jacobian is diagonalized by the eigenvectors k i ± = µ i ± q i , q i with eigenvalues κ i ± = 1 2 1 + β -ηλ i ± (1 + β -ηλ i ) 2 -4β , and µ i ± = βηκ i ± -1-ηλ i . We proceed to study the decay rate of the different modes of the Jacobian to draw conclusions about the characteristic timescales of fluctuations around the valley. Longitudinal motion: On the valley, the Hessian will have a number of "zero modes" with λ i = 1. These lead to two distinct modes in the present setting with momentum, which we distinguish as pure and mixed. The first pure longitudinal mode is an exact zero mode which has κ i + = 1 with k i + = (0, q i ), corresponding to translations of the parameters along the manifold, and keeping π = 0 at its fixed point value. The second mode is a mixed longitudinal mode with κ i -= β with k i -= (-(1 -β)/(2βη)q i , q i ). This mode has a component of π along the valley, which must subsequently decay because the equilibrium is a single point π = 0. Therefore, this mode decays at the characteristic rate β for π, gleaned directly from Eq. (1). Transverse motion: When the w and π are perturbed along the transverse directions q i with positive λ i , the relaxation behavior exhibits a qualitative change depending on β. Using the scaling function β = 1-Cη γ , for small learning rate, the spectrum is purely real for γ < 1/2, and comes in complex conjugate pairs for γ > 1/2. This leads to two distinct scaling behaviors for the set of timescales. Defining a positive c 1 ≤ min{λ i |λ i > 0}, we find: 1) For γ < 1/2, transverse modes are purely decaying as (1 -Cη γ ) k ≤ |δx T k | ≤ (1 -(c 1 /C)η 1-γ ) k , with the lower bound set by the mixed longitudinal mode. For γ > 1/2, the transverse modes are oscillatory but with an envelope that decays like |δx T,env k | ≈ (1 -Cη γ ) k/2 . We leave the derivation of these results to Appendix (E). Collecting these results, we can describe the hierarchy of timescales τ 1 in the deterministic regime as a function of γ (excluding the pure longitudinal zero mode): τ -1 1 γ < 1/2 γ > 1/2 Long. η γ η γ Transv. η 1-γ , η γ η γ , η γ These are illustrated schematically in Fig. 1(a) , where the finite timescales are shown as a function of γ. We compare these "equilibration" timescales τ 1 , i.e. characteristic timescales associated with relaxation back to the zero-loss manifold, with the timescale τ 2 ∼ η 2(γ-1) associated with driftdiffusion of the noise-driven motion along the zero-loss manifold Eq. ( 6). For small γ, the timescale associated with the drift-diffusion along the valley is much faster than that associated with the relaxation of the dynamics toward steady state. Transverse and mixed longitudinal fluctuations relax much faster than the motion along the valley, and produce an effective drift toward the minimizer of the implicit regularizer. However, the timescales collide at γ = 2/3, suggesting a transition to a qualitatively different transport above this value, where the transverse and the mixed longitudinal dynamics, having a long timescale, will disrupt the longitudinal drift Eq. ( 6). This leads us to propose γ = 2 3 as the optimal choice for SGD training. We consistently find evidence for such a qualitative transition in our experiments below. In addition, we see that speedup of SGD with label noise is in fact maximal at this value where the timescales meet.

3.4. A SOLVABLE EXAMPLE

In this section we analyse a model that will allow us to determine, on top of the optimal exponent γ = 2 3 , also the prefactor C. We will specify to a 2-layer linear MLP, which is sufficient to describe the transition in the hierarchy of timescales described above, and is simple enough to exactly compute C. We will show in Sec. 4.1 that C depends only mildly on the activation function. We apply a simple matching principle to determine C, by asking that the deterministic timescale τ 1 is equal to the drift-diffusion timescale τ 2 . In the previous section, we found the critical γ = 2/3 by requiring these timescales have the same scaling in η. In order to determine C, we need more details of the model architecture. Definition 3.6 (UV model). We define the UV model as a 2-layer linear network parametrized by For a training dataset D = {(x a , y a ) a = 1, ..., P }, the dataset covariance matrix is Σ ij = 1 P P a=1 x a i x a j , and the dataset variance is µ 2 = trΣ. For mean-squared error loss, it is possible to explicitly determine the trace of the Hessian (see Appendix D). SGD with label noise introduces y a → y a + ξ t where ξ 2 t = 1, from which we identify σ µ,ja = P -1 ∇ µ f j (x a ), where µ runs over all parameter indices, j ∈ [m], and a ∈ [P ]. With this choice, the SGD noise covariance satisfies σσ T = P -1 ∇ 2 L. Equipped with this, we may use Corollary 3.5 with c = P -1 to determine the effective drift (presented in Appendix D). For the vector UV model, the expression simplifies to f (x) = 1 √ n U V x ∈ R m , where x ∈ R d , V ∈ R n×d dY = -τ -1 2 Y dt, with τ -1 2 = η 2-2γ 2 µ 2 2nP C 2 . ( ) The timescale of the fast initial phase τ -1 1 = (C/2)η γ follows from the previous section. Then requiring τ 1 = τ 2 implies not only γ = 2/3, but C = 2 µ 2 P n 1/3 . ( ) One particular feature to note here is that C will be small for overparametrized models and/or training with large datasets.

4.1. 2-LAYER MLP WITH LINEAR AND NON-LINEAR ACTIVATIONS

The first experiment we consider is the vector UV model analyzed in Sec. 3.4. Our goal with this experiment is to analyze a simple model, and show quantitative agreement with theoretical expectations. Though simple, this model shows almost all of the important features in our analysis: the 2(1 -γ) exponent below γ = 2 3 , the γ exponent above 2 3 , and the constant C theoretically evaluated in Sec. 3.4. To this end, we extract the timescales at different values of γ and show them in Fig. 1 . We train on an artificially generated dataset D = {(x a , y a )} 5 a=1 with x ∼ N (0, 1) and y = 0. We use the full dataset to train. From Eq.( 124) we know that the norm of the weights follows an approximately exponential trajectory as it approaches the widest minimum (U = V = 0). We therefore measure convergence timescale, T c , by fitting an exponential ae -t/Tc to the squared distance from the origin, |U| 2 + |V| 2 . To extract the scaling of T c with γ we perform SGD label noise with learning rates η ∈ [10 -3 , 10 -1 ] and corresponding momentum parameters β = 1 -Cη γ . We fit the timescale to a power-law in the learning rate T c (η, γ) = T 0 η -α(γ) (see Fig. 1(b) ). Imposing that T 0 be independent of γ, as predicted by theory, we found the numerical value C ≈ 0.2, which is consistent with the theoretical estimate of C = 0.17 from Sec. 3.4. We find consistency with prediction across all the values of γ we simulated. Note that, for γ > 2 3 the timescale estimate fluctuates more which is a consequence of having a slower timescale for the transverse modes. As discussed at the end of Sec. 3.3, such slowness disrupts the drift motion along the manifold. γ = 2 3 is clearly the optimal scaling. We repeated the same experiments using nonlinear activations, specifically we considered tanh and ReLU acting on the first layer. The timescales for tanh are shown in Fig. 1(c ), and we refer the reader to Appendix A for the ReLU case. The optimal scaling value is still γ = 2 3 for both tanh and ReLU and the optimal C remains close to 0.2.

4.2. RESNET18 ON CIFAR10

We now verify our predictions on a realistic problem, which will demonstrate the robustness of our analysis. We focus on ResNet18 (He et al., 2016) , specifically implemented by Liu (2021) , classifier trained on CIFAR10 Krizhevsky et al. (2009) . We aim to extrapolate the theory by showing optimal acceleration with our hyperparameter choice once training reaches an interpolating solution. To this purpose, we initialize the network on the valley, obtained starting from a random weight values and training the network using full batch gradient descent without label noise and with a fixed value β = 0.9 until it reaches perfect training accuracy. With this initialization, we then train with SGD and label noise for a fixed number of epochs multiple times for various values the momentum hyperparameter β. Finally, we project the weights back onto the valley before recording the final test accuracy. This last step can be viewed as noise annealing and allows us to compare the performance of training the drift phase for the different values of β. From this procedure we extract the optimal momentum parameter β * (η) that maximizes the best test accuracy during training as a function of the learning rate, which we can then compare with the theoretical prediction. As shown in Fig. 2 (b), 1-β * follows the power law we predicted almost exactly. The optimal choice for speedup does not have to coincide with the optimal choice for generalization. Strikingly, this optimal choice of scaling also leads to the best generalization in a realistic setting! This can be easily interpreted if we assume that the more we decrease the Hessian the better our model generalizes and by applying the fact that our scaling leads to the fastest transport along the manifold. The second important point is the value of the constant C ≈ 0.1 found as the coefficient of the power-law fit. If we set η = 1 this corresponds to setting β * = 0.9 which is the traditionally recommended value. The result here, can therefore be viewed as a generalization of this common wisdom. For more experiments, we refer the reader to Appendices A.2 and A.4.

5. CONCLUSION

We studied the implicit regularization of SGD with label noise and momentum in the limit of small noise and learning rate. We found that there is an interplay between the speedup of momentum and the limiting diffusion generated by SGD noise. This gives rise to two characteristic timescales associated to the training dynamics, and the longest timescale essentially governs the training time. Maximum acceleration is thus reached when these two timescales coincide, which lead us to identifying an optimal scaling of the hyperparameters. This optimal scaling corresponds not only to faster training but also to superior generalization. More generally, we have shown how momentum can significantly enrich the dynamics of learning with SGD, modulating between qualitatively different phases of learning characaterized by different timescales and dynamical behavior. It would be interesting to explore our scaling limit in statistical mechanical theories of learning to uncover further nontrivial effects on feature extraction and generalization in the phases of learning (Jacot et al., 2018; Roberts et al., 2021; Yang et al., 2022) . For future work, it will be very interesting to generalize this result to adaptive optimization algorithms such as Adam and its variants, to use this principle to design new adaptive algorithms, and to study the interplay between the scaling we found and the hyperparameter schedule.

A EXPERIMENTAL DESIGN

A.1 UV MODEL In our experiments with the vector UV model we aim to extract how the timescale of motion scales with η for each γ, therefore we train the model over a sweep of both of these parameters. As a reminder, the loss of the UV model is L = 1 2P P i=1 y i - 1 √ n u • vx i 2 ( ) where P is the dataset size and u, v are n-dimensional vectors. We set y i = 0 for simplicity, and with label noise this becomes y i (t) = • ξ i (t) for i.i.d standard Gaussian distributed ξ i . We initialize x i once (not online) as i.i.d. random standard Gaussian variables. For all experiments we choose = 1 2 . We initialize with u i , v i i.i.d. standard Gaussian distributed and keep this initialization constant over all our experiments to reduce noise associated with the specific initialization. For each value of η, β we train with label noise SGD and momentum until |u| 2 + |v| 2 < εn with ε = 0.1, thereby obtaining a time series for each of u(t) and v(t). We extract the timescale by fitting log |u| 2 + |v| 2 to a linear function and taking the slope.

A.2 MATRIX SENSING

We also explore speedup for a well understood problem: matrix sensing. The goal is to find a low-rank matrix X * given the measurements along random matrices A i : y i = Tr A i X * . Here X * ∈ R d×d is a matrix of rank r (Soudry et al., 2017; Li et al., 2018) . Blanc et al. (2020) analyze the problem of matrix sensing using SGD with label noise and show that if X * is symmetric with the hypothesis X = U U for some matrix U , then gradient descent with label noise corresponds not only to satisfying the constraints y i = Tr A i X, but also to minimizing an implicit regularizer, the Frobenius norm of U , which eventually leads to the ground truth. In the analogous U V matrix model (with X * is an asymmetric matrix of low rank r), we demonstrate a considerable learning speedup by adding momentum, and show that this speedup is not monotonic with increasing β; there is a value β * at which the acceleration appears optimal. This non-monotonicity with an optimal β * is observed for both the Hessian trace and the expected test error. Assuming that in this setting we also have γ = 2/3, we can extract C * = (1 -β * )/η 2/3 ≈ 0.24 P -1/3 , which compares favorably to the upper bound we may extract from Appendix D of ≈ 0.12 P -1/3 . In the experiments with matrix sensing we aim to demonstrate the benefit of momentum in a popular setting. Matrix sensing corresponds to the following problem: Given a target matrix X * ∈ R d×d of low rank r d and measurements {y i = Tr A i X * } P i=1 how can we reconstruct X * ? One way to solve this problem is to write our guess X = U V the product of two other matrices, and do stochastic gradient descent on them, hoping that the implicit regularization induced by this parametrization and the learning algorithm will converge to a good low rank X.

A.2.1 EXPERIMENTAL DETAILS

In our experiments we study the d = 100, r = 5, P = 5rd = 2500 case. We draw (A i ) ij ∼ N (0, 1) as standard Gaussians and choose X * by drawing first (X 0 ) ij ∼ N (0, 1) and then performing SVD and projecting onto the top r singular values by zeroing out the smaller singular values in the diagonal matrix. We intitalize U = V = I d . We perform SGD with momentum on the time dependent loss (with label noise depending on time) L(t) = 1 dP P i=1 ( • ξ i (t) + y i -Tr(A i U V )) 2 where 2 = 0.1, ξ i (t) ∼ N (0, 1). We choose η = 0.1 for all of our experiments. The inset shows that the orange curve crosses below the blue curve before convergence of the Hessian. Therefore, the same value of β is optimal for both the Hessian and the expected test error -increasing or decreasing β from this value slows down generalization. The hessian of the loss is defined to be the Hessian averaged over the noise. Equivalently we may just set ξ i (t) = 0 when we calculate the Hessian because averaging over the noise decouples the noise. Similarly when we define the expected test loss we define it as an average over all A i setting ξ i (t) = 0 in order to decouple the noise. Averaging over ξ i (t) and A i would simply lead to an additional term ξ i (t) 2 which would simply contribute a constant. We remove this constant for clarity. As a result, the expected test that we plot is proportional to the squared Frobenius norm of the difference between the model U V and the target X * , L = 1 d ||U V -X * || 2 F . It is also interesting to note that we observe epoch-wise double descent Nakkiran et al. (2020) in this problem. In particular, we observe that the peak in the test error can be controlled by the momentum hyperparameter, and becomes especially pronounced for β → 1.

A.3 RESNET18 ON CIFAR10

We train our model in three steps: full batch without label noise until 100% test accuracy, SGD with label noise and momentum, and then a final projection onto the interpolating manifold. The model we use is the ResNet18 and we train on the CIFAR10 training set. The first step is full batch gradient descent on the full CIFAR10 training set of 50,000 samples. We train with a learning rate η = 0.1 and momentum β = 0.9 and a learning rate schedule with linearly increases from 0 to η over 600 epochs, after which it undergoes a cosine learning rate schedule for 1400 more epochs stopping the first time the network reaches 100% test accuracy which happened on epoch 1119 in our run. This model is saved and the same one is used for all future runs. The loss function we use is cross cross entropy loss. Because we will choose a label noise level of p = 0.2 which corresponds to a uniformly wrong label 20% of the time, during this phase of training we train with the expected loss over this randomness. Notice that this loss is actually linear in the labels so taking the expectation is easy. The second step involves starting from the initialization in step 1 and training with a different learning rate and momentum parameter. In this step we choose the same level of label noise p = 0.2 but take it to be stochastic. Additionally we use SGD instead of gradient descent with a batch size of 512. This necessitates decreasing the learning rate because noise is greatly increased as demonstrated in the main text. In this step we train for a fixed 200 epochs for any learning rate momentum combination. We only compare runs with the same learning rate value. We show an example of the test accuracy as we train in phase 2 for η = 0.001 in figure 4 . Notice that for both too-small and too-large momentum values that the convergence to a good test accuracy value is slower. The initial transient with the decreased test accuracy happens as we start on the valley and adding noise coupled with momentum causes the weights to approach their equilibrium distribution about the valley. For wider distributions the network is farther from the optimal point on the valley. As training proceeds we see that the test accuracy actually increases over the baseline as the hessian decreases and the generalization capacity of the network increases. This happens most quickly for the momentum which matches our scaling law. The final step is a projection onto the zero loss manifold. This step is necessary because the total width of the distribution around the zero loss manifold scales with 1 1-β , and this will distort the results systematically at larger momentum, making them look worse than they are. We perform this projection to correct for this effect and put all momenta on an equal footing. This projection is done by training on the full batch with Adam and a learning rate of η = 0.0001 in order to accelerate training. We do not expect any significant systematic distortion by using Adam for the projection instead of gradient descent. To determine the optimal value of β we sweep several values of β and observe the test accuracy after the previously described procedure. To get a more precise value of β instead of simply selecting the one with the highest test accuracy we fit the accuracy A(β) to A(β) = a max + a 1 (β -β * ) if β ≤ β * a 2 (β -β * ) if β ≥ β * (14) to the parameters a max , a 1 , a 2 , and β * , thereby extracting β * for each η.

A.4 MLP ON FASHIONMNIST

We perform a similar experiment as in section A.3 but with different model and dataset: a 6-layer MLP trained on FashionMNIST (Xiao et al., 2017) as our dataset. We perform the experiment with a 6 layer MLP with Relu activation after the first 4 layers, tanh activation after the fifth layer, and a Figure 5 : Scaling analysis for a 6-layer MLP on the Fashion MNIST dataset. We see that the exponent, though not very well determined, is consistant with our theoretical prediction of 2 3 . linear mapping to logits. We use cross entropy loss with label noise p = 0.2 as before, and always start training from a reference point initialized on the zero loss manifold. This point was obtained by gradient descent on the expected loss from a random initialized point with learning rate η = 0.002 and without momentum. After this we sweep η ∈ [10 -5 , 10 -4 ] and β ∈ [.95, 1 -2η] and train for 600 epochs with label noise and 400 without label noise. This allows us to obtain a test accuracy as a function of η and β and therefore we can obtain the best momentum hyperparameter, β * , as a function of η as in A.3. We extract the scaling exponent by doing a linear fit between log(1 -β * ) and log η. The scaling analysis is shown in Fig. 5 , and shows that the exponent is consistent with our theory. B REVIEW OF RELEVANT RESULTS FROM KATZENBERGER (1991) In this Appendix we summarize the relevant conditions and theorems from Katzenberger (1991) that we use to prove our result on the limiting drift-diffusion. We refer to Katzenberger (1991) for part of the definitions and conditions cited throughout the below. In what follows, (Ω n , F n , {F n t } t≥0 , P ) will denote a filtered probability space, Z n an R r -valued cadlag {F n t }-semimartingale with Z n (0) = 0, A n a real-valued cadlag {F n t }-adapted nondecreasing process with A n (0) = 0, and σ : U → R D×r a continuous function, where U is a neighborhood of {0} × Γ as defined in the main text. Also, X n is an R D valued cadlag {F n t }-semimartingale satisfying X n (t) = X n (0) + t 0 σ n (X n )dZ n + t 0 F (X n )dA n (15) for all t ≤ λ n (K) and all compact K ⊂ U , where λ n (K) = inf{t ≥ 0|X n (t-) = KorX n (t) = K} (16) be the stopping time of X n (t) to leave K, the interior of K. For cadlag real-valued seimimartingales X, Y let [X, Y ](t) be defined as the limit of sums n-1 i=0 (X(t i+1 ) -X(t i ))(Y (t i+1 -Y (t i )) where 0 = t 0 < t 1 < • • • < t n = t and the limit is in probability as the mesh size goes to zero. If X is an R D -valued semimartingale, we write [X] = D i=1 [X i , X i ]. ( ) Condition B.1. For every T > > 0 and compact K ⊂ U inf 0≤t≤T ∧λn(K)-) (A n (t + ) -A n (t)) → ∞ (19) as n → ∞ where the infimum of the empty set is taken to be ∞. Condition B.2. For every compact K ⊂ U {Z λn(K) n } satisfies the following: For n ≥ 1 let Z n be a {F n t }-semimartingale with sample paths in D R d [0, ∞). Assume that for some δ > 0 allowing δ = ∞ and every n ≥ 1, there exist stopping times {τ k n |k ≥ 1} and a decomposition of Z n - J δ (Z n ) into a local martingale M n plus a finite variation process F n such that P [τ k n ≤ k] ≤ 1/k, {[M n ](t ∧ τ k n ) + T t∧τ k n (F n )|n ≥ 1} is uniformly integrable for every t ≥ 0 and k ≥ 1 and lim γ→0 lim sup n→∞ P sup 0≤t≤T (T t+γ (F n ) -T t (F n )) > = 0 (20) for every > 0 and T > 0. Also as n → ∞ and for any T > 0 sup 0<t≤T ∧λn(K) |∆Z n (t)| → 0 (21) Condition B.3. The process Zn (t) = 0<s≤t ∆Z n (s)∆A n (s) (22) exists, is an {F n t }-semimartingale, and for every compact K ⊂ U , the sequence { Zλn(K) n } is relatively compact and satisfies Condition 4.1 in Katzenberger (1991 ). Theorem B.4 (Theorem 7.3 in Katzenberger (1991) ). Assume that Γ is C 2 and for every y ∈ Γ, the matrix ∂F (y) has D -M eigenvalues in D(1). Assume (B.1),(B.2) and (B.3) hold, Φ is C 2 (or F is LC 2 ) and X n (0) ⇒ X(0) ∈ U . Let Y n (t) = X n (t) -ψ(X(0), A(t)) + Φ(X(0)) (23) and, for a compact K ⊂ U , let µ n (K) = inf{t ≥ 0|Y n (t-) / ∈ K or Y n (t) / ∈ K} . ( ) Then for every compact K ⊂ U , the sequence {(Y Katzenberger (1991) for details about the topology). If (Y, Z, µ) is a limit of this sequence then (Y, Z) is a continuous semimartingale, Y (t) ∈ Γ for every t almost surely, µ ≥ inf{t ≥ 0|Y (t) / ∈ K} almost surely, and µn(K) n , Z µn(K) n , µ n (K))} is relatively compact in D R 2D×r [0, ∞) × [0, ∞] (see Y (t) = Y (0) + t∧µ 0 ∂Φ(Y )σ(Y )dZ + 1 2 ijkl t∧µ 0 ∂ ij Φ(Y )σ ik (Y )σ jl (Y )d[Z k , Z l ] . (25) B.1 APPLYING THEOREM B.4 Recall the equations of motion of stochastic gradient descent π (n) k+1 = βπ (n) k -∇L(w (n) k ) + n σ(w (n) k )ξ k , w (n) k+1 = w (n) k + ηπ (n) k+1 , where σ(w) ∈ R D×r is the noise function evaluated at w ∈ R D , and ξ k ∈ R r is a noise vector drawn i.i.d. at every timestep k with zero mean and unit variance. We now show that this equation satisfies all the properties required by Theorem B.4. The manifold Γ is the fixed point manifold of (non-stochastic) gradient descent. {0} × Γ is a C 2 manifold because Γ is C 2 , which follows from assumption 3.1. The flow F (w, π) = (η(βπ -∇L(w)), βπ -∇L(w)). As shown in Appendix E, dF has exactly M zero eigenvalues on Γ∩K. F inherits the differentiable and locally Lipschitz properties from ∇L, and therefore satisfies the conditions of B.4. Next, notice that the noise function σ : R 2D → R 2D×r is continuous because σ is. Now we define A n and Z n (as in the main text) so that X n reproduces the dynamics in equation ( 26), except with a new time parameter t = k 2 n . A n (t) = t 2 n , Z n (t) = n An(t) k=1 ξ k (27) So that, with these choices, Eq. ( 15) precisely corresponds to ( 26), up to the rescaling t = k 2 n . Now we show that A n , Z n satisfy the conditions of B.4. Clearly A n (0) = Z n (0) by definition. Then This shows that we satisfy the conditions of Theorem B.4, therefore we have the following Lemma B.5. The SGD equations formulated as in ( 26) satisfy all the conditions of Theorem B.4. A n (t + ε) -A n (t) = t + ε 2 n - t 2 n ≥ t + ε -t 2 n -2 = ε 2 n -2 → ∞

B.2 EQUIVALENCE BETWEEN EQ. (1) AND EQ. (4)

Eq. ( 4) is the rewriting of eq. ( 1) in the form presented in Katzenberger (1991) , on which our theory is based. We show the equivalence below, and include the definitions of A and Z to keep this section self-contained. The statement is that Eq. (1), i.e. π k+1 = βπ k -∇L(w k ) + σ(w k )ξ k , w k+1 = w k + ηπ k+1 , can be rewritten in the form of eq. ( 4) X n (t) = X n (0) + t 0 σ(X n )dZ n + t 0 F (X n )dA n , A n (t) = t 2 n , Z n (t) = n An(t) k=1 ξ k (31) where X n (t) = (π k , w k ), with the correspondence t = k 2 n , where k denotes the SGD time step. Also, σ(X) = (σ, ησ) and F (X) = ((β -1)π -∇L(w), η(βπ -∇L(w))). Before we begin the proof, we will need the following fact (see Sec. 2 of Katzenberger (1991)): t s f dg = t s f dg c + s<r≤t f (r-)∆g(r) , ( ) where f and g are càdlàg functions, in particular they are right-continuous and have left limits everywhere. The integral above is done with respect to the measure dg, the differential of a function g. The sum is taken over all r ∈ (s, t] where g is discontinuous and the notation ∆g(r) = g(r)g(r-) indicates the discontinuity of g at r, where g(r-) = lim u→r -g(u) indicates the left limit of g at r. Finally g c denotes the continuous part of g g c (t) = g(t) - 0<s≤t ∆g(s), t ≥ 0. We now show by induction that X n (t = k 2 n ) solves the first equation above. Note that dA n (t) = ∞ k=-∞ δ(t -2 n k), dZ n (t) = n ∞ k=∞ ξ k δ(t -2 n k) . For brevity we will drop the subscript n and let k 2 = t. Consider X(t + 2 ) = t+ 2 0 σ(X(s))dZ(s) + t+ 2 0 F (X(s))dA(s) (35) = X(t) + t+ 2 t σ(X(s))dZ(s) + t+ 2 t F (X(s))dA(s) (36) = X(t) + t<s≤t+ 2 σ(X(s-))∆Z(s) + t<s≤t+ 2 F (X(s-))∆A(s) where in the last step we use eq. ( 32) as A, Z, σ and F are càdlàg, and that dA c = dZ c = 0. The sums are taken over all s ∈ (k 2 , (k + 1) 2 ] where Z(s) and A(s) are discontinuous. The only point of discontinuity of A and Z in this interval is at s = t + 2 , so X(t + 2 ) = X(t) + σ(X((t + 2 )-))ξ k+1 + F (X((t + 2 )-)) because the jumps of Z and A at s = t + 2 are ξ k+1 and 1, respectively. Now we must determine the left limit of X at t + 2 . Notice that for 0 < δ < X(t + δ 2 ) = X(t) + t<s≤t+δ 2 σ(X(s-))∆Z(s) + t<s≤t+δ 2 F (X(s-))∆A(s) (39) = X(t) because A and Z are continuous on (t, t + δ 2 ]. Hence the left limit of X((t + 2 )-) = X(t). Putting these equations together we find that X((k + 1) 2 ) = X(t + 2 ) = X(k 2 ) + σ(X(k 2 ))ξ k+1 + F (X(k 2 )) which, using the definition of X(k 2 ) = (π k , w k ), is eq. ( 29), thus proving equivalence.

C EXPLICIT EXPRESSION OF LIMITING DIFFUSION IN MOMENTUM SGD

In this section we provide the proof of Theorem 3.4. Recall that Φ satisfies Φ(x + F (x)) = Φ(x) , where x = (π, w), and F given in (132). To obtain the explicit expression of the limiting driftdiffusion, according to Theorem (B.4), and keeping into account Assumption (3.2), we need to determine Φ up to its second derivatives. To this aim, we shall expand Eq. ( 42) up to second order in the series expansion in F . In components, this reads: η∂ wi Φ j (π i -g i ) + ∂ πi Φ j (-Cη γ π i -g i ) + 1 2 η 2 ∂ wi ∂ w k Φ j (π i -g i )(π k -g k ) + 1 2 η∂ wi ∂ π k Φ j (π i -g i )(-Cη γ π k -g k ) + 1 2 η∂ πi ∂ w k Φ j (-Cη γ π i -g i )(π k -g k ) + 1 2 ∂ πi ∂ π k Φ j (-Cη γ π i -g i )(-Cη γ π k -g k ) = 0 subject to the boundary condition Φ(π, w)| w∈Γ,π=0 = w . Here, g i = ∂ i L. In Eq. ( 43), we already substituted β = 1 -Cη γ . In what follows, we shall find Φ to leading order in η as η → 0. We will solve the above problem by performing a series expansion in Φ around a point w ∈ Γ and π = 0 up to second order: Φ(π, w + δw) = Φ 00 + Φ i 01 δw i + Φ i 10 π i + 1 2 Φ ij 02 δw i δw j + Φ ij 11 π i δw j + 1 2 Φ ij 20 π i π j + • • • . (45) For example, we have Φ i 01 = ∂Φ ∂w i , Φ ij 11 = ∂ 2 Φ ∂π i ∂w j , Φ ij 20 = ∂ 2 Φ ∂π i ∂π j . More precisely, we regard this as an expansion in powers of δw and π: Φ 00 is zeroth order, Φ 10 , Φ 01 are first order, and the remaining terms are second order. We will occasionally write explicitly the index of Φ, e.g. Φ k,ij 02 = ∂Φ k ∂w i ∂w j . It will be useful to introduce the longitudinal projector onto Γ, P L (w) : R D → T w Γ, defined such that for any vector v ∈ T w Γ: P L v = v. The transverse projector is then P T = Id -P L . We will also decompose various tensors using these projectors, e.g. Φ ij 11 = Φ ij 11LL + Φ ij 11LT + Φ ij 11T L + Φ ij 11T T , with Φ ij 11LT = Φ kl 11 P ik L P jl T . Note that the Hessian H = ∇g ∈ R D×D satisfies HH † = P T , where H † denotes the pseudoinverse. At zeroth order, we obviously have Φ 00 (w) = w. Lemma C.1. The first order terms in the series expansion (45) are given by Φ 01T = Φ 10T = 0, Φ 01L = P L , Φ 10L = C -1 η 1-γ P L . Proof. Suppose ŵ(s) ∈ R D is a curve lying on Γ. Then due to the boundary condition (44), ∂ s Φ(π = 0, ŵ) = Φ i 01 ∂ s ŵi = ∂ s ŵi . This means that Φ 01L = P L . Now from (43), ηΦ i 01 (π i -∂ j g i δw j ) + Φ i 10 (-Cη γ π i -∂ j g i δw j ) = 0 . This condition should hold for any π and δw, therefore we arrive at Φ 10 = C -1 η 1-γ Φ 01 and (ηΦ i 01 + Φ i 10 )∂ k g i = 0 . Decomposing into longitudinal and transverse components, and noting that the Hessian satisfies P L H = 0, the above equation becomes ηΦ i 01T + Φ i 10T = 0 , which together with (50) and the above discussion gives Φ 01T = Φ 10T = 0, Φ 01L = P L , Φ 10L = C -1 η 1-γ P L , which concludes the first order analysis. We now need a Lemma, which requires Definition (3.3) in the main text. We report it here for convenience: Definition C.2. For a symmetric matrix H ∈ R D × R D , and W H = {Σ ∈ R D × R D : Σ = Σ , HH † Σ = H † HΣ = σ}, we define the operator LH : W H → W H with LH S ≡ {H, S} + 1 2 C -2 η 1-2γ [[S, H], H], with [S, H] = SH -HS. Lemma C.3. The inverse of the operator LH is unique. Proof. Let us go to a basis where H is diagonal, with eigenvalues λ i . In components, the equation LH S = M reads (λ i + λ j )S ij + 1 2 C -2 η 1-2γ (λ i -λ j ) 2 S ij + C -1 η 1-γ M ij = 0 which has a unique solution, with S ij = -C -1 η 1-γ λ i + λ j + 1 2 C -2 η 1-2γ (λ i -λ j ) 2 -1 M ij . Lemma C.4. The second order terms in the series expansion (45) are given by Φ j,ik 02LL = -(H † ) jl ∂ L i H ln P nk L , Φ j,ik 02LT = -P jl L ∂ L i H ln (H † ) nk , Φ j,ik 02T T = O η min{0,1-2γ} (56) Φ j,ik 11LL = -C -1 η 1-γ (H † ) jl ∂ L i H ln P nk L , Φ j,ik 11T L = -C -1 η 1-γ P jl L ∂ L k H ln (H † ) ni (57) Φ j,ik 11T T = -C -1 η 1-γ L-1 H (M (j) ) ik - 1 2 C -3 η 2-3γ [H, L-1 H M (j) ] ik (58) Φ j,ik 20LL = - 1 2 C -2 η 2-2γ (H † ) jl ∂ L i H ln P nk L + (H † ) jl ∂ L k H ln P ni L (59) Φ j,ik 20T L = - 1 2 C -2 η 2-2γ P jl L ∂ L k H ln (H † ) ni + P jl L ∂ L i H ln (H † ) nk (60) Φ j,ik 20T T = -C -2 η 2-2γ L-1 H (P T ∂ L j HP T ) ik where (M (j) ) kl = P ki T ∂ L j H ni P nl T (62) Proof. Consider a path ŵ(s) lying on Γ. From P T H = H we have dP T ds H = (Id -P T ) dH ds = P L dH ds , and thus, using HH † = P T , we find dP T ds P T = P L dH ds H † , P T dP T ds = H † dH ds P L , where the second equation is obtained from the first by taking the transpose. Putting the last two relations together, we find dP L ds = - dP T ds = -P T dP T ds - dP T ds P T = -H † dH ds P L -P L dH ds H † . From ( 48) we can then write ∂ L i Φ j,k 01 = -(H † ) jl ∂ L i H ln P nk L -P jl L ∂ L i H ln (H † ) nk , where ∂ L i = P ij L ∂ j . This leads to Φ j,ik 02LL = P kl L ∂ L i Φ j,l 01 = -(H † ) jl ∂ L i H ln P nk L (67) Φ j,ik 02LT = P kl T ∂ L i Φ j,l 01 = -P jl L ∂ L i H ln (H † ) nk . Also note that Φ j,ik 02LT = Φ j,ki 02T L , so the only component still to be determined in Φ 02 is Φ 02T T . The next step is to expand Eq. ( 43) to second order: π i π k η(1 + Cη γ )Φ ki 11 -Cη γ (1 -1 2 Cη γ )Φ ik 20 + 1 2 η 2 Φ ik 02 +δw k π i ηΦ ik 02 -η 2 H kj Φ ji 02 -ηΦ il 11 H kl -Cη γ Φ ik 11 -η(1-Cη γ )Φ (ij) 11 H kj -(1-Cη γ )Φ li 20 H kl -δw k δw l ηΦ ik 02 H li + 1 2 η 2 Φ ji 02 H lj H ki + Φ ik 11 H li + ηΦ ji 11 H kj H li + 1 2 (ηΦ i 01 + Φ i 10 )∂ k H li + 1 2 Φ ji 20 H lj H ki = 0 , where A (ij) = 1 2 (A ij + A ji ) denotes the symmetric part. Neglecting various terms that are subleading at small η, gives π i π k ηΦ ki 11 -Cη γ Φ ik 20 + 1 2 η 2 Φ ik 02 +δw k π i ηΦ ik 02 -Cη γ Φ ik 11 -Φ li 20 H kl -δw k δw l ηΦ ik 02 H li + Φ ik 11 H li + 1 2 (ηΦ i 01 + Φ i 10 )∂ k H li + 1 2 Φ ji 20 H lj H ki = 0 , The first line immediately gives Φ ij 20 = C -1 η 1-γ Φ (ij) 11 + 1 2 C -1 η 2-γ Φ ij 02 . Taking the second line of (70) and projecting onto the longitudinal part the index k gives ηβΦ ik 02LL -Cη γ Φ ik 11LL = 0 (72) ηβΦ ik 02T L -Cη γ Φ ik 11T L = 0 and using ( 67),(68), Φ j,ik 11LL = C -1 η 1-γ Φ j,ik 02LL = -C -1 η 1-γ (H † ) jl ∂ L i H ln P nk L (74) Φ j,ik 11T L = C -1 η 1-γ Φ j,ik 02T L = -C -1 η 1-γ P jl L ∂ L k H ln (H † ) ni Projecting the second line of (70) on the transverse part of the index k we have ηΦ j,ik 02LT -Cη γ Φ j,ik 11LT -Φ j,li 20T L H kl = 0 (76) ηΦ j,ik 02T T -Cη γ Φ j,ik 11T T -Φ j,li 20T T H kl = 0 Using ( 71) and keeping (68) into account, Eq. ( 76) becomes, neglecting subleading terms in η, -ηP jl L ∂ L i H ln (H † ) nk -Cη γ Φ ik 11LT - 1 2 C -1 η 1-γ Φ li 11T L H kl - 1 2 C -1 η 1-γ Φ il 11LT H kl = 0 . ( ) Using ( 75) in ( 78), -ηP jl L ∂ L i H ln (H † ) nk -Cη γ Φ ik 11LT + 1 2 C -2 η 2-2γ P jp L ∂ L i H pn (H † ) nl H kl - 1 2 C -1 η 1-γ Φ il 11LT H kl = 0 , and simplifying, -ηP jl L ∂ L i H ln (H † ) nk -Cη γ Φ ik 11LT + 1 2 C -2 η 2-2γ P jl L ∂ L i H ln P nk T - 1 2 C -1 η 1-γ Φ il 11LT H kl = 0 . ( ) which determines Φ 11LT in close form. Indeed, the above has the form, in matrix notation - 1 2 C -1 η 1-γ Φ 11LT H -Cη γ Φ 11LT = M , and can be immediately inverted to solve for Φ 11LT . The only two undetermined components now are Φ 11T T and Φ 02T T . One condition is obtained from (77) which gives, keeping (71) into account, ηΦ j,ik 02T T -Cη γ Φ j,ik 11T T -C -1 η 1-γ Φ j,(li) 11T T H kl = 0 , and thus Φ j,ik 02T T = Cη γ-1 Φ j,ik 11T T + C -1 η -γ Φ j,(ni) 11T T H kn . (83) Further taking symmetric and antisymmetric part in ik of the above gives Φ j,ik 02T T =Cη γ-1 Φ j,(ik) 11T T + 1 2 C -1 η -γ Φ j,(ni) 11T T H kn + 1 2 C -1 η -γ Φ j,(nk) 11T T H in (84) 0 =Cη γ-1 Φ j,[ik] 11T T + 1 2 C -1 η -γ Φ j,(ni) 11T T H kn - 1 2 C -1 η -γ Φ j,(nk) 11T T H in . The other condition comes from the third line of (70) which, using ( 48) and ( 71), and neglecting subleading terms in η, gives ηΦ j,ik 02 H li + ηΦ j,il 02 H ki + Φ j,ik 11 H li + Φ j,il 11 H ki + C -1 η 1-γ P ji L ∂ k H li = 0 Projecting the k and l indices on the longitudinal part, gives P ji L ∂ L k H li P ln L = 0, which is an identity that can be checked from (63). Projecting k on the longitudinal part and l on the transverse part gives ηΦ j,ik 02T L H li + Φ j,ik 11T L H li + C -1 η 1-γ P ji L ∂ L k H ni P nl T = 0 , which is implied by (75), indeed (ηΦ j,ik 02T L + Φ j,ik 11T L )H li = C -1 η 1-γ Φ j,ik 02T L H li = -C -1 η 1-γ P jl L ∂ L k H ln (H † ) ni H li = -C -1 η 1-γ P jl L ∂ L k H ln P nl T , therefore ( 87) does not give a new condition. The only new condition comes from projecting the k and l indices of (86) on the transverse direction, giving ηΦ j,ik 02T T H li + ηΦ j,il 02T T H ki + Φ j,ik 11T T H li + Φ j,il 11T T H ki + C -1 η 1-γ P ji L ∂ T k H ni P nl T = 0 . (89) Plugging in (83), Cη γ Φ j,ik 11T T H li + C -1 η 1-γ Φ j,(in) 11T T H kn H li + Cη γ Φ j,il 11T T H ki + C -1 η 1-γ Φ j,(in) 11T T H ln H ki +Φ j,ik 11T T H li + Φ j,il 11T T H ki + C -1 η 1-γ P ji L ∂ T k H ni P nl T = 0 . ( ) Neglecting subleading terms in η this can be rewritten as 2C -1 η 1-γ Φ j,(in) 11T T H kn H li + Φ j,ik 11T T H li + Φ j,il 11T T H ki + C -1 η 1-γ ∂ L j H ni P ki T P nl T = 0 . ( ) where we recall that the last term is symmetric in k and l. In matrix notation this reads C -1 η 1-γ H(Φ j 11T T + Φ jt 11T T )H + HΦ j 11T T + Φ jt 11T T H + C -1 η 1-γ M (j) = 0 , (M (j) ) kl = P ki T ∂ L j H ni P nl T (93) is the longitudinal derivative of the Hessian, projected on the transverse directions. Decomposing Φ j 11T T into symmetric and antisymmetric parts Φ j 11T T = S (j) + A (j) , Eq. ( 92) reads (suppressing the index j) 2C -1 η 1-γ HSH + HS + SH + HA -AH + C -1 η 1-γ M = 0 . (94) The first term is subleading in η, therefore we have HS + SH + HA -AH + C -1 η 1-γ M = 0 . (95) This equation together with (85), which in matrix form reads C -1 η -γ (SH -HS) + 2Cη γ-1 A = 0 , (96) determine S and A, and thus Φ 11T T . Note that M j has only a longitudinal component in the index j, therefore the transverse parts of S and A vanish, i.e.

P pj

T Φ j,ik 11T T = 0 . (97) Eq. ( 97) is natural as the slow degrees of freedom are the longitudinal coordinates. Note that these two equations admit a unique solution. Indeed, solving (96) for A gives A = 1 2 C -2 η 1-2γ [H, S] Plugging in (95), we find LH S = -C -1 η 1-γ M , ( ) where LH S ≡ {H, S} + 1 2 C -2 η 1-2γ [[S, H], H] is introduced in definition C.2. By Lemma C.3, (99) admits a unique solution. Then Φ j 11T T = -C -1 η 1-γ L-1 H M (j) - 1 2 C -3 η 2-3γ [H, L-1 H M (j) ] . The other components are (see eqs. ( 67),( 68),( 74),( 75),( 83),( 71)) Φ j,ik 02LL = -(H † ) jl ∂ L i H ln P nk L (101) Φ j,ik 02LT = -P jl L ∂ L i H ln (H † ) nk (102) Φ j,ik 11LL = -C -1 η 1-γ (H † ) jl ∂ L i H ln P nk L (103) Φ j,ik 11T L = -C -1 η 1-γ P jl L ∂ L k H ln (H † ) ni (104) Φ j,ik 02T T = Cη γ-1 Φ j,ik 11T T + C -1 η -γ Φ j,(ni) 11T T H kn = O η min{0,1-2γ} (105) Φ ij 20 = C -1 η 1-γ Φ (ij) 11 + 1 2 C -1 η 2-γ Φ ij 02 = C -1 η 1-γ Φ (ij) 11 . The only contributing term at leading order in η to the limiting diffusion equation is (106). Splitting it into longitudinal and transverse components, we find: Φ j,ik 20LL = C -1 η 1-γ Φ j,(ik) 11LL = - 1 2 C -2 η 2-2γ (H † ) jl ∂ L i H ln P nk L + (H † ) jl ∂ L k H ln P ni L (107) Φ j,ik 20T L = C -1 η 1-γ Φ j,(ik) 11T L = - 1 2 C -2 η 2-2γ P jl L ∂ L k H ln (H † ) ni + P jl L ∂ L i H ln (H † ) nk and, using ( 100) and ( 93) we have, in matrix notation, Φ j 20T T = 1 2 C -1 η 1-γ (Φ j 11T T + (Φ j 11T T ) T ) = -C -2 η 2-2γ L-1 H (P T ∂ L j HP T ) . To write things more compactly, the following Lemma will be useful: Lemma C.5. For any transverse symmetric matrix T : Φ j 20T T [T ] = - 1 2 C -2 η 2-2γ M (j) [ L-1 H T ] , Proof. From (61), Φ 20T T is proportional to the symmetric part of Φ 11T T which was denoted by S below Eq. ( 93) and satisfies Eq. ( 99). Therefore Φ 20T T also satisfies Eq. ( 99), up to an overall factor: LH Φ j 20T T = - 1 2 C -2 η 2-2γ M (j) . Then, for any T : ( LH Φ j 20T T )[T ] = - 1 2 C -2 η 2-2γ M (j) [T ] . Moreover, ( LH Φ j 20T T )[T ] = Φ j 20T T [ LH T ] . Since the two above equations hold for any T and LH is invertible, this implies Φ j 20T T [T ] = - 1 2 C -2 η 2-2γ M (j) [ L-1 H T ] , which is the statement of the lemma. To leading order in η we then have, for a symmetric matrix V , ∂ 2 Φ[V ] = D i,j=1 ∂i ∂j ΦV ij , and thus ∂ 2 Φ[V ] = - 1 2C 2 η 2-2γ (∇ 2 L) † ∂ 2 (∇L)[V LL ]dt - 1 C 2 η 2-2γ P L ∂ 2 (∇L)[(H † )V T L ]dt - 1 2C 2 η 2-2γ P L ∂ 2 (∇L)[ L-1 H V T T ] , where V LL = P L V P L , V T L = P T V P L , and V T T = P T V P T are transverse and longitudinal projections of V . We also have, using (48), and to leading order in η, ∂ΦσdZ = Φ 10 σdZ + ηΦ 01 σdZ = (C -1 η 1-γ + η)P L σdW , ( ) where dW is a Wiener process. Applying Theorem B.4, and keeping into account that, to leading order in dt, d[Z i , Z j ] = δ ij dt, we find dY =(C -1 η 1-γ + η)P L σdZ - 1 2C 2 η 2-2γ (∇ 2 L) † ∂ 2 (∇L)[Σ LL ]dt - 1 C 2 η 2-2γ P L ∂ 2 (∇L)[(H † )Σ T L ]dt - 1 2C 2 η 2-2γ P L ∂ 2 (∇L)[ L-1 H Σ T T ]dt , where Σ = σσ T . For γ < 1 2 , LH reduces to the Lyapunov operator at leading order in η, i.e. LH S = {H, S}. For γ > 1 2 , from (55), it is easy to see that the role of the divergent term proportional to η 1-2γ , when acting L-1 H on S, is to set to zero the off-diagonal entries of S ij at O(η 1-γ ), i.e. S ii = -2C -1 η 1-γ λ i M ii , S i =j = 0, . Using Lemma B.5, we finally conclude the following Corollary, which is the formal version of Theorem 3.4 in the main text: Corollary C.6. Consider the stochastic process defined in Eq. ( 26) parametrized by n , with initial conditions (π 0 , w 0 ) ∈ U , under assumptions 3.1 and 3.2. Fix a compact K ⊂ U . Then the conclusions of Theorem B.4 apply, and Y (t) satisfies the limiting diffusion equation dY =( 1 C η 1-γ + η)P L σdW -1 2C 2 η 2-2γ (∇ 2 L) † ∂ 2 (∇L)[Σ LL ]dt -1 C 2 η 2-2γ P L ∂ 2 (∇L)[(∇ 2 L) † Σ T L ]dt -1 2C 2 η 2-2γ P L ∂ 2 (∇L)[ L-1 ∇ 2 L Σ T T ]dt , where W (t) is a Wiener process. Let us now see the special case of label noise. In this case Σ = cH, so that Σ is only transverse. Moreover, using (55), L-1 H H = 1 2 P T , and dY = - 1 4C 2 η 2-2γ P L ∂ 2 (∇L)[cP T ]dt = - 1 4C 2 η 2-2γ P L ∇Tr(c∂ 2 L)dt . This proves Corollary 3.5 in the main text.

D EFFECTIVE DRIFT IN UV MODEL

We start with mean-square loss L = 1 2P P a=1 ||f (x a ) -y a || 2 , with the data covariance matrix Σ as defined in the text. The trace of the Hessian on the zero loss manifold (L(w * ) = 0) is given explicitly by trH = 1 n mTr (ΣV V ) + Tr Σ Tr (U U ) . Taking gradients of this and plugging into Corollary 3.5, repeated in Eq. ( 121) leads to an explicit expression for the drift-diffusion along the manifold dY = - 2 η 2-2γ 4P C 2 2 n ŜL Y dt, Ŝ = TrΣ 1 n 0 0 mΣ , ŜL = P L ŜP L , The simplification we cite in the main text in Sec. 3.4 is due to the fact that for input and output dimension d = m = 1, we have that Σ = TrΣ = µ 2 , and Ŝ is proportional to the identity. For matrix sensing, in order to compute the trace of the Hessian, we use (12) with ξ t = 0 and with slightly different notation (V instead of V ). The loss is then L = 1 P d P i=1 y i -Tr(A i U V ) 2 , and define the data covariance matrices Σ1 = 1 P P i=1 A i A i ∈ R d×d , Σ2 = 1 P P i=1 A i A i ∈ R d×d . Then the trace of the Hessian is TrH = 2 d Tr Σ2 U U + Σ1 V V . ( ) We find the noise function is σ µi = 2 P d ∇ µ f (A i ) ∈ R d 4 ×P where f (A) = Tr(U V A). Since the Hessian on the zero loss manifold is H µν = 2 P d i ∇ µ f (A i )∇ ν f (A i ), we see that σσ T = (2/P d)H. Therefore, we get dY = - η 2-2γ 4C 2 4 2 P d 2 ŜL Y dt, Ŝ = Σ2 0 0 Σ1 , ŜL = P L ŜP L To get a crude estimate of the coefficient C in the main text, we approximate the top eigenvalue of Ŝ with 1 d Tr Σ1 . With this, we get for η -2+2γ τ -1 2 = 2 C 2 P d 2 1 dP P i=1 TrA i A i ≈ 2 C 2 P d 2 d a 2 ij = 2 C 2 dP a 2 ij . Here we have denoted by a ij an arbitrary element of the data matrices A, with brackets signifying an average over the distribution of these elements. Assuming the fast initial fast remains the same, and τ -1 1 ≈ (C/2)η γ , we get C 3 = 2 2 dP a 2 ij ( ) For the values used in our experiments, this gives C ≈ 0.12 × P -1/3 .

E LINEARIZATION ANALYSIS OF MOMENTUM GRADIENT DESCENT IN THE SCALING LIMIT

Here we elaborate on the discussion in Sec. 3.3, providing derivations of various results. We take a straightforward linearization of the deterministic (noise-free) gradient descent with momentum. Working in the extended phase space x = (π, w), the dynamical updates are of the form x t+1 = x t + F (x t ), F (x t ) = (β -1)π t -∇L(w t ) η(βπ t -∇L(w t )) . (132) The fixed point of the evolution x * = (0, w * ) will have the momentum variable π = 0, and the coordinate satisfying ∇L(w * ) = 0. Linearizing the update (132) around this point δx t+1 = J(x * )δx t , J(x * ) = β -∇ 2 L(w * ) ηβ 1 -η∇ 2 L(w * ) . ( ) Note ∇ 2 L(w * ) is the Hessian of the loss function at the fixed point. The spectrum of the Jacobian can be written in terms of the eigenvalues of the Hessian λ i . This is accomplished by using a straightforward ansatz for the (unnormalized) eigenvectors of the Jacobian k i = (µ i q i , q i ), where q i are eigenvectors of the Hessian with eigenvalue λ i . Solving the resulting coupled eigenvalue equations for eigenvalue κ i : 1 -ηλ i + ηβµ i = κ i , -λ i + µ i β = µ i κ i . ( ) For a fixed λ i , there will be two solutions given by κ i ± = 1 2 1 + β -ηλ i ± (1 + β -ηλ i ) 2 -4β , i = 1, ..., D µ i ± = 1 2βη β -1 + ηλ i ± (1 + β -ηλ i ) 2 -4β . ( ) For the set of zero modes λ i = 0, we get the following modes: κ + = 1, corresponding to motion only along w, with eigenvector k i = (0, q i ). In addition, there is a mixed longitudinal mode which includes a component of π along the zero manifold k i = (µ -q i , q i ), and has an eigenvalue κ -= β. On the zero loss manifold, we can assume the Hessian is positive semi-definite, and that the positive eigenvalues satisfy 0 < c 1 ≤ λ i ≤ c 2 . ( ) for constants c 1 , c 2 independent of η, β. We now analyze the spectrum of the Jacobian one eigenvalue at a time, and then use these results to informally control the relaxation rate of off-manifold perturbations. It is useful first to consider the conditions for stability, i.e. (139) Proof of Case 1: The condition ηλ i < (1 -√ β) 2 implies A 2 > 4β, where A = 1 + β -ηλ i . The condition for stability then requires -1 < κ i < +1. We satisfy both sides of this inequality: If 1 + β -ηλ i = A > 0, then |κ i -| < 1,and κ i + > 0, so we simply require κ i + < 1, i.e. κ i < +1 (140) 1 2 (A + A 2 -4β) < 1 (141) + A 2 -4β < 2 -A (142) A 2 -4β < 4 -4A + A 2 (143) A = 1 + β -ηλ i < 1 + β ⇒ ηλ i > 0 If 1 + β -ηλ i < 0, then |κ i + | < 1, and κ i -< 1, so we only require κ i -> -1 -1 < κ i - (145) -1 < 1 2 (-|A| -A 2 -4β) (146) |A| -2 < -A 2 -4β (147) 2 -|A| > A 2 -4β A 2 -4|A| + 4 > A 2 -4β (149) |A| = -1 -β + ηλ i < 1 + β ⇒ ηλ i < 2(1 + β) Finally, note that since η 1-γ < η γ , we have that 1 -c1 C η 1-γ > 1 -Cη γ = β, so indeed ρ 2 < ρ 1 . Equipped with the upper and lower bounds on the spectrum leads naturally to bounds on the relaxation rate. For the purely decaying modes at γ < 1/2, we use ρ t 2 ≤ |δx T t | ≤ ρ t 1 , δx T t represents the projection of the fluctuations δx t onto the transverse and mixed longitudinal modes. After applying Eqs. ( 153) and ( 154), we arrive at the result quoted in the main text in Sec. 3.3. For γ > 1/2, the modes are oscillatory. However, the eigenvalues within the unit circle have norm either √ β, β, as reflected in Eqs. ( 155) and ( 156). This implies that we can estimate the decay rate of the envelope of the transverse and mixed longitudinal modes in this regime, thereby arriving at second expression quoted in the main text in Sec. 3.3.

F LINEARIZED SGD AND ORNSTEIN-UHLENBECK PROCESS ON THE VALLEY

In this appendix, we provide a derivation of some of the statements quoted in Sec. 1.1. To get there, we start with the basic model for momentum SGD (1) but linearize around a point on the valley w 0 ∈ Γ where L(w 0 ) = ∇L(w 0 ) = 0. Let w k = w 0 + δw k , and define the Hessian H(w 0 ) = ∇ 2 L(w 0 ). Then π k+1 = βπ k -H(w 0 )δw k + σ(w 0 )ξ k , δw k+1 = δw k + ηπ k+1 , Consider the projector along transverse nonzero eigemode λ of H(w 0 ), P T λ , and define P T λ x = X, and P T λ π = Π. Let σ = P T λ σ, and P T λ H = λP T λ . Let σσ = Λ. Then Π k+1 = βΠ k -λX k + σ(w 0 )ξ k , X k+1 = X k + ηΠ k+1 , (165) = X k + ηβΠ k -ηλX k + η σ(w 0 )ξ k This is a simple OU process, and we can easily compute the second moments. Define the second moments C 12 (k) = X k Π k , C 11 (k) = X k X k , C 22 (k) = Π k Π k . ( ) We find by taking the equations above, squaring them, then averaging over the noise, C 22 (k + 1) = β 2 C 22 (k) -2βλC 12 (k) + λ 2 C 11 (k) + 2 Λ, C 12 (k + 1) = ηβ 2 C 22 (k) + β(1 -2ηλ)C 12 (k) -λ(1 -ηλ)C 11 (k) + η 2 Λ, C 11 (k + 1) = η 2 β 2 C 22 + 2ηβ(1 -ηλ)C 12 (k) + .(1 -ηλ) 2 C 11 (k) + η 2 2 Λ. Next, assuming a stationary distribution implies C(k + 1) = C(k), which then allows us to solve for equilibrium variance, and extract the main quantity of interest, which is the variance of the weights. We find then C 11 (k) = η 2 2 Λ (1 -β)ηλ(2(1 + β) -ηλ) . In the limit of small η we extract the scaling behavior quoted in Sec.(1.1). We can see how the mixing timescale τ 1 , discussed in Sec. 1.1, arises from this linearized analysis. By taking the expectation value of the OU process, the noise will vanish and we find that the average values follow the linearized GD dynamics analyzed in E. Thus, from this linearized GD analysis we can extract the characteristic timescale for the OU process to approach its mean value.



Definition 3.3. For a symmetric matrix H ∈ R D × R D , and W H = {Σ ∈ R D × R D : Σ = Σ , HH † Σ = H † HΣ = σ}, we define the operator LH : W H → W H with LH S ≡ {H, S} + 1 2 C -2 η 1-2γ [[S, H], H], with [S, H] = SH -HS. It can be shown that the operator LH is invertible (see Lemma C.3).

, and U ∈ R m×n . For d = m = 1, we refer to this as the vector UV model(Rennie & Srebro, 2005;Saxe et al., 2014;Lewkowycz et al., 2020).

Figure 1: Timescale of training as a function of γ. a): theoretical prediction with blue line representing the timescale for equilibration while the black line shows the timescale of the drift along Γ. The maximum of these two gives the overall timescale. In b), c), and d) we demonstrate this result with the vector U V model with linear, tanh and Relu activations respectively.

Figure 2: Classification of CIFAR10 using a ResNet18 model. Subfigure a) shows our training protocol which is to first use noiseless gradient descent (black) to reach the zero-loss manifold, then to perform SGD label noise with various (blue, red) values of η and β, before finally projecting onto the valley (green) and measuring the test accuracy. Subfigure b) shows the scaling of the optimal momentum, β * , as a function of η. We perform a power-law fit whose exponent γ * = 0.660 matches very closely to the value implied by the theory γ = 2 3 . Notice we also extract the constant C = 0.11.

Figure 3: The expected test error, a), and Hessian of the training loss, b), in matrix sensing (with d = 100, r = 5, and 5rd = 2500 samples) as a function of training epoch plotted for different values of β at η = 0.1. The label noise variance is 0.1. Each curve represents a different value of β.The inset shows that the orange curve crosses below the blue curve before convergence of the Hessian. Therefore, the same value of β is optimal for both the Hessian and the expected test error -increasing or decreasing β from this value slows down generalization.

Figure 4: A sample of the curves with different momentum hyperparameters β with η = 0.001 with Resnet on CIRAR10. The speed of increase in accuracy is non-monotonic in β: the best performance is obtained by an intermediate value of β, consistent with our predictions.

|κ i | < 1, which are stated in (138,139) below:Case 1 : If ηλ i < (1 -β) 2 , then κ i ± ∈ R and |κ i ± | < 1 iff 0 < ηλ i < 2(1 + β). (138) Case 2 : If ηλ i > (1 -β) 2 , then κ i ± ∈ C and |κ i ± | = β < 1.

• • • denotes the noise average. To keep track of the displacements in the longitudinal directions, δw L k , we need to look at the cubic order in the Taylor expansion of L(w 0 + δw k ), i.e. ∂ 2 (∇L)[δw k , δw k ]. Let P L be the projector onto the tangent space. The expectation value of momentum, upon applying the longitudinal projector, is P

. Additionally, Kunin et al. (2021) and Xie et al. (2021) studied the diffusive dynamics induced by SGD both empirically and in a simple theoretical model.

28)when we take n → 0, thus recoverying condition B.1. By definition Z n is a martingale. Notice also by the definition of Z n , because ξ k is i.i.d. with variance 1, that Z n (t) has variance A n (t) 2 n ≤ t which is uniformly bounded and hence Z n (t) is uniformly integrable for stopping times τ k n = 2k > k. Also note that ∆Z n (k 2 n ) = | n ξ k | which goes to zero in probability as becomes small because ξ k has bounded variance, and ∆Z n (t) is zero otherwise. This shows that Z n satisfies condition B.2. Because Z n and A n are discontinuous at the same time we automatically satisfy condition B.3 as pointed out by Katzenberger.

annex

Proof of Case 2: the condition ηλ i > (1 -√ β) 2 implies A 2 < 4β, where A = (1 + β -ηλ i ). The corresponding eigenvalues of the Jacobian can be written κThese results show us when the GD+momentum is stable. Next, assuming stability, we want to estimate the rate of convergence to the fixed point. More precisely, we would like to determine the fastest mode as well as the slowest mode. To this end, we define two quantitiesUsing the explicit scaling for momentum, and in the limit of small learning rate, we prove the following for ρ 1 : Lemma E.1. Let β = 1 -Cη γ , and η sufficiently small:For γ < 1/2, the condition for Case 1 (138) holds,.For γ > 1/2, the condition for Case 2 (139), and.Proof: For small η, we find thatWhen γ > 1/2, this expression tends to zero as η → 0. Therefore, for sufficiently small η, the condition for Case 2 (139) will be satisfied and |κ i | = √ β for all λ i > 0. Since β < √ β, this implies ρ 1 = √ β and ρ 2 = β. The scaling behavior in the Lemma follows by substitution.For γ < 1/2, (157) diverges as η → 0, which means Case 1 (138) will obtain for all λ i . Next, since for small η, 1 + β -ηλ i > 1 + β -ηc 2 = 2 -Cη γ -ηc 2 > 0, since C and c 2 are order one constants, and η → 0. In this case, the largest contribution from the nonzero eigenvalues λ i will come from κ i + . In particular, we findFor γ < 1/2, we have the hierarchy η γ > η 2γ > η > η γ+1 > η 2 . This allows us to simplify the upper boundNext, we find a lower bound for ρ 2 . This will be controlled by κ i -. We may use then that

