ON THE ORIGIN OF IMPLICIT REGULARIZATION IN STOCHASTIC GRADIENT DESCENT

Abstract

For infinitesimal learning rates, stochastic gradient descent (SGD) follows the path of gradient flow on the full batch loss function. However moderately large learning rates can achieve higher test accuracies, and this generalization benefit is not explained by convergence bounds, since the learning rate which maximizes test accuracy is often larger than the learning rate which minimizes training loss. To interpret this phenomenon we prove that for SGD with random shuffling, the mean SGD iterate also stays close to the path of gradient flow if the learning rate is small and finite, but on a modified loss. This modified loss is composed of the original loss function and an implicit regularizer, which penalizes the norms of the minibatch gradients. Under mild assumptions, when the batch size is small the scale of the implicit regularization term is proportional to the ratio of the learning rate to the batch size. We verify empirically that explicitly including the implicit regularizer in the loss can enhance the test accuracy when the learning rate is small.

1. INTRODUCTION

In the limit of vanishing learning rates, stochastic gradient descent with minibatch gradients (SGD) follows the path of gradient flow on the full batch loss function (Yaida, 2019) . However in deep networks, SGD often achieves higher test accuracies when the learning rate is moderately large (LeCun et al., 2012; Keskar et al., 2017) . This generalization benefit is not explained by convergence rate bounds (Ma et al., 2018; Zhang et al., 2019) , because it arises even for large compute budgets for which smaller learning rates often achieve lower training losses (Smith et al., 2020) . Although many authors have studied this phenomenon (Jastrzębski et al., 2018; Smith & Le, 2018; Chaudhari & Soatto, 2018; Shallue et al., 2018; Park et al., 2019; Li et al., 2019; Lewkowycz et al., 2020) , it remains poorly understood, and is an important open question in the theory of deep learning. In a recent work, Barrett & Dherin (2021) analyzed the influence of finite learning rates on the iterates of gradient descent (GD). Their approach is inspired by backward error analysis, a method for the numerical analysis of ordinary differential equation (ODE) solvers (Hairer et al., 2006) . The key insight of backward error analysis is that we can describe the bias introduced when integrating an ODE with finite step sizes by introducing an ancillary modified flow. This modified flow is derived to ensure that discrete iterates of the original ODE lie on the path of the continuous solution to the modified flow. Using this technique, the authors show that if the learning rate is not too large, the discrete iterates of GD lie close to the path of gradient flow on a modified loss C GD (ω) = C(ω) + ( /4)||∇C(ω)|| 2 . This modified loss is composed of the original loss C(ω) and an implicit regularizer proportional to the learning rate which penalizes the euclidean norm of the gradient. However these results only hold for full batch GD, while in practice SGD with small or moderately large batch sizes usually achieves higher test accuracies (Keskar et al., 2017; Smith et al., 2020) . In this work, we devise an alternative approach to backward error analysis, which accounts for the correlations between minibatches during one epoch of training. Using this novel approach, we prove that for small finite learning rates, the mean SGD iterate after one epoch, averaged over all possible sequences of minibatches, lies close to the path of gradient flow on a second modified loss C SGD (ω), which we define in equation 1. This new modified loss is also composed of the full batch loss function and an implicit regularizer, however the structure of the implicit regularizers for GD and SGD differ, and their modified losses can have different local and global minima. Our analysis therefore helps explain both why finite learning rates can aid generalization, and why SGD can achieve higher test accuracies than GD. We assume that each training example is sampled once per epoch, in line with best practice (Bottou, 2012), and we confirm empirically that explicitly including the implicit regularization term of SGD in the training loss can enhance the test accuracy when the learning rate is small. Furthermore, we prove that if the batch size is small and the gradients are sufficiently diverse, then the expected magnitude of the implicit regularization term of SGD is proportional to the ratio of the learning rate to the batch size (Goyal et al., 2017; Smith et al., 2018) . We note that many previous authors have sought to explain the generalization benefit of SGD using an analogy between SGD and stochastic differential equations (SDEs) (Mandt et al., 2017; Smith & Le, 2018; Jastrzębski et al., 2018; Chaudhari & Soatto, 2018) . However this SDE analogy assumes that each minibatch is randomly sampled from the full dataset, which implies that some examples will be sampled multiple times in one epoch. Furthermore, the most common SDE analogy holds only for vanishing learning rates (Yaida, 2019) and therefore misses the generalization benefits of finite learning rates which we identify in this work. An important exception is Li et al. ( 2017), who applied backward error analysis to identify a modified SDE which holds when the learning rate is finite. However this work still relies on the assumption that minibatches are sampled randomly. It also focused on the convergence rate, and did not discuss the performance of SGD on the test set. Main Result. We now introduce our main result. We define the cost function over parameters ω as C(ω) = (1/N ) N j=1 C j (ω), which is the mean of the per-example costs C j (ω), where N denotes the training set size. Gradient flow follows the ODE ω = -∇C(ω), while gradient descent computes discrete updates ω i+1 = ω i -∇C(ω i ), where is the learning rate. For simplicity, we assume that the batch size B perfectly splits the training set such that N %B = 0, where % denotes the modulo operation, and for convenience we define the number of batches per epoch m = N/B. We can therefore re-write the cost function as a sum over minibatches C(ω) = (1/m) m-1 k=0 Ĉk (ω), where the minibatch cost Ĉk (ω) = (1/B) kB+B j=kB+1 C j (ω). In order to guarantee that we sample each example precisely once per epoch, we define SGD by the discrete update ω i+1 = ω i -∇ Ĉi%m (ω i ). Informally, our main result is as follows. After one epoch, the mean iterate of SGD with a small but finite learning rate , averaged over all possible shuffles of the batch indices, stays close to the path of gradient flow on a modified loss ω = -∇ C SGD (ω), where the modified loss C SGD is given by: C SGD (ω) = C(ω) + 4m m-1 k=0 ||∇ Ĉk (ω)|| 2 . We emphasize that our analysis studies the mean evolution of SGD, not the path of individual trajectories. The modified loss C SGD (ω) is composed of the original loss C(ω) and an implicit regularizer C reg (ω) = (1/4m) m-1 k=0 ||∇ Ĉk (ω)|| 2 . The scale of this implicit regularization term is proportional to the learning rate , and it penalizes the mean squared norm of the gradient evaluated on a batch of B examples. To help us compare the modified losses of GD and SGD, we can expand, C SGD (ω) = C(ω) + 4 ||∇C(ω)|| 2 + 4m m-1 i=0 ||∇ Ĉi (ω) -∇C(ω)|| 2 . ( ) We arrive at Equation 2 from Equation 1 by noting that m-1 i=0 (∇ Ĉi (ω) -∇C(ω)) = 0. In the limit B → N , we identify the modified loss of gradient descent, C GD = C(ω) + ( /4)||∇C(ω)|| 2 , which penalizes "sharp" regions where the norm of the full-batch gradient (||∇C(ω)|| 2 ) is large. However, as shown by Equation 2, the modified loss of SGD penalizes both sharp regions where the full-batch gradient is large, and also "non-uniform" regions where the norms of the errors in the minibatch gradients (||∇ Ĉ(ω) -∇C(ω)|| 2 ) are large (Wu et al., 2018) . Although global minima of C(ω) are global minima of C GD (ω), global minima of C(ω) may not be global (or even local) minima of C SGD (ω). Note however that C(ω) and C SGD (ω) do share the same global minima on over-parameterized models which can interpolate the training set (Ma et al., 2018) . We verify in our experiments that the implicit regularizer can enhance the test accuracy of models trained with SGD. Paper structure. In Section 2, we derive our main result (Equation 1), and we confirm empirically that we can close the generalization gap between small and large learning rates by including the implicit regularizer explicitly in the loss function. In Section 3, we confirm Equation 1 satisfies the linear scaling rule between learning rate and batch size (Goyal et al., 2017) . In Section 4, we provide additional experiments which challenge the prevailing view that the generalization benefit of small batch SGD arises from the temperature of an associated SDE (Mandt et al., 2017; Park et al., 2019) .

