A VIEW OF MINI-BATCH SGD VIA GENERATING FUNC-TIONS: CONDITIONS OF CONVERGENCE, PHASE TRAN-SITIONS, BENEFIT FROM NEGATIVE MOMENTA

Abstract

Mini-batch SGD with momentum is a fundamental algorithm for learning large predictive models. In this paper we develop a new analytic framework to analyze noise-averaged properties of mini-batch SGD for linear models at constant learning rates, momenta and sizes of batches. Our key idea is to consider the dynamics of the second moments of model parameters for a special family of "Spectrally Expressible" approximations. This allows to obtain an explicit expression for the generating function of the sequence of loss values. By analyzing this generating function, we find, in particular, that 1) the SGD dynamics exhibits several convergent and divergent regimes depending on the spectral distributions of the problem; 2) the convergent regimes admit explicit stability conditions, and explicit loss asymptotics in the case of power-law spectral distributions; 3) the optimal convergence rate can be achieved at negative momenta. We verify our theoretical predictions by extensive experiments with MNIST, CIFAR10 and synthetic problems, and find a good quantitative agreement.

1. INTRODUCTION

We consider a classical mini-batch Stochastic Gradient Descent (SGD) algorithm (Robbins & Monro, 1951; Bottou & Bousquet, 2007) with momentum (Polyak, 1964) : w t+1 = w t + v t+1 , v t+1 = -α t ∇ w L Bt (w t ) + β t v t . Here, L B (w) = 1 b b i=1 l(f (x i , w), y i ) is the sampled loss of a model y = f (x, w), computed using a pointwise loss l( y, y) on a mini-batch B = {(x i , y i )} b s=1 of b data points representing the target function y = f * (x). The momentum term v n represents information about gradients from previous iterations and is well-known to significantly improve convergence both generally (Polyak, 1987) and for neural networks (Sutskever et al., 2013) . Re-sampling of the mini-batch B t at each SGD iteration t creates a specific gradient noise, structured according to both the local geometry of the model f (x, w) and the quality of current approximation y. In the context of modern deep learning, f (x, w) is usually very complex, and the quantitative prediction of the SGD behavior becomes a challenging task that is far from being complete at the moment. Our goal is to obtain explicit expressions characterizing the average case convergence of mini-batch SGD for the classical least-squares problem of minimizing quadratic objective L(w). This setup is directly related to modern neural networks trained with a quadratic loss function, since networks can often be well described -e.g., in the large-width limit (Jacot et al., 2018; Lee et al., 2019) or during the late stage of training (Fort et al., 2020) -by their linearization w.r.t. parameters w.

