FLATTER, FASTER: SCALING MOMENTUM FOR OPTI-MAL SPEEDUP OF SGD

Abstract

Commonly used optimization algorithms often show a trade-off between good generalization and fast training times. For instance, stochastic gradient descent (SGD) tends to have good generalization; however, adaptive gradient methods have superior training times. Momentum can help accelerate training with SGD, but so far there has been no principled way to select the momentum hyperparameter. Here we study training dynamics arising from the interplay between SGD with label noise and momentum in the training of overparametrized neural networks. We find that scaling the momentum hyperparameter 1 -β with the learning rate to the power of 2/3 maximally accelerates training, without sacrificing generalization. To analytically derive this result we develop an architecture-independent framework, where the main assumption is the existence of a degenerate manifold of global minimizers, as is natural in overparametrized models. Training dynamics display the emergence of two characteristic timescales that are well-separated for generic values of the hyperparameters. The maximum acceleration of training is reached when these two timescales meet, which in turn determines the scaling limit we propose. We confirm our scaling rule for synthetic regression problems (matrix sensing and teacher-student paradigm) and classification for realistic datasets (ResNet-18 on CIFAR10, 6-layer MLP on FashionMNIST), suggesting the robustness of our scaling rule to variations in architectures and datasets.

1. INTRODUCTION

The modern paradigm for optimization of deep neural networks has engineers working with vastly overparametrized models and training to near perfect accuracy (Zhang et al., 2017) . In this setting, a model will typically have not just isolated minima in parameter space, but a continuous set of minimizers, not all of which generalize well. Liu et al. (2020) demonstrate that depending on parameter initialization and hyperparameters, stochastic gradient descent (SGD) is capable of finding minima with wildly different test accuracies. Thus, the power of a particular optimization method lies in its ability to select a minimum that generalizes amongst this vast set. In other words, good generalization relies on the implicit bias or regularization of an optimization algorithm. There is a significant body of evidence that training deep nets with SGD leads to good generalization. Intuitively, SGD appears to prefer flatter minima (Keskar et al., 2017; Wu et al., 2018; Xie et al., 2020) , and flatter minima generalize better (Hochreiter & Schmidhuber, 1997) . More recently, a variant of SGD which introduces "algorithmic" label noise has been especially amenable to rigorous treatment. In the overparametrized setting Blanc et al. (2020) were able to rigorously determine that SGD with label noise converges not just to any minimum, but to those minima that lead to the smallest trace norm of the Hessian. However, Li et al. (2022) show that the dynamics of this regularization happen on a timescale proportional to the inverse square of the learning rate η -much slower than the time to first converge to an interpolating solution. Therefore we consider the setting where we remain near the local minima, which is responsible for significant regularization after the initial convergence of the train loss (Blanc et al., 2020) . With the recent explosion in size of both models and datasets, training time has become an important consideration in addition to asymptotic generalization error. In this context, adaptive gradient methods such as Adam (Kingma & Ba, 2015) are unilaterally preferred over variants of SGD, even though they often yield worse generalization errors in practical settings (Keskar & Socher, 2017; Wilson et al., 2017) , though extensive hyperparameter tuning (Choi et al., 2019) or scheduling (Xie et al., 2022) can potentially obviate this problem. These two constraints motivate a careful analysis of how momentum accelerates SGD. Classic work on acceleration methods, which we refer to generally as momentum, have found a provable benefit in the deterministic setting, where gradient updates have no error. However, rigorous guarantees have been harder to find in the stochastic setting, and remain limited by strict conditions on the noise (Polyak, 1987; Kidambi et al., 2018) or model class and dataset structure Lee et al. (2022) . In this work, we show that there exists a scaling limit for SGD with momentum (SGDM) which provably increases the rate of convergence.

Notation.

In what follows, we denote by C n , for n = 0, 1, . . . the set of functions with continuous n th derivatives. For any function f , ∂f [u] and ∂ 2 f [u, v] will denote directional first and second derivatives along directions defined by vectors u, v ∈ R D . We may occasionally also write ∂ 2 f [Σ] = D i,j=1 ∂ 2 f [e i , e j ]Σ ij . Given a submanifold Γ ⊂ R D and w ∈ Γ, we denote by T w Γ the tangent space to Γ in w, and by P L (w) the projector onto T w Γ (we will often omit the dependence on w and simply write P L ). Given a matrix H ∈ R D × R D , we will denote by H the transpose, and H † the pseudoinverse of H.

1.1. HEURISTIC EXPLANATION FOR OPTIMAL MOMENTUM-BASED SPEEDUP

Deep neural networks typically posses a manifold of parametrizations with zero training error. Because the gradients of the loss function vanish along this manifold, the dynamics of the weights is completely frozen under gradient descent. However, as appreciated by Wei & Schwab (2019) and Blanc et al. ( 2020), noise can generate an average drift of the weights along the manifold. In particular, SGD noise can drive the weights to a lower-curvature region which, heuristically, explains the good generalization properties of SGD. Separately, it is well-known that adding momentum typically leads to acceleration in training Sutskever et al. ( 2013) Below, we will see that there is a nontrivial interplay between the drift induced by noise and momentum, and find that acceleration along the valley is maximized by a particular hyperparameter choice in the limit of small learning rate. In this section we will illustrate the main intuition leading to this prediction using heuristic arguments, and defer a more complete discussion to Sec. 3. We model momentum SGD with label noise using the following formulation π k+1 = βπ k -∇L(w k ) + σ(w k )ξ k , w k+1 = w k + ηπ k+1 , where η is the learning rate, β is the (heavy-ball) momentum hyperparameter (as introduced by Polyak (1964)), w ∈ R D denotes the weights and π ∈ R D denotes momentum (also called the auxiliary variable). L : R D → R is the training loss, and σ : R D → R D×r is the noise function, whose dependence on the weights allows to model the gradient noise due to SGD and label noise. Specifically this admits modeling phenomena such as automatic variance reduction (Liu & Belkin, 2020) , and expected smoothness, satisfying only very general assumptions such as those developed by Khaled & Richtárik (2020) . Finally, ξ k ∈ R r is sampled i.i.d. at every timestep k from a distribution with zero mean and unit variance, and > 0. We will now present a heuristic description of the drift dynamics that is induced by the noise along a manifold of minimizers Γ = {w : L(w) = 0} ⊆ R D , in the limit → 0. In practice, this limit corresponds to choosing small strength of label noise and large minibatch size. Let us assume that the weights at initial time w 0 are already close to Γ, and π 0 = 0. Because of this, the gradients of L at w 0 are very small, and so only fluctuations transverse to the manifold will generate systematic drifts. Denoting by δw k = w 0 -w k the displacement of the weights after k timesteps, let us Taylor expand the first equation in (1) to get π k+1 = βπ k -∇ 2 L(w 0 )[δw k ] + σ(w 0 )ξ k . By construction, the Hessian ∇ 2 L(w 0 ) vanishes along the directions tangent to Γ, while in the transverse direction we have an Ornstein-Uhlenbeck (OU) process. The number of time steps it takes to this process to relax to the stationary state, or mixing time, is τ 1 = Θ (1/(1 -β)) as β → 1, which can be anticipated from the first equation in (1) since π k+1 -π k ∼ -(1 -β)π k (see Sec. 3.3 for a more detailed derivation). After this time, the variance of this linearized OU process becomes independent of the time step, and can be estimated to be (see Appendix 



F): (δw T k ) δw T k = Θ 2 η/(1 -β), where • • • denotes the noise average. To keep track of the displacements in the longitudinal directions, δw L k , we need to look at the cubic order in

