FLATTER, FASTER: SCALING MOMENTUM FOR OPTI-MAL SPEEDUP OF SGD

Abstract

Commonly used optimization algorithms often show a trade-off between good generalization and fast training times. For instance, stochastic gradient descent (SGD) tends to have good generalization; however, adaptive gradient methods have superior training times. Momentum can help accelerate training with SGD, but so far there has been no principled way to select the momentum hyperparameter. Here we study training dynamics arising from the interplay between SGD with label noise and momentum in the training of overparametrized neural networks. We find that scaling the momentum hyperparameter 1 -β with the learning rate to the power of 2/3 maximally accelerates training, without sacrificing generalization. To analytically derive this result we develop an architecture-independent framework, where the main assumption is the existence of a degenerate manifold of global minimizers, as is natural in overparametrized models. Training dynamics display the emergence of two characteristic timescales that are well-separated for generic values of the hyperparameters. The maximum acceleration of training is reached when these two timescales meet, which in turn determines the scaling limit we propose. We confirm our scaling rule for synthetic regression problems (matrix sensing and teacher-student paradigm) and classification for realistic datasets (ResNet-18 on CIFAR10, 6-layer MLP on FashionMNIST), suggesting the robustness of our scaling rule to variations in architectures and datasets.

1. INTRODUCTION

The modern paradigm for optimization of deep neural networks has engineers working with vastly overparametrized models and training to near perfect accuracy (Zhang et al., 2017) . In this setting, a model will typically have not just isolated minima in parameter space, but a continuous set of minimizers, not all of which generalize well. Liu et al. (2020) demonstrate that depending on parameter initialization and hyperparameters, stochastic gradient descent (SGD) is capable of finding minima with wildly different test accuracies. Thus, the power of a particular optimization method lies in its ability to select a minimum that generalizes amongst this vast set. In other words, good generalization relies on the implicit bias or regularization of an optimization algorithm. There is a significant body of evidence that training deep nets with SGD leads to good generalization. Intuitively, SGD appears to prefer flatter minima (Keskar et al., 2017; Wu et al., 2018; Xie et al., 2020) , and flatter minima generalize better (Hochreiter & Schmidhuber, 1997) . More recently, a variant of SGD which introduces "algorithmic" label noise has been especially amenable to rigorous treatment. In the overparametrized setting Blanc et al. (2020) were able to rigorously determine that SGD with label noise converges not just to any minimum, but to those minima that lead to the smallest trace norm of the Hessian. However, Li et al. (2022) show that the dynamics of this regularization happen on a timescale proportional to the inverse square of the learning rate η -much slower than the time to first converge to an interpolating solution. Therefore we consider the setting where we remain near the local minima, which is responsible for significant regularization after the initial convergence of the train loss (Blanc et al., 2020) . With the recent explosion in size of both models and datasets, training time has become an important consideration in addition to asymptotic generalization error. In this context, adaptive gradient methods such as Adam (Kingma & Ba, 2015) are unilaterally preferred over variants of SGD, even though they often yield worse generalization errors in practical settings (Keskar & Socher, 2017;  

