SELF-STABILIZATION: THE IMPLICIT BIAS OF GRADI-ENT DESCENT AT THE EDGE OF STABILITY

Abstract

Traditional analyses of gradient descent show that when the largest eigenvalue of the Hessian, also known as the sharpness S(θ), is bounded by 2/η, training is "stable" and the training loss decreases monotonically. Recent works, however, have observed that this assumption does not hold when training modern neural networks with full batch or large batch gradient descent. Most recently, Cohen et al. (2021) detailed two important phenomena. The first, dubbed progressive sharpening, is that the sharpness steadily increases throughout training until it reaches the instability cutoff 2/η. The second, dubbed edge of stability, is that the sharpness hovers at 2/η for the remainder of training while the loss continues decreasing, albeit non-monotonically. We demonstrate that, far from being chaotic, the dynamics of gradient descent at the edge of stability can be captured by a cubic Taylor expansion: as the iterates diverge in direction of the top eigenvector of the Hessian due to instability, the cubic term in the local Taylor expansion of the loss function causes the curvature to decrease until stability is restored. This property, which we call self-stabilization, is a general property of gradient descent and explains its behavior at the edge of stability. A key consequence of self-stabilization is that gradient descent at the edge of stability implicitly follows projected gradient descent (PGD) under the constraint S(θ) ≤ 2/η. Our analysis provides precise predictions for the loss, sharpness, and deviation from the PGD trajectory throughout training, which we verify both empirically in a number of standard settings and theoretically under mild conditions. Our analysis uncovers the mechanism for gradient descent's implicit bias towards stability.

1. INTRODUCTION

1.1 GRADIENT DESCENT AT THE EDGE OF STABILITY Almost all neural networks are trained using a variant of gradient descent, most commonly stochastic gradient descent (SGD) or ADAM (Kingma & Ba, 2015) . When deciding on an initial learning rate, many practitioners rely on intuition drawn from classical optimization. In particular, the following classical lemma, known as the "descent lemma," provides a common heuristic for choosing a learning rate in terms of the sharpness of the loss function: Definition 1. Given a loss function L(θ), the sharpness at θ is defined to be S(θ) := λ max (∇ 2 L(θ)). When this eigenvalue is unique, the associated eigenvector is denoted by u(θ). Lemma 1 (Descent Lemma). Assume that S(θ) ≤ ℓ for all θ. If θ t+1 = θ t -η∇L(θ t ), L(θ t+1 ) ≤ L(θ t ) - η(2 -ηℓ) 2 ∥∇L(θ t )∥ 2 . Here, the loss decrease is proportional to the squared gradient, and is controlled by the quadratic η(2ηℓ) in η. This function is maximized at η = 1/ℓ, a popular learning rate criterion. For any η < 2/ℓ, the descent lemma guarantees that the loss will decrease. As a result, learning rates below 2/ℓ are considered "stable" while those above 2/ℓ are considered "unstable." For quadratic



The condition that ∇L(θ) • u(θ) = 0 is necessary to ensure the stability of the constrained trajectory.

annex

Published as a conference paper at ICLR 2023 loss functions, e.g. from linear regression, this is tight. Any learning rate above 2/ℓ provably leads to exponentially increasing loss.However, it has recently been observed that in neural networks, the descent lemma is not predictive of the optimization dynamics. Recently, Cohen et al. (2021) observed two important phenomena for gradient descent, which made more precise similar observations in Jastrzębski et al. (2019); Jastrzebski et al. (2020) for SGD:Progressive Sharpening Throughout most of the optimization trajectory, the gradient of the loss is negatively aligned with the gradient of sharpness, i.e. ∇L(θ) • ∇S(θ) < 0. As a result, for any reasonable learning rate η, the sharpness increases throughout training until it reaches S(θ) = 2/η.

Edge of Stability

Once the sharpness reaches 2/η (the "break-even" point in Jastrzebski et al. ( 2020)), it ceases to increase and remains around 2/η for the rest of training. The descent lemma no longer guarantees the loss decreases but the loss still continues decreasing, albeit non-monotonically. Traditional non-convex optimization analyses involve Taylor expanding the loss function to second order around θ to prove loss decrease when η ≤ 2/S(θ). When this is violated, the iterates diverge exponentially in the top eigenvector direction, u, thus leaving the region in which the loss function is locally quadratic. Understanding the dynamics thus necessitates a cubic Taylor expansion.Our key insight is that the missing term in the Taylor expansion of the gradient after diverging in the u direction is ∇ 3 L(θ)(u, u), which is conveniently equal to the gradient of the sharpness at θ: Lemma 2 (Self-Stabilization Property). If the top eigenvalue of ∇ 2 L(θ) is unique, then the sharpness S(θ) is differentiable at θ and ∇S(θAs the iterates move in the negative gradient direction, this term has the effect of decreasing the sharpness. The story of self-stabilization is thus that as the iterates diverge in the u direction, the strength of this movement in the -∇S(θ) direction grows until it forces the sharpness below 2/η, at which point the iterates in the u direction shrink and the dynamics re-enter the quadratic regime.This negative feedback loop prevents both the sharpness S(θ) and the movement in the top eigenvector direction, u, from growing out of control. As a consequence, we show that gradient descent implicitly solves the constrained minimization problem:Specifically, if the stable set M is defined by M := {θ : S(θ) ≤ 2/η and ∇L(θ) • u(θ) = 0} 1 then the gradient descent trajectory {θ t } tracks the following projected gradient descent trajectory {θ † t } which solves the constrained problem (Barber & Ha, 2017) :where proj M (θ) := arg min θ ′ ∈M ∥θθ ′ ∥.(2)

