SELF-STABILIZATION: THE IMPLICIT BIAS OF GRADI-ENT DESCENT AT THE EDGE OF STABILITY

Abstract

Traditional analyses of gradient descent show that when the largest eigenvalue of the Hessian, also known as the sharpness S(θ), is bounded by 2/η, training is "stable" and the training loss decreases monotonically. Recent works, however, have observed that this assumption does not hold when training modern neural networks with full batch or large batch gradient descent. Most recently, Cohen et al. (2021) detailed two important phenomena. The first, dubbed progressive sharpening, is that the sharpness steadily increases throughout training until it reaches the instability cutoff 2/η. The second, dubbed edge of stability, is that the sharpness hovers at 2/η for the remainder of training while the loss continues decreasing, albeit non-monotonically. We demonstrate that, far from being chaotic, the dynamics of gradient descent at the edge of stability can be captured by a cubic Taylor expansion: as the iterates diverge in direction of the top eigenvector of the Hessian due to instability, the cubic term in the local Taylor expansion of the loss function causes the curvature to decrease until stability is restored. This property, which we call self-stabilization, is a general property of gradient descent and explains its behavior at the edge of stability. A key consequence of self-stabilization is that gradient descent at the edge of stability implicitly follows projected gradient descent (PGD) under the constraint S(θ) ≤ 2/η. Our analysis provides precise predictions for the loss, sharpness, and deviation from the PGD trajectory throughout training, which we verify both empirically in a number of standard settings and theoretically under mild conditions. Our analysis uncovers the mechanism for gradient descent's implicit bias towards stability.

1. INTRODUCTION

1.1 GRADIENT DESCENT AT THE EDGE OF STABILITY Almost all neural networks are trained using a variant of gradient descent, most commonly stochastic gradient descent (SGD) or ADAM (Kingma & Ba, 2015) . When deciding on an initial learning rate, many practitioners rely on intuition drawn from classical optimization. In particular, the following classical lemma, known as the "descent lemma," provides a common heuristic for choosing a learning rate in terms of the sharpness of the loss function: Definition 1. Given a loss function L(θ), the sharpness at θ is defined to be S(θ) := λ max (∇ 2 L(θ)). When this eigenvalue is unique, the associated eigenvector is denoted by u(θ). Lemma 1 (Descent Lemma). Assume that S(θ) ≤ ℓ for all θ. If θ t+1 = θ t -η∇L(θ t ), L(θ t+1 ) ≤ L(θ t ) - η(2 -ηℓ) 2 ∥∇L(θ t )∥ 2 . Here, the loss decrease is proportional to the squared gradient, and is controlled by the quadratic η(2ηℓ) in η. This function is maximized at η = 1/ℓ, a popular learning rate criterion. For any η < 2/ℓ, the descent lemma guarantees that the loss will decrease. As a result, learning rates below 2/ℓ are considered "stable" while those above 2/ℓ are considered "unstable." For quadratic loss functions, e.g. from linear regression, this is tight. Any learning rate above 2/ℓ provably leads to exponentially increasing loss. However, it has recently been observed that in neural networks, the descent lemma is not predictive of the optimization dynamics. Recently, Cohen et al. (2021) observed two important phenomena for gradient descent, which made more precise similar observations in Jastrzębski et al. (2019) ; Jastrzebski et al. (2020) for SGD: Progressive Sharpening Throughout most of the optimization trajectory, the gradient of the loss is negatively aligned with the gradient of sharpness, i.e. ∇L(θ) • ∇S(θ) < 0. As a result, for any reasonable learning rate η, the sharpness increases throughout training until it reaches S(θ) = 2/η.

Edge of Stability

Once the sharpness reaches 2/η (the "break-even" point in Jastrzebski et al. (2020) ), it ceases to increase and remains around 2/η for the rest of training. The descent lemma no longer guarantees the loss decreases but the loss still continues decreasing, albeit non-monotonically. In this work we explain the second stage, "edge of stability." We identify a new implicit bias of gradient descent which we call self-stabilization. Self-stabilization is the mechanism by which the sharpness remains bounded around 2/η, despite the continued force of progressive sharpening, and by which the gradient descent dynamics do not diverge, despite instability. Unlike progressive sharpening, which is only true for specific loss functions (e.g. those resulting from neural network optimization (Cohen et al., 2021) ), self stabilization is a general property of gradient descent. Traditional non-convex optimization analyses involve Taylor expanding the loss function to second order around θ to prove loss decrease when η ≤ 2/S(θ). When this is violated, the iterates diverge exponentially in the top eigenvector direction, u, thus leaving the region in which the loss function is locally quadratic. Understanding the dynamics thus necessitates a cubic Taylor expansion. Our key insight is that the missing term in the Taylor expansion of the gradient after diverging in the u direction is ∇ 3 L(θ)(u, u), which is conveniently equal to the gradient of the sharpness at θ: Lemma 2 (Self-Stabilization Property). If the top eigenvalue of ∇ 2 L(θ) is unique, then the sharpness S(θ) is differentiable at θ and ∇S(θ) = ∇ 3 L(θ)(u(θ), u(θ)). As the iterates move in the negative gradient direction, this term has the effect of decreasing the sharpness. The story of self-stabilization is thus that as the iterates diverge in the u direction, the strength of this movement in the -∇S(θ) direction grows until it forces the sharpness below 2/η, at which point the iterates in the u direction shrink and the dynamics re-enter the quadratic regime. This negative feedback loop prevents both the sharpness S(θ) and the movement in the top eigenvector direction, u, from growing out of control. As a consequence, we show that gradient descent implicitly solves the constrained minimization problem: min θ L(θ) such that S(θ) ≤ 2/η. Specifically, if the stable set M is defined by M := {θ : S(θ) ≤ 2/η and ∇L(θ) • u(θ) = 0}foot_0 then the gradient descent trajectory {θ t } tracks the following projected gradient descent trajectory {θ † t } which solves the constrained problem (Barber & Ha, 2017) : θ † t+1 = proj M θ † t -η∇L(θ † t ) where proj M (θ) := arg min θ ′ ∈M ∥θ -θ ′ ∥. Our main contributions are as follows. First, we explain self-stabilization as a generic property of gradient descent for a large class of loss functions, and provide precise predictions for the loss, sharpness, and deviation from the constrained trajectory {θ † t } throughout training (Section 4). Next, we prove that under mild conditions on the loss function (which we verify empirically for standard architectures and datasets), our predictions track the true gradient descent dynamics up to higher order error terms (Section 5). Finally, we verify our predictions by replicating the experiments in Cohen et al. (2021) and show that they model the true gradient descent dynamics (Section 6). 2 RELATED WORK Xing et al. (2018) observed that for some neural networks trained by full-batch gradient descent, the loss is not monotonically decreasing. Wu et al. (2018) remarked that gradient descent cannot converge to minima where the sharpness exceeds 2/η but did not give a mechanism for avoiding such minima. Lewkowycz et al. (2020) observed that when the initial sharpness is larger than 2/η, gradient descent "catapults" into a stable region and eventually converges. Jastrzębski et al. (2019) studied the sharpness along stochastic gradient descent trajectories and observed an initial increase (i.e. progressive sharpening) followed by a peak and eventual decrease. They also observed interesting relationships between the dynamics in the top eigenvector direction and the sharpness. Jastrzebski et al. ( 2020) conjectured a general characterization of stochastic gradient descent dynamics asserting that the sharpness tends to grow but cannot exceed a stability criterion given by their eq (1), which reduces to S(θ) ≤ 2/η in the case of full batch training. Cohen et al. (2021) demonstrated that for the special case of (full batch) gradient descent training, the optimization dynamics exhibit a simple characterization. First, the sharpness rises until it reaches S(θ) = 2/η at which point the dynamics transition into an "edge of stability" (EOS) regime where the sharpness oscillates around 2/η and the loss continues to decrease, albeit non-monotonically. Recent works have sought to provide theoretical analyses for the EOS phenomenon. Ma et al. (2022) analyzes EOS when the loss satisfies a "subquadratic growth" assumption. Ahn et al. (2022) argues that unstable convergence is possible when there exists a "forward invariant subset" near the set of minimizers. Arora et al. (2022) analyzes progressive sharpening and the EOS phenomenon for normalized gradient descent close to the manifold of global minimizers. Lyu et al. (2022) uses the EOS phenomenon to analyze the effect of normalization layers on sharpness for scale-invariant loss functions. Chen & Bruna (2022) show global convergence despite instability for certain 2D toy problems and in a 1-neuron student-teacher setting. The concurrent work Li et al. (2022b) proves progressive sharpening for a two-layer network and analyzes the EOS dynamics through four stages similar to ours using the norm of the output layer as a proxy for sharpness. Beyond the EOS phenomenon itself, prior work has also shown that SGD with large step size or small batch size will lead to a decrease in sharpness (Keskar et al., 2017; Jastrzebski et al., 2017; Jastrzębski et al., 2019; Jastrzebski et al., 2020) . Gilmer et al. (2021) also describes connections between EOS, learning rate warm-up, and gradient clipping. At a high level, our proof relies on the idea that oscillations in an unstable direction prescribed by the quadratic approximation of the loss cause a longer term effect arising from the third-order Taylor expansion of the dynamics. This overall idea has also been used to analyze the implicit regularization of SGD (Blanc et al., 2020; Damian et al., 2021; Li et al., 2022a) . In those settings, oscillations come from the stochasticity, while in our setting the oscillations stem from instability.

3. SETUP

We denote the loss function by L ∈ C 3 (R d ). Let θ ∈ R d follow gradient descent with learning rate η, i.e. θ t+1 := θ t -η∇L(θ t ). Recall that M := {θ : S(θ) ≤ 2/η and ∇L(θ) • u(θ) = 0} is the set of stable points and proj M := arg min θ ′ ∈M ∥θθ ′ ∥ is the orthogonal projection onto M. For notational simplicity, we will shift time so that θ 0 is the first point such that S(proj M (θ)) = 2/η. The constrained trajectory θ † is initialized with θ † 0 := proj M (θ 0 ) after which it follows eq. ( 2). Our key assumption is the existence of progressive sharpening along the constrained trajectory, which is captured by the progressive sharpening coefficient α(θ We focus on the regime in which there is a single unstable eigenvalue, and we leave understanding multiple unstable eigenvalues to future work. We thus make the following assumption on ) := -∇L(θ) • ∇S(θ): Assumption 1 (Progressive Sharpening). Let α(θ) := -∇L(θ) • ∇S(θ). Then α(θ † t ) > 0. x = u • (θ -θ ) Change in sharpness y = ∇S • (θ -θ ) θ Stage 1: Progressive Sharpening x = u • (θ -θ ) θ Stage 2: Blowup x = u • (θ -θ ) -∇ 3 L(u, u) = -∇S θ θ Stage 3: Self-Stabilization x = u • (θ -θ ) θ θ ∇S ∇ 2 L(θ † t ): Assumption 2 (Eigengap). For some absolute constant c < 2 we have λ 2 (∇ 2 L(θ † t )) < c/η.

4. THE SELF-STABILIZATION PROPERTY OF GRADIENT DESCENT

In this section, we derive a set of equations that predict the displacement between the gradient descent trajectory {θ t } and the constrained trajectory {θ † t }. Viewed as a dynamical system, these equations give rise to a negative feedback loop, which prevents both the sharpness and the displacement in the unstable direction from diverging. These equations also allow us to predict the values of the sharpness and the loss throughout the gradient descent trajectory.

4.1. THE FOUR STAGES OF EDGE OF STABILITY: A HEURISTIC DERIVATION

The analysis in this section proceeds by a cubic Taylor expansion around a fixed reference point θ ⋆ := θ † 0 .foot_1 For notational simplicity, we will define the following quantities at θ ⋆ : ∇L := ∇L(θ ⋆ ), ∇ 2 L := ∇ 2 L(θ ⋆ ), u := u(θ ⋆ ), ∇S := ∇S(θ ⋆ ), α := α(θ ⋆ ), β := ∥∇S∥ 2 , where α = -∇L•∇S > 0 is the progressive sharpening coefficient at θ ⋆ . For simplicity, in Section 4 we assume that ∇S ⊥ u and ∇L, ∇S ∈ ker(∇ 2 L), and ignore higher order error terms. 3 Our main argument in Section 5 does not require these assumptions and tracks all error terms explicitly. We want to track the movement in the unstable direction u and the direction of changing sharpness ∇S, and thus define x t := u • (θ tθ ⋆ ) and y t := ∇S • (θ tθ ⋆ ). Note that y t is approximately equal to the change in sharpness from θ ⋆ to θ t , since Taylor expanding the sharpness yields S(θ t ) ≈ S(θ ⋆ ) + ∇S • (θ t -θ ⋆ ) = 2/η + y t . At a high level, the mechanism for edge of stability can be described in 4 stages (see Figure 1 ): Stage 1: Progressive Sharpening While x, y are small, ∇L(θ t ) ≈ ∇L. In addition, because ∇L • ∇S < 0, gradient descent naturally increases the sharpness at every step. In particular, y t+1 -y t = ∇S • (θ t+1 -θ t ) ≈ -η∇L • ∇S = ηα. The sharpness therefore increases linearly with rate ηα. Stage 2: Blowup As x t measures the deviation from θ ⋆ in the u direction, the dynamics of x t can be modeled by gradient descent on a quadratic with sharpness S(θ t ) ≈ 2/η + y t . In particular, the rule for gradient descent on a quadratic givesfoot_3  x t+1 = x t -ηu • ∇L(θ t ) ≈ x t -ηS(θ t )x t ≈ x t -η[2/η + y t ]x t = -(1 + ηy t )x t . When the sharpness exceeds 2/η, i.e. when y t > 0, |x t | begins to grow exponentially. X(t) Y (t) X(t) Y (t) Stage 1 Stage 2 Stage 3 Stage 4 t X(t) t Y (t) Figure 2 : The effect of X(0) (left): We plot the evolution of the ODE in eq. ( 4) with α = β = 1 for varying X(0). Observe that smaller X(0)'s correspond to larger curves. The four stages of edge of stability (right): We show how the four stages of edge of stability described in Section 4.1 and Figure 1 correspond to different parts of the curve generated by the ODE in eq. ( 4). Stage 3: Self-Stabilization Once the movement in the u direction is sufficiently large, the loss is no longer locally quadratic. Understanding the dynamics necessitates a third order Taylor expansion. The missing cubic term in the Taylor expansion of ∇L(θ t ) is ∇ 3 L(u, u) x 2 t 2 = ∇S x 2 t 2 by Lemma 2. This biases the optimization trajectory in the -∇S direction, which decreases sharpness. Recalling β = ∥∇S∥ 2 , the new update for y becomes: y t+1 -y t = ηα + ∇S • -η∇ 3 L(u, u) x 2 t 2 = η α -β x 2 t 2 Therefore once x t > 2α/β, the sharpness begins to decrease and continues to do so until the sharpness goes below 2/η and the dynamics return to stability. Stage 4: Return to Stability At this point |x t | is still large from stages 1 and 2. However, the selfstabilization of stage 3 eventually drives the sharpness below 2/η so that y t < 0. Because the rule for gradient descent on a quadratic with sharpness S(θ t ) = 2/η +y t < 2/η is x t+1 ≈ -(1+ηy t )x t , |x t | begins to shrink exponentially and the process returns to stage 1. Combining the update for x t , y t in all four stages, we obtain the following simplified dynamics: x t+1 ≈ -(1 + ηy t )x t and y t+1 ≈ y t + η α -β x 2 t 2 where we recall α = -∇L • ∇S is the progressive sharpening coefficient and β = ∥∇S∥ 2 .

4.2. ANALYZING THE SIMPLIFIED DYNAMICS

We now analyze the dynamics in eq. (3). First, note that x t changes sign at every iteration, and that, x t+1 ≈ -x t due to the instability in the u direction. While eq. ( 3) cannot be directly modeled by an ODE due to these rapid oscillations, we can instead model |x t |, y t , whose update is controlled by η. As a consequence, we can couple the dynamics of |x t |, y t to the following ODE X(t), Y (t): X ′ (t) = X(t)Y (t) and Y ′ (t) = α -β X(t) 2 2 . ( ) This system has the unique fixed point (X, Y ) = (δ, 0) where δ := 2α/β. We also note that this ODE can be written as a Lotka-Volterra predator-prey model after a change of variables, which is a classical example of a negative feedback loop. In particular, the following quantity is conserved: Lemma 3. Let h(z) := z -log z -1. Then g(X(t), Y (t)) := h βX(t) 2 2α + Y (t) 2 α is conserved. Proof. d dt g(X(t), Y (t)) = βX(t) 2 Y (t) α -2Y (t) + 2 α Y (t) α -β X(t) 2 2 = 0. As a result we can use the conservation of g to explicitly bound the size of the trajectory: Corollary 1. For all t, X(0) ≤ X(t) ≲ δ log(δ/X(0)) and |Y (t)| ≲ α log(δ/X(0)). The fluctuations in sharpness are Õ( √ α), while the fluctuations in the unstable direction are Õ(δ). Moreover, the normalized displacement in the ∇S direction, i.e. ∇S ∥∇S∥ • (θθ ⋆ ) is also bounded by Õ(δ), so the entire process remains bounded by Õ(δ). Note that the fluctuations increase as the progressive sharpening constant α grows, and decrease as the self-stabilization force β grows. 4.3 RELATIONSHIP WITH THE CONSTRAINED TRAJECTORY θ † t Equation (3) completely determines the displacement θ t -θ ⋆ in the u, ∇S directions and Section 4.2 shows that these dynamics remain bounded by Õ(δ) where δ = 2α/β. However, progress is still made in all other directions. Indeed, θ t evolves in these orthogonal directions by -ηP ⊥ u,∇S ∇L at every step where P ⊥ u,∇S is the projection onto this orthogonal subspace. This can be interpreted as first taking a gradient step of -η∇L and then projecting out the ∇S direction to ensure the sharpness does not change. Lemma 13, given in the Appendix, shows that this is precisely the update for θ † t (eq. ( 2)) up to higher order terms. The preceding derivation thus implies that ∥θ tθ † t ∥ ≤ Õ(δ) and that this Õ(δ) error term is controlled by the self-stabilizing dynamics in eq. ( 3).

5. THE PREDICTED DYNAMICS AND THEORETICAL RESULTS

We now present the equations governing edge of stability for general loss functions.

5.1. NOTATION

Our general approach Taylor expands the gradient of each iterate θ t around the corresponding iterate θ † t of the constrained trajectory. We define the following Taylor expansion quantities at θ † t : Definition 2 (Taylor Expansion Quantities at θ † t ). ∇L t := ∇L(θ † t ), ∇ 2 L t := ∇ 2 L(θ † t ), ∇ 3 L t := ∇ 3 L(θ † t ), ∇S t := ∇S(θ † t ), u t := u(θ † t ) . Furthermore, for any vector-valued function v(θ), we define v ⊥ t := P ⊥ ut v(θ † t ) where P ⊥ ut is the projection onto the orthogonal complement of u t . We also define the following quantities which govern the dynamics near θ ⋆ t . Definition 3. Let α t := -∇L t • ∇S t , β t := ∇S ⊥ t 2 , and δ t := 2αt βt . Furthermore, we define β s→t := ∇S ⊥ t+1 s+1 k=t (I -η∇ 2 L k )P ⊥ u k ∇S ⊥ s and δ := sup t δ t . Recall that α t is the progressive sharpening force, β t is the strength of the stabilization force, and δ t controls the size of the deviations from θ † t and was the fixed point in the x direction in Section 4.2. The scalars β s→t capture the effect of the interactions between ∇S and the Hessian.

5.2. THE EQUATIONS GOVERNING EDGE OF STABILITY

We now introduce the equations governing edge of stability. We track the following quantities: Definition 4. Define v t := θ t -θ † t , x t := u t • v t , y t := ∇S ⊥ t • v t . Our predicted dynamics directly predict the displacement v t and the full definition is deferred to Appendix C. However, they have a relatively simple form in the u t , ∇S ⊥ t directions: Lemma 4 (Predicted Dynamics for x, y). Let ⋆ x s 2 2 . ( ) Note that when β s→t are constant, our update reduces to the simple case discussed in Section 4, which we analyze fully. When x t is large, eq. ( 5) demonstrates that there is a self-stabilization force which acts to decrease y t ; however, unlike in Section 4, the strength of this force changes with t.

5.3. COUPLING THEOREM

We now show that, under a mild set of assumptions which we verify to hold empirically in Appendix E, the true dynamics are accurately governed by the predicted dynamics. This lets us use the predicted dynamics to predict the loss, sharpness, and the distance to the constrained trajectory θ † t . Our errors depend on the unitless quantity ϵ, which we verify is small in Appendix E. Definition 5. Let ϵ t := η √ α t and ϵ := sup t ϵ t . To control Taylor expansion errors, we require upper bounds on ∇ 3 L and its Lipschitz constant:foot_6 Assumption 3. Let ρ 3 , ρ 4 to be the minimum constants such that for all θ, ∇ 3 L(θ) op ≤ ρ 3 and ∇ 3 L is ρ 4 -Lipschitz with respect to ∥•∥ op . Then we assume that ρ 4 = O(ηρfoot_5 3 ). Next, we require the following generalization of Assumption 1: Assumption 4. For all t, -∇Lt•∇St ∥∇Lt∥∥∇S ⊥ t ∥ = Θ(1) and ∇S ⊥ t = Θ(ρ 3 ). Finally, we require a set of "non-worst-case" assumptions, which are that the quantities ∇ 2 L, ∇ 3 L, and λ min (∇ 2 L) are nicely behaved in the directions orthogonal to u t , which generalizes the eigengap assumption. We verify the assumptions on ∇ 2 L and ∇ 3 L empirically in Appendix E. Assumption 5. For all t and v, w ⊥ u t , ∥∇ 3 Lt(v,w)∥ ∥∇ 3 Lt∥ op ∥v∥∥w∥ , |∇ 2 Lt( ⋆ v ⊥ t , ⋆ v ⊥ t )| ∥∇ 2 Lt∥∥ ⋆ v ⊥ t ∥

6. EXPERIMENTS

We verify that the predicted dynamics defined in eq. ( 5) accurately capture the dynamics of gradient descent at the edge of stability by replicating the experiments in (Cohen et al., 2021) and tracking the deviation of gradient descent from the constrained trajectory. In Figure 3 , we evaluate our theory on a 3-layer MLP and a 3-layer CNN trained with mean squared error (MSE) on a 5k subset of CIFAR10 and a 2-layer Transformer (Vaswani et al., 2017) trained with MSE on SST2 Socher et al. (2013) . We provide additional experiments varying the learning rate and loss function in Appendix G, which use the generalized predicted dynamics described in Section 7.2. For additional details, see Appendix D. Figure 3 confirms that the predicted dynamics eq. ( 5) accurately predict the loss, sharpness, and distance from the constrained trajectory. In addition, while the gradient flow trajectory diverges from the gradient descent trajectory at a linear rate, the gradient descent trajectory and the constrained trajectories remain close throughout training. In particular, the dynamics converge to the fixed point (|x t |, y t ) = (δ t , 0) described in Section 4.2 and ∥θ tθ † t ∥ → δ t . This confirms our claim that gradient descent implicitly follows the constrained trajectory eq. ( 2). In Section 5, various assumptions on the model were made to obtain the EOS behavior. In Appendix E, we numerically verify these assumptions to ensure the validity of our theory.

Gradient Descent Predicted Dynamics Constrained Trajectory Gradient Flow

Predicted We empirically demonstrate that the predicted dynamics given by eq. ( 5) track the true EOS dynamics. For each learning rate, the top row is a zoomed in copy of the bottom row which isolates one cycle and is represented by the dashed rectangle. Reported sharpnesses are two-step averages for visual clarity. For additional experimental details, see Section 6 and Appendix D.

7.1. TAKEAWAYS FROM THE PREDICTED DYNAMICS

A key consequence of the predicted dynamics is that the loss and sharpness only depend on ( ⋆ x t , ⋆ y t ), which are governed by the 2D dynamical system eq. ( 5). Therefore, understanding the EOS dynamics only requires analyzing this dynamical system, which is generally well behaved (Figure 3 ). Furthermore, we expect ⋆ x t , ⋆ y t to converge to (±δ t , 0), the fixed point of eq. ( 5). In fact, Figure 3 shows that after a few cycles, ( ⋆ x t , ⋆ y t ) indeed converges to this fixed point. We can accurately predict its location as well as the loss increase from the constrained trajectory due to ⋆ x t ̸ = 0.

7.2. GENERALIZED PREDICTED DYNAMICS

In order for our cubic Taylor expansions to track the true gradients, we require a bound on the fourth derivative of the loss (Assumption 3). This is usually sufficient to capture the EOS dynamics as demonstrated by Figure 3 and Appendix E. However, this condition was violated in some of our experiments, especially when using logistic loss. To overcome this challenge, we developed a generalized form of the predicted dynamics whose definition we defer to Appendix F. These generalized predictions are qualitatively similar to those given by the predicted dynamics in Section 5; however, they precisely track the dynamics of gradient descent in a wider range of settings (see Appendix G).

7.3. IMPLICATIONS FOR NEURAL NETWORK TRAINING

Non-Monotonic Loss Decrease An important property of EOS is that the loss decreases over long time scales, albeit non-monotonically. Our theory provides a clear explanation for this phenomenon. We show that the gradient descent trajectory remains close to the constrained trajectory (Sections 4 and 5). Since the constrained trajectory is stable, it satisfies a descent lemma (Lemma 14), and its loss monotonically decreases. Over short time periods, the loss is dominated by the rapid fluctuations of x t described in Section 4. Over longer time periods, the loss decrease of the constrained trajectory overpowers the bounded fluctuations of x t , leading to an overall loss decrease. Generalization & the Role of Large Learning Rates Prior work has shown that in neural networks, both decreasing sharpness of the learned solution (Keskar et al., 2017; Dziugaite & Roy, 2017; Neyshabur et al., 2017; Jiang et al., 2020) and increasing the learning rate (Smith et al., 2018; Li et al., 2019; Lewkowycz et al., 2020) are correlated with better generalization. Our analysis shows that gradient descent implicitly constrains the sharpness to stay below 2/η, which suggests larger learning may improve generalization by reducing the sharpness. In Figure 4 we confirm that in a standard setting, full-batch gradient descent generalizes better with large learning rates. Training Speed Additional experiments in (Cohen et al., 2021, Appendix F) show that, despite the instability in the training process, larger learning rates lead to faster convergence. This phenomenon is explained by our analysis. Gradient descent is coupled to the constrained trajectory which minimizes the loss while constraining movement in the u t , ∇S ⊥ t directions. Since only two directions are "off limits," the constrained trajectory can still move quickly in the orthogonal directions, using the large learning rate to accelerate convergence. We demonstrate this empirically in Figure 4 . We defer additional discussion of our work, including the effect of multiple unstable eigenvalues and connections to Sharpness Aware Minimization (Foret et al., 2021) , warm-up (Gilmer et al., 2021) , and scale-invariant loss functions (Lyu et al., 2022) to Appendix H.

7.4. FUTURE WORK

An important direction for future work is understanding the EOS dynamics when there are multiple unstable eigenvalues, which we briefly discuss in Appendix H. Another interesting direction is understanding the global convergence properties at EOS, including convergence to a KKT point of the constrained update eq. ( 2). Next, our analysis focused on the EOS dynamics but left open the question of why neural networks exhibit progressive sharpening. Finally, we would like to understand the role of self-stabilization in stochastic-gradient descent and how it interacts with the implicit biases of SGD (Blanc et al., 2020; Damian et al., 2021; Li et al., 2022a) . 

A NOTATION

We denote by ∇ k L(θ) the k-th order derivative of the loss L at θ. Note that ∇ k L(θ) is a symmetric k-tensor in (R d ) ⊗k when θ ∈ R d . For a symmetric k-tensor T , and vectors u 1 , . . . , u j ∈ R d we will use T (u 1 , . . . , u j ) to denote the tensor contraction of T with u 1 , . . . , u j , i.e. [T (u 1 , . . . , u k )] i1,...,i k-j := T i1,...,i k (u 1 ) i k-j+1 • • • (u j ) i k . We use P u1,...,u k to denote the orthogonal projection onto span(u 1 , . . . , u k ) and P ⊥ u1,...,u k is the projection onto the corresponding orthogonal complement. For matrices A 1 , . . . , A k , we define t k=1 A k := A 1 . . . A t and 1 k=t A k := A t . . . A 1 . B A TOY MODEL FOR SELF-STABILIZATION For α, β > 0, consider the function L(x, y, z) := 2 η + βy x 2 2 - α √ β y -z initialized at the point (x 0 , 0, 0). Note that the constrained trajectory will follow x † t = 0, y † t = 0, z † t = -ηt as it cannot decrease y without increasing the sharpness past 2/η. We therefore have: ∇L t = 0, - α √ β , 1 , u t = [1, 0, 0], S t = 2/η + βy, ∇ 2 L t = S t u t u t t , ∇S t = 0, β, 0 . Note that this satisfies all of the assumptions in Section 4 and it satisfies α = -∇L t • ∇S t and β = ∥∇S t ∥ 2 . This process will then follow eq. ( 4) in the x, y directions while it tracks the constrained trajectory θ † t moving linearly in the -P ⊥ u,∇S ∇L = [0, 0, -1] direction.

C DEFINITION OF THE PREDICTED DYNAMICS

Below, we present the full definition of the predicted dynamics: Published as a conference paper at ICLR 2023 Definition 6 (Predicted Dynamics, full). Define ⋆ v 0 = v 0 , and let ⋆ x t = ⋆ v t • u t , ⋆ y t = ∇S ⊥ • ⋆ v t . Then v * t+1 = P ⊥ ut+1 (I -η∇ 2 L t )P ⊥ ut v * t + ηP ⊥ ut+1 ∇S ⊥ t δ 2 t -x * t 2 2 -(1 + ηy * t )x * t • u t+1 (6) For convenience, we will define the map step t : R d → R d as follows: Definition 7. Given a vector v and a timestep t, define step t (v) by P ⊥ ut+1 step t (v) = P ⊥ ut+1 (I -η∇ 2 L t )P ⊥ ut v + η∇S ⊥ t δ 2 t -x 2 2 (7) u t+1 • step t (v) = -(1 + ηy)x. ( ) where x = u t • v and y = ∇S ⊥ t • v. It is easy to see that ⋆ v t+1 = step t ( ⋆ v t ). Proof of Lemma 4. Defining A t = (I -η∇ 2 L t )P ⊥ ut , we can unfold the recursion in eq. ( 6) to obtain the following formula for ⋆ v t . v * t+1 = η t s=0 P ⊥ ut+1 s+1 k=t A k ∇S ⊥ s δ 2 s -x * s 2 2 -(1 + ηy * t )x * t • u t+1 . It is then immediate to see that ⋆ x t = ⋆ v t • u t , ⋆ y t = ∇S ⊥ t • ⋆ v t have the following simple update: x * t+1 = -(1 + ηy * t )x * t and y * t+1 = η t s=0 β s→t δ 2 s -x * s 2 2 , where we recall that we have defined β s→t := ∇S ⊥ t+1 s+1 k=t A k ∇S ⊥ s .

D EXPERIMENTAL DETAILS D.1 ARCHITECTURES

We evaluated our theory on four different architectures. The 3-layer MLP and CNN are exact copies of the MLP and CNN used in (Cohen et al., 2021) . The MLP has width 200, the CNN has width 32, and both are using the swish activation (Ramachandran et al., 2017) . We also evaluate on a ResNet18 with progressive widths 16, 32, 64, 128 and on a 2-layer Transformer with hidden dimension 64 and two attention heads.

D.2 DATA

We evaluated our theory on three primary tasks: CIFAR10 multi-class classification with both categorical MSE loss and cross-entropy loss, CIFAR10 binary classification (cats vs dogs) with binary MSE loss and logistic loss, and SST2 (Socher et al., 2013) with binary MSE loss and logistic loss.

D.3 EXPERIMENTAL SETUP

For every experiment, we tracked the gradient descent dynamics until they reached instability and then began tracking the constrained trajectory, gradient descent, gradient flow, and both our predicted dynamics (Section 5) and our generalized predicted dynamics (Appendix F). In addition, we tracked the various quantities on which we made assumptions for Section 5 in order to validate these assumptions. We also tracked the second eigenvalue of the Hessian at the constrained trajectory throughout training and stopped training once it reached 1.9/η, to ensure the existence of a single unstable eigenvalue. Finally, as the edge of stability dynamics are very sensitive to small perturbation when |x| is small (see Figure 2 ), we switched to computing gradients with 64-bit precision after first reaching instability to avoid propagating floating point errors. Eigenvalues were computed using the LOBPCG sparse eigenvalue solver in JAX (Bradbury et al., 2018) . To compute the constrained trajectory, we computed a linearized approximation for proj M inspired by Lemma 13 along with a Newton step in the u t direction to ensure that ∇L • u = 0. Each linearized approximation step required recomputing the sharpness and top eigenvector and each projection step then consisted of three linearized projection steps, for a total of three eigenvalue computations per projection step. Our experiments were conducted in JAX (Bradbury et al., 2018) , using https://github.com/ locuslab/edge-of-stability as a reference for replicating the experimental setup used in (Cohen et al., 2021) . All experiments were conducted on two servers, each with 10 NVIDIA GPUs. Our code can be found at https://github.com/adamian98/EOS.

E EMPIRICAL VERIFICATION OF THE ASSUMPTIONS

For each of the experimental settings considered (MLP+MSE, CNN+MSE, CNN+Logistic, ResNet18+MSE, Transformer+MSE, Transformer+Logistic), we plot a number of quantities along the constrained trajectory to verify that the assumptions made in the main text hold. For each learning rate η we have 8 plots tracking various quantities, which verify the assumptions as follows: Assumption 1 is verified by the 1st plot, ϵ being small is verified by the 2nd plot, Assumption 4 is verified by the 3rd and 4th plots, Assumption 3 is verified by the 5th plot, and Assumption 5 is verified by the last 3 plots. As described in the experimental setup, training is stopped once the second eigenvalue is 1.9/η, so Assumption 2 always holds with c = 1.9 as well.

MLP+MSE on CIFAR10

0 100  α t := -∇L t • ∇S t 0 100 0.00 0.05 ε t := η √ α t 0 100 0.0 0.5 1.0 -∇L t •∇S t ∇L t ∇S ⊥ t 0 100 0.0 0.5 1.0 ∇S ⊥ t ρ 3 0 100 ρ 4 ηρ 2 3 0 100 0.0 0.5 1.0 ∇ 3 L t (v ⊥ t ,v ⊥ t ) ∇ 3 L t op v ⊥ t 2 0 100 0.0 0.5 1.0 ∇ 3 L t (∇L t ,∇L t ) ∇ 3 L t op ∇L t 2 0 100 0.0 0.5 1.0 ∇ 2 L t (v ⊥ t ,v ⊥ t ) ∇ 2 L t v ⊥ t 2 η = 0. ∇ 3 L t (v ⊥ t ,v ⊥ t ) ∇ 3 L t op v ⊥ t 2 0 100 200 0.0 0.5 1.0 ∇ 3 L t (∇L t ,∇L t ) ∇ 3 L t op ∇L t 2 0 100 200 0.0 0.5 1.0 ∇ 2 L t (v ⊥ t ,v ⊥ t ) ∇ 2 L t v ⊥ t 2 η = 0.01 0 500 α t := -∇L t • ∇S t 0 500 0.00 0.02 ε t := η √ α t 0 500 0.0 0.5 1.0 -∇L t •∇S t ∇L t ∇S ⊥ t 0 500 0.0 0.5 1.0 ∇S ⊥ t ρ 3 0 500 ρ 4 ηρ 2 3 0 500 0.0 0.5 1.0 ∇ 3 L t (v ⊥ t ,v ⊥ t ) ∇ 3 L t op v ⊥ t 2 0 500 0.0 0.5 1.0 ∇ 3 L t (∇L t ,∇L t ) ∇ 3 L t op ∇L t 2 0 500 0.0 0.5 1.0 ∇ 2 L t (v ⊥ t ,v ⊥ t ) ∇ 2 L t v ⊥ t 2 η = 0. ∇ 3 L t (v ⊥ t ,v ⊥ t ) ∇ 3 L t op v ⊥ t 2 0 1000 2000 0.0 0.5 1.0 ∇ 3 L t (∇L t ,∇L t ) ∇ 3 L t op ∇L t 2 0 1000 2000 0.0 0.5 1.0 ∇ 2 L t (v ⊥ t ,v ⊥ t ) ∇ 2 L t v ⊥ ∇ 3 L t (v ⊥ t ,v ⊥ t ) ∇ 3 L t op v ⊥ t 2 0 200 400 0.0 0.5 1.0 ∇ 3 L t (∇L t ,∇L t ) ∇ 3 L t op ∇L t 2 200 400 0.0 0.5 1.0 ∇ 2 L t (v ⊥ t ,v ⊥ t ) ∇ 2 L t v ⊥ t 2 η = 0.02 0 1000 0 20 α t := -∇L t • ∇S t 0 1000 0.00 0.02 ε t := η √ α t 0 1000 0.0 0.5 1.0 -∇L t •∇S t ∇L t ∇S ⊥ t 1000 0.0 0.5 1.0 ∇S ⊥ t ρ 3 0 1000 0.0 0.5 ρ 4 ηρ 2 3 0 1000 0.0 0.5 1.0 ∇ 3 L t (v ⊥ t ,v ⊥ t ) ∇ 3 L t op v ⊥ t 2 0 1000 0.0 0.5 1.0 ∇ 3 L t (∇L t ,∇L t ) ∇ 3 L t op ∇L t 2 1000 0.0 0.5 1.0 ∇ 2 L t (v ⊥ t ,v ⊥ t ) ∇ 2 L t v ⊥ ∇ 3 L t (v ⊥ t ,v ⊥ t ) ∇ 3 L t op v ⊥ t 2 0 1000 2000 0.0 0.5 1.0 ∇ 3 L t (∇L t ,∇L t ) ∇ 3 L t op ∇L t 2 1000 2000 0.0 0.5 1.0 ∇ 2 L t (v ⊥ t ,v ⊥ t ) ∇ 2 L t v ⊥ ∇ 3 L t (v ⊥ t ,v ⊥ t ) ∇ 3 L t op v ⊥ t 2 0 1000 2000 0.0 0.5 1.0 ∇ 3 L t (∇L t ,∇L t ) ∇ 3 L t op ∇L t 2 0 1000 2000 0.0 0.5 1.0 ∇ 2 L t (v ⊥ t ,v ⊥ t ) ∇ 2 L t v ⊥ t 2 η = 0. ∇ 3 L t (v ⊥ t ,v ⊥ t ) ∇ 3 L t op v ⊥ t 2 0 1000 2000 0.0 0.5 1.0 ∇ 3 L t (∇L t ,∇L t ) ∇ 3 L t op ∇L t 2 0 1000 2000 0.0 0.5 1.0 ∇ 2 L t (v ⊥ t ,v ⊥ t ) ∇ 2 L t v ⊥ t 2 η = 0. ∇ 3 L t (v ⊥ t ,v ⊥ t ) ∇ 3 L t op v ⊥ t 2 0 1000 2000 0.0 0.5 1.0 ∇ 3 L t (∇L t ,∇L t ) ∇ 3 L t op ∇L t 2 0 1000 2000 0.0 0.5 1.0 ∇ 2 L t (v ⊥ t ,v ⊥ t ) ∇ 2 L t v ⊥ t 2 η = 0.005 ∇ 3 L t (v ⊥ t ,v ⊥ t ) ∇ 3 L t op v ⊥ t 2 0 1000 2000 0.0 0.5 1.0 ∇ 3 L t (∇L t ,∇L t ) ∇ 3 L t op ∇L t 2 0 1000 2000 0.0 0.5 1.0 ∇ 2 L t (v ⊥ t ,v ⊥ t ) ∇ 2 L t v ⊥ ∇ 3 L t (v ⊥ t ,v ⊥ t ) ∇ 3 L t op v ⊥ t 2 0 500 1000 0.0 0.5 1.0 ∇ 3 L t (∇L t ,∇L t ) ∇ 3 L t op ∇L t 2 0 500 1000 0.0 0.5 1.0 ∇ 2 L t (v ⊥ t ,v ⊥ t ) ∇ 2 L t v ⊥ ∇ 3 L t (v ⊥ t ,v ⊥ t ) ∇ 3 L t op v ⊥ t 2 0 500 1000 0.0 0.5 1.0 ∇ 3 L t (∇L t ,∇L t ) ∇ 3 L t op ∇L t 2 0 500 1000 0.0 0.5 1.0 ∇ 2 L t (v ⊥ t ,v ⊥ t ) ∇ 2 L t v ⊥ t η = 0.001

Transformer+MSE on SST2

∇ 3 L t (v ⊥ t ,v ⊥ t ) ∇ 3 L t op v ⊥ t 2 0 1000 0.0 0.5 1.0 ∇ 3 L t (∇L t ,∇L t ) ∇ 3 L t op ∇L t 2 0 0.0 0.5 1.0 ∇ 2 L t (v ⊥ t ,v ⊥ t ) ∇ 2 L t v ⊥ t η = 0.002 0 1000 0 100 α t := -∇L t • ∇S t 0 1000 0.00 0.01 ε t := η √ α t 0 1000 0.0 0.5 1.0 -∇L t •∇S t ∇L t ∇S ⊥ t 0 0.0 0.5 1.0 ∇S ⊥ t ρ 3 0 1000 0.0 2.5 ×10 6 ρ 4 ηρ 2 3 0 1000 0.0 0.5 1.0 ∇ 3 L t (v ⊥ t ,v ⊥ t ) ∇ 3 L t op v ⊥ t 2 0 1000 0.0 0.5 1.0 ∇ 3 L t (∇L t ,∇L t ) ∇ 3 L t op ∇L t 2 0 0.0 0.5 1.0 ∇ 2 L t (v ⊥ t ,v ⊥ t ) ∇ 2 L t v ⊥ ∇ 3 L t (v ⊥ t ,v ⊥ t ) ∇ 3 L t op v ⊥ t 2 0 1000 2000 0.0 0.5 1.0 ∇ 3 L t (∇L t ,∇L t ) ∇ 3 L t op ∇L t 2 0 1000 2000 0.0 0.5 1.0 ∇ 2 L t (v ⊥ t ,v ⊥ t ) ∇ 2 L t v ⊥ t 2 η = 0. ∇ 3 L t (v ⊥ t ,v ⊥ t ) ∇ 3 L t op v ⊥ t 2 0 1000 2000 0.0 0.5 1.0 ∇ 3 L t (∇L t ,∇L t ) ∇ 3 L t op ∇L t 2 0 1000 2000 0.0 0.5 1.0 ∇ 2 L t (v ⊥ t ,v ⊥ t ) ∇ 2 L t v ⊥ t 2 η = 0. ∇ 3 L t (v ⊥ t ,v ⊥ t ) ∇ 3 L t op v ⊥ t 2 0 1000 2000 0.0 0.5 1.0 ∇ 3 L t (∇L t ,∇L t ) ∇ 3 L t op ∇L t 2 0 1000 2000 0.0 0.5 1.0 ∇ 2 L t (v ⊥ t ,v ⊥ t ) ∇ 2 L t v ⊥ t 2 η = 0.005 ∇ 3 L t (v ⊥ t ,v ⊥ t ) ∇ 3 L t op v ⊥ t 2 0 1000 2000 0.0 0.5 1.0 ∇ 3 L t (∇L t ,∇L t ) ∇ 3 L t op ∇L t 2 0 1000 2000 0.0 0.5 1.0 ∇ 2 L t (v ⊥ t ,v ⊥ t ) ∇ 2 L t v ⊥ t 2 η = 0.002

F THE GENERALIZED PREDICTED DYNAMICS

Our analysis relies on a cubic Taylor expansion of the gradient. However, in order for this Taylor expansion to accurately track the gradients we need a bound on the fourth derivative of the loss (Assumption 3). Section 6 and Appendix E show that this approximation is sufficient to capture the dynamics of gradient descent at the edge of stability for many standard models when the loss criterion is the mean squared error. However, for certain architectures and loss functions, including ResNet18 and models trained with the logistic loss, this condition is often violated. In these situations, the loss function in the top eigenvector direction is either sub-quadratic, meaning that the quadratic Taylor expansion overestimates the loss and sharpnessfoot_7 , or super-quadratic, meaning that the quadratic Taylor expansion underestimates the loss and sharpness. To capture this phenomenon, we derive a more general form of the predicted dynamics which reduces to the standard predicted dynamics in Section 5 when the loss in the top eigenvector direction is approximately quadratic. In addition, Appendix G shows that the generalized predicted dynamics capture the dynamics of gradient descent at the edge of stability for both mean squared error and cross-entropy in all settings we tested.

F.1 DERIVING THE GENERALIZED PREDICTED DYNAMICS

To derive the generalized predicted dynamics, we will abstract away the dynamics in the top eigenvector direction. Specifically, for every t we define F t (x) := L(θ † t + xu t ) -L(θ † t ) - x 2 η . We say that L is sub-quadratic at t if F t (x) < 0 and super-quadratic if F t (x) > 0. Note that knowing F t is not sufficient to capture the dynamics in the u t direction. Specifically, x t+1 = x t -ηu t • ∇L(θ † t + v t ) ̸ = x t -ηu t • ∇L(θ † t + xu t ). It is still critically important to track the effect that the movement in the ∇S ⊥ t direction has on the dynamics of x. As in Section 4.1, the effect of the movement in the ∇S ⊥ t direction on the dynamics of x is changing the sharpness by y t . This gives us the generalized predicted dynamics update: v * t+1 = P ⊥ ut+1 (I -η∇ 2 L t )P ⊥ ut v * t + ηP ⊥ ut+1 ∇S ⊥ t δ 2 t -x * t 2 2 -x ⋆ t+1 • u t+1 where x ⋆ t+1 = -(1 + ηy ⋆ t )x ⋆ t -ηF ′ (x ⋆ t ). Note that when F t (x) = 0 is exactly quadratic, this reduces to the standard predicted dynamics update in eq. ( 6). Note that the update for y is completely unchanged: Lemma 5. Restricted to the u t , ∇S t directions, the generalized predicted dynamics v ⋆ t imply: x ⋆ t+1 = -(1 + ηy ⋆ t )x ⋆ t -ηF ′ (x ⋆ t ) and y ⋆ t+1 = η t s=0 β s→t δ 2 s -x * s 2 2 . ( ) The proof is identical to the proof of Lemma 4.

F.2 PROPERTIES OF THE GENERALIZED PREDICTED DYNAMICS

Note that due to the sign flipping argument in Appendix I, we can assume that F is an even function as the odd part will only influence the dynamics through additional oscillations of period 2, so throughout the remainder of this section we will assume that F t (x) = F t (-x). Otherwise, we can simply redefine F by its even part. Next, note that the fixed point of eq. ( 11) is still when x t = δ t , regardless of the shape of F t , due to the need to stabilize the ∇S ⊥ t direction. This contradicts previous 1-dimensional analyses of edge of stability in which the fixed point in the top eigenvector direction strongly depends on the shape of F t , the loss in the u t direction. The limiting value of y t can therefore be read from the update for x t . If (δ t , y) is an orbit of period 2 of eq. ( 11), then -δ t = -(1 + ηy)δ t -ηF ′ (δ t ) =⇒ y = - F ′ (δ t ) δ t . In addition, note that the sharpness can no longer be approximated as S(θ t ) ≈ 2/η + y t as the sharpness now changes along the u t direction. In particular, it changes by F ′′ (x) so that S(θ t ) ≈ 2/η + y t + F ′′ (x t ). Therefore, the limiting sharpness of eq. ( 11) is S(θ t ) → 2/η - F ′ t (δ t ) δ t + F ′′ t (δ t ). When F t = 0 and the loss is exactly quadratic in the u direction, this update reduces to fixed point predictions in Section 4.1. One interesting phenomenon observed by Cohen et al. (2021) is that with cross-entropy loss, the sharpness was never exactly 2/η, but usually hovered above it. This contradicts the predictions of the standard predicted dynamics which predict that the fixed point has sharpness 0. However, using the generalized predicted dynamics eq. ( 11), we can give a clear explanation. When the loss is sub-quadratic, e.g. when F t (x) = -ρ 4 x 4 24 , we have S(θ t ) → 2/η + ρ 4 δ 2 t 6 -ρ 4 δ 2 t 2 = 2/η -ρ 4 δ 2 t 3 < 2/η so the sharpness will converge to a value below 2/η. On the other hand if the loss is super-quadratic, the sharpness converges to a value above 2/η. More generally, whether the loss converges to a value above or below 2/η depends on the sign of F ′′ t (δ t )δ t F ′ t (δ t ). In our experiments in Appendix G, we observed both sub-quadratic and super-quadratic loss functions. In particular, the loss was usually sub-quadratic when it first reached instability but gradually became super-quadratic as training progressed at the edge of stability.

G ADDITIONAL EXPERIMENTS G.1 THE BENEFIT OF LARGE LEARNING RATES: TRAINING TIME AND GENERALIZATION

We trained ResNet18 with full batch gradient descent on the full 50k training set of CIFAR10 with various learning rates, in addition to the commonly proposed learning rate schedule η t := 1/S(θ t ). We show that despite entering the edge of stability, large learning rates converge much faster. In addition, due to the self-stabilization effect of gradient descent, the final sharpness is bounded by 2/η which is smaller for larger learning rates and leads to better generalization (see Figure 4 ). 

H ADDITIONAL DISCUSSION

A Precise Criterion for Self-Stabilization Our theoretical results in Sections 4 and 5 give sufficient conditions for self-stabilization as evidenced by Theorem 1. However, these assumptions may not be strictly necessary and we believe that self-stabilization may hold under significantly weaker assumptions. An important open question is understanding the precise conditions on the loss function and learning rate that enable self-stabilization. Furthermore, in Theorem 1 in Section 5, we give a quantitative bound on the time for which the EOS trajectory and the constrained trajectory can be coupled. It is an interesting future direction to understand whether coupling can be strengthened to hold for longer time periods to allow for convergence to a KKT point of the constrained trajectory eq. ( 2). Multiple Unstable Eigenvalues Our work focuses on explaining edge of stability in the presence of a single unstable eigenvalue (Assumption 2). However, Cohen et al. (2021) observed that progressive sharpening appears to apply to all eigenvalues, even after the largest eigenvalue has become unstable. As a result, all of the top eigenvalues will successively enter edge of stability (see Figure 5 ). In particular, Figure 5 shows that the dynamics are fairly well behaved in the period when only a single eigenvalue is unstable, yet appear to be significantly more chaotic when multiple eigenvalues are unstable. One technical challenge with dealing with multiple eigenvalues is that, when the top eigenvalue is not unique, the sharpness is no longer differentiable and it is unclear how to generalize our analysis. However, one might expect that gradient descent can still be coupled to projected gradient descent under the non-differentiable constraint S(θ † T ) ≤ 2/η. When there are k unstable eigenvalues, with corresponding eigenvectors u 1 t , . . . , u k t , the constrained update is roughly equivalent to projecting out the subspace span{∇ 3 L t (u i t , u j t ) : i, j ∈ [k]} from the gradient update -η∇L t . Demonstrating self-stabilization thus requires analyzing the dynamics in the subspace span {u i t : i ∈ [k]} ∪ {∇ 3 L t (u i t , u j t ) : i, j ∈ [k]} . We leave investigating the dynamics of multiple unstable eigenvalues for future work. Connection to Sharpness Aware Minimization (SAM) Foret et al. (2021) introduced the sharpness-aware minimization (SAM) algorithm, which aims to control sharpness by solving the optimization problem min θ max ∥δ∥≤ϵ L(θ + δ). This is roughly equivalent to minimizing S(θ) over all global minimizers, and thus SAM tries to explicitly minimize the sharpness. Our analysis shows that gradient descent implicitly minimizes the sharpness, and for a fixed η looks to minimize L(θ) subject to S(θ) = 2/η. Connections to Warmup. Gilmer et al. (2021) demonstrated that learning rate warmup, which consists of gradually increasing the learning rate, empirically leads to being able to train with a larger learning rate. The self-stabilization property of gradient descent provides a plausible explanation for this phenomenon. If too large of an initial learning rate η 0 is chosen (so that S(θ 0 ) is much greater than 2/η 0 ), then the iterates may diverge before self stabilization can decrease the sharpness to 2/η 0 . On the other hand, if the learning rate is chosen that S(θ 0 ) is only slightly greater than 2/η 0 , selfstabilization will decrease the sharpness to 2/η 0 . Repeatedly increasing the learning rate slightly could then lead to small decreases in sharpness without the iterates diverging, thus allowing training to proceed with a large learning rate. Connection to Weight Decay and Sharpness Reduction. Lyu et al. (2022) proved that when the loss function is scale-invariant, gradient descent with weight decay and sufficiently small learning rate converges leads to reduction of the normalized sharpness S(θ/∥θ∥). In fact, the mechanism behind the sharpness reduction is exactly the self-stabilization force described in this paper restricted to the setting in (Lyu et al., 2022) . We present here a heuristic derivation of this equivalence. First, we show that any scale invariant satisfies our assumptions when trained with weight decay. Lemma 6. Let f be a scale invariant loss function, i.e. f (θ) = f (cθ). Let L(θ) = f (θ) + λ 2 ∥θ∥ 2 . Then for any local minimizer θ of f (θ) such that S(θ) = 2/η, • ∇L(θ) ⊥ u(θ) • ρ 4 = O(ηρ 2 3 ) • α(θ) > 0 • α(θ) ∥∇L(θ)∥∥∇S(θ)∥ = Θ(1) • ∥∇S(θ)∥ = Θ(ρ 3 ) Our primary result is that gradient descent solves the constrained problem min θ L(θ) such that S(θ) ≤ 2/η. Let S f (θ) denote the largest eigenvalue of ∇ 2 f (θ). To prove equivalence to the sharpness reduction, we will need the following lemma from (Lyu et al., 2022) which follows from the scale invariance of the f : S f (θ) = 1 ∥θ∥ 2 S f (θ/∥θ∥). Let θ := θ ∥θ∥ . Then we have the following equality between minimization problems: min θ L(θ) such that S(θ) ≤ 2/η ⇐⇒ min θ f (θ) + λ ∥θ∥ 2 2 such that S f (θ) ≤ 2/η -λ ⇐⇒ min θ,∥θ∥ f (θ) + λ ∥θ∥ 2 2 such that 1 ∥θ∥ 2 S f (θ) ≤ 2 -ηλ η ⇐⇒ min θ f (θ) + ηλ 2 -ηλ S f (θ) where the last line follows from the scale-invariance of the loss function. In particular if ηλ is sufficiently small and the dynamics are initialized near a global minimizer of the loss, this will converge to the solution of the constrained problem: min ∥θ∥=1 S f (θ) such that f (θ) = 0.

H.1 SCALE INVARIANT LEMMAS

Let θ denote an arbitrary parameter and let θ = θ/∥θ∥. Throughout this section, let f be a scale invariant function with non-vanishing Hessian. Lemma 7. ∇f (θ) = P ⊥ θ ∇f (θ) ∥θ∥ ∇ 2 f (θ) = P ⊥ θ ∇ 2 f (θ)P ⊥ θ -P ⊥ θ ∇f (θ)θ T -θ(P ⊥ θ ∇f (θ)) T ∥θ∥ Proof. We start with the equality: f (θ) = f (θ). Differentiating with respect to θ and using that ∇ θ θ = P ⊥ θ ∥θ∥ gives, ∇f (θ) = P ⊥ θ ∇f (θ) ∥θ∥ . Differentiating this again gives: ∇ 2 f (θ) = P ⊥ θ ∇ 2 f (θ)P ⊥ θ -P ⊥ θ ∇f (θ)θ T -θ(P ⊥ θ ∇f (θ)) T ∥θ∥ 2 . A few corollaries immediately follow: Corollary 2. For any critical point θ of f , ∇ 2 f (θ) = P ⊥ θ ∇ 2 f (θ)P ⊥ θ ∥θ∥ 2 . Corollary 3. For any critical point θ of f , u(θ) ⊥ ∇L(θ) Proof. Note that from Corollary 2, the top eigenvector of ∇ 2 f is perpendicular to θ. In addition, ∇ 2 L(θ) = ∇ 2 f (θ) + λI so this is also the top eigenvector of ∇ 2 L(θ), i.e. u(θ). Finally, ∇L(θ) = ∇f (θ) + λθ = λθ which is parallel to θ and concludes the proof. Lemma 8.

∇S(θ)

= P ⊥ θ ∇S(θ) -(S(θ) -λ)θ ∥θ∥ 3 Proof. Let S f (θ) denote the largest eigenvalue of ∇ 2 f (θ). Then by scale invariance, ∇ 2 f (θ) = ∇ 2 f (θ)/∥θ∥ 2 . This implies that S f (θ) = S f (θ) ∥θ∥ 2 . Differentiating this gives: ∇S f (θ) = P ⊥ θ ∇S f (θ) -S f (θ)θ ∥θ∥ 3 . Finally, we have from ∇ 2 L(θ) = ∇ 2 f (θ) + λI that S(θ) = S f (θ) + λ so ∇S(θ) = P ⊥ θ ∇S(θ) -(S(θ) -λ)θ ∥θ∥

I PROOFS I.1 PROPERTIES OF THE CONSTRAINED TRAJECTORY

We next prove several nice properties of the constrained trajectory. First, we require the following auxiliary lemma, which shows that several quantities are Lipschitz in a neighborhood around the constrained trajectory: Definition 8 (Lipschitz Sets). S t := B(θ † t , 2-c 4ηρ3 ) where c is the constant in Assumption 2 and B(x, r) denotes the ball of radius r centered at x. Lemma 12 (Lipschitz Properties). 1. θ → ∇L(θ) is O(η -1 )-Lipschitz in each set S t . 2. θ → ∇ 2 L(θ) is ρ 3 -Lipschitz with respect to ∥•∥ 2 . 3. θ → λ i (∇ 2 L(θ)) is ρ 3 -Lipschitz. 4. θ → u(θ) is O(ηρ 3 )-Lipschitz in each set S t . 5. θ → ∇S(θ) is O(ηρ 2 3 )-Lipschitz in each set S t . Proof. The Lipschitzness of ∇ 2 L(θ) follows immediately from the bound ∇ 3 L(θ) op ≤ ρ 3 . Weil's inequality then immediately implies the desired bound on the Lipschitz constant of the eigenvalues of ∇ 2 L(θ). Therefore for any t, we have for all θ ∈ S t : λ 1 (∇ 2 L(θ)) -λ 2 (∇ 2 L(θ)) ≥ λ 1 (∇ 2 L(θ)) -λ 2 (∇ 2 L(θ)) -2ρ 3 2 -c 4ηρ 3 ≥ 2 -c 2η . Next, from the derivative of eigenvector formula: ∥∇u(θ)∥ 2 = (λ 1 (∇ 2 L(θ))I -∇ 2 L(θ)) † ∇ 3 L(θ)(u(θ)) 2 ≤ ρ 3 λ 1 (∇ 2 L(θ)) -λ 2 (∇ 2 L(θ)) ≤ 2ηρ 3 2 -c = O(ηρ 3 ) which implies the bound on the Lipschitz constant of u restricted to S t . Finally, because ∇S(θ ) = ∇ 3 L(θ)(u(θ), u(θ)), ∇ 2 S(θ) 2 ≤ ∇ 4 L(θ) op + 2 ∇ 3 L(θ) op ∥∇u(θ)∥ 2 ≤ O(ρ 4 + ηρ 2 3 ) ≤ O(ηρ 3 ) where the second to last inequality follows from the bound on ∥∇u(θ)∥ 2 restricted to S t and the last inequality follows from Assumption 3. Lemma 13 (First-order approximation of the constrained trajectory update {θ † t }). For all t ≤ T , θ † t+1 = θ † t -ηP ⊥ ut,∇St ∇L t + O ϵ 2 • η∥∇L t ∥ and S t = 2/η. Proof. We will prove by induction that S t = 2/η for all t. The base case follows from the definitions of θ 0 , θ † 0 . Next, assume S(θ † t ) = 0 for some t ≥ 0. Let θ ′ = θ † t -η∇L t . Then because θ † t ∈ M we have θ † t+1 -θ ′ ≤ θ † t -θ ′ = η∥∇L t ∥. Then because θ † t+1 = proj M (θ ′ ), the KKT conditions for this minimization problem imply that there exist x, y with y ≥ 0 such that θ † t+1 = θ † t -η∇L t -x∇ θ [∇L(θ) • u(θ)] θ=θ † t+1 -y∇S t+1 = θ † t -η∇L t -x S t+1 u t+1 + ∇u T t+1 ∇L t+1 -y∇S t+1 = θ † t -η∇L t -x[S t+1 u t+1 + O(ηρ 3 ∥∇L t+1 ∥)] -y∇S t+1 = θ † t -η∇L t -x[S t u t + O(ηρ 3 ∥∇L t ∥)] -y ∇S t + O(η 2 ρ 2 3 ∥∇L t ∥) = θ † t -η∇L t -xS t u t -y∇S t + O (|x|ηρ 3 + |y|η 2 ρ 2 3 )∥∇L t ∥ . Next, note that we can decompose ∇S t = u t (∇S t • u t ) + ∇S ⊥ t : θ † t+1 = θ † t -η∇L t -[xS t + y(∇S t • u t )]u t -y∇S ⊥ t + O (|x|ηρ 3 + |y|η 2 ρ 2 3 )∥∇L t ∥ . Let s t = ∇S ⊥ t ∥∇S ⊥ t ∥ . We can now perform the change of variables (x ′ , y ′ ) = xS t + y(∇S t • u t ), y ∇S ⊥ t , (x, y) =   x ′ -y ′ ∇St•ut ∥∇S ⊥ t ∥ S t , y ′ ∇S ⊥ t   to get θ † t+1 = θ † t -η∇L t -x ′ u t -y ′ s t + O η 2 ρ 3 ∥∇L∥(|x ′ | + |y ′ |) . Note that O(η 2 ρ 3 ∥∇L∥(|x| + |y|)) ≤ x 2 + y 2 2 (12) for sufficiently small ϵ so because θ † t+1θ ′ ≤ η∥∇L t ∥ we have x 2 + y 2 2 ≤ θ † t+1 -θ ′ ≤ η∥∇L t ∥ so x, y = O(η∥∇L t ∥). Therefore, θ † t+1 = θ † t -η∇L t -x ′ u t -y ′ s t + O η 3 ρ 3 ∥∇L∥ 2 = θ † t -η∇L t -x ′ u t -y ′ s t + O ϵ 2 • η∥∇L t ∥ Then Taylor expanding ∇L t+1 around θ † t gives ∇L t+1 • u t+1 = ∇L t • u t + (∇L t+1 -∇L t ) • u t + ∇L t+1 • (u t+1 -u t ) = u T t ∇ 2 L t -η∇L t -x ′ u t -y ′ s t + O(ϵ 2 • η∥∇L t ∥ + O ϵ 2 • ∥∇L t ∥ = -x ′ S t + O ϵ 2 • ∥∇L t ∥ so x ′ = O(ϵ 2 • η∥∇L t ∥). We can also Taylor expand S t+1 around θ † t and use that S t = 2/η to get S t+1 = 2/η + ∇S t • -η∇L t -x ′ u t -y ′ s t + O η 3 ρ 3 ∥∇L t ∥ 2 + O ϵ 2 • ρ 3 η∥∇L t ∥ = 2/η + ηα t -y ′ ∥∇S ⊥ t ∥ + O ϵ 2 • ρ 3 η∥∇L t ∥ . Now note that for ϵ sufficiently small we have O ϵ 2 • ρ 3 η∥∇L t ∥ ≤ O ϵ 2 • ηα t ≤ ηα t . Therefore if y ′ = 0, we would have S t+1 > 2/η which contradicts θ † t+1 ∈ M. Therefore y ′ > 0 and therefore y > 0, which by complementary slackness implies S t+1 = 2/η. This then implies that -η∇L t • ∇S ⊥ t -y ′ ∥∇S ⊥ t ∥ + O(ϵ 2 • ρ 3 η∥∇L t ∥) = 0 =⇒ y ′ = -η∇L t • ∇S ⊥ t ∇S ⊥ t + O ϵ 2 • η∥∇L t ∥ . Putting it all together gives θ † t+1 = θ † t -ηP ⊥ ∇S ⊥ t ∇L t + O ϵ 2 • η∥∇L t ∥ = θ † t -ηP ⊥ ut,∇St ∇L t + O ϵ 2 • η∥∇L t ∥ where the last line follows from u t • ∇L t = 0. Lemma 14 (Descent Lemma for θ † ). For all t ≤ T , L(θ † t+1 ) ≤ L(θ † t ) -Ω η P ⊥ ut,∇St ∇L t 2 . Proof of Theorem 1. First, by Lemma 17, we have ∥ ⋆ v t ∥ ≤ O(δ). Next, by Lemma 16, we have θ t -θ † t = v t = ⋆ v t + O(ϵδ). Next, we Taylor expand to calculate S(θ t ): S(θ t ) = S(θ † t ) + ∇S t • v t + O(ηρ 2 3 ∥v t ∥ 2 ) = 2/η + ∇S ⊥ t • v t + ∇S t • u t u t • v t + O(ηρ 2 3 δ 2 ) = 2/η + ∇S ⊥ t • ⋆ v t + ∇S t • u t u t • ⋆ v t + O(ρ 3 ϵδ + ηρ 2 3 δ 2 ) = 2/η + y t + (∇S t • u t )x t + O(η -1 ϵ 2 ). Finally, we Taylor expand the loss: L(θ t ) = L(θ † t ) + ∇L t • v t + 1 2 v T t ∇ 2 L t v t + O(ρ 3 ∥v t ∥ 3 ) = L(θ † t ) + 1 η x 2 t + 1 2 v ⊥ t T ∇ 2 L t v ⊥ t + O(ρ 1 ∥v t ∥ + ρ 3 ∥v t ∥ 3 ) = L(θ † t ) + 1 η ⋆ x 2 t + 1 2 ⋆ v ⊥ t T ∇ 2 L t ⋆ v ⊥ t + O(η -1 δ 2 ϵ) = L(θ † t ) + 1 η ⋆ x 2 t + O(η -1 δ 2 ϵ), where the last line follows from Assumption 5.

I.3 PROOF OF AUXILIARY LEMMAS

Proof of Lemma 15. Taylor expanding the update for θ t+1 about θ † t , we get θ t+1 = θ t -η∇L(θ t ) = θ t -η∇L t -η∇ 2 L t v t - 1 2 η∇ 3 L t (v t , v t ) + O ηρ 4 ∥v t ∥ Additionally, recall that the update for θ † t+1 is θ † t+1 = θ † t -ηP ⊥ ∇S ⊥ t ∇L t + O ϵ 2 • η∥∇L t ∥ . Subtracting the previous 2 equations and expanding out ∇ 3 L(v t , v t ) via the non-worst-case bounds, we obtain v t+1 = (I -η∇ 2 L t )v t -η(∇L t -P ⊥ ∇S ⊥ t ∇L t ) - 1 2 ηx 2 t ∇S t -ηx t ∇ 3 L t (u t , v ⊥ t ) - 1 2 η∇ 3 L t (v ⊥ t , v ⊥ t ) + O ηρ 4 ∥v t ∥ 3 + ϵ 2 • η∥∇L t ∥ = (I -η∇ 2 L t )v t -η ∇L • ∇S ⊥ ∥∇S ⊥ ∥ 2 ∇S ⊥ t - 1 2 ηx 2 t ∇S t -ηx t ∇ 3 L t (u t , v ⊥ t ) + O ηρ 3 ϵ∥v t ∥ 2 + ηρ 4 ∥v t ∥ 3 + ϵ 2 • η∥∇L t ∥ = (I -η∇ 2 L t )v t + η∇S ⊥ t ϵ 2 t -x 2 t 2 - 1 2 ηx 2 t ∇S t • u t u t -ηx t ∇ 3 L t (u t , v ⊥ t ) + O ϵ 2 • ∥v t ∥ 2 δ + ϵ 2 • ∥v t ∥ 3 δ 2 + ϵ 3 δ = (I -η∇ 2 L t )v t + η∇S ⊥ t ϵ 2 t -x 2 t 2 - 1 2 ηx 2 t ∇S t • u t u t -ηx t ∇ 3 L t (u t , v ⊥ t ) + O ϵ 2 δ • max 1, ∥v t ∥ δ 3 We would first like to compute the magnitude of v t+1 . ∥v t+1 ∥ = O ∥v t ∥ + ηρ 3 ∥v t ∥ 2 + η∥∇L t ∥ + ϵ 2 δ • max 1, ∥v t ∥ δ 3 . Observe that by definition of ϵ and δ, and since ∥v t ∥ ≤ ϵ -1 δ O(ηρ 3 ∥v t ∥ 2 ) ≤ O ∥v t ∥ • ϵ -1 ηρ 3 δ ≤ O ∥v t ∥ • ϵ -1 η √ ρ 1 ρ 3 ≤ O(∥v t ∥) O(ϵ 2 δ • max 1, ∥v t ∥ δ 3 ) ≤ O ϵ 2 δ + ∥v t ∥ • ϵ 2 • (ϵ -1 ) 2 ≤ O ϵ 2 δ + ∥v t ∥ . Hence ∥v t+1 ∥ = O ∥v t ∥ + η∥∇L t ∥ + ϵ 2 δ = O(∥v t ∥ + ϵδ). Note that we can bound ∥u t+1 -u t ∥ • ∥v t+1 ∥ = O η 2 ρ 3 ∥∇L t ∥ • (∥v t ∥ + ϵδ) = O ϵ 2 • (∥v t ∥ + ϵδ) ≤ O ϵ 2 • max(∥v t ∥, δ) . Therefore, the one-step update in the u t direction is: x t+1 = v t+1 • u t+1 = v t+1 • u t + O ϵ 2 • max(∥v t ∥, δ) = -v t • u t - 1 2 ηx 2 t ∇S t • u t -ηx t ∇S t • v ⊥ t + O ϵ 2 • max(∥v t ∥, δ) + ϵ 2 δ • max 1, ∥v t ∥ δ 3 = -x t (1 + ηy t ) - 1 2 ηx 2 t ∇S t • u t + O ϵ 2 δ • max 1, ∥v t ∥ δ 3 = -x t (1 + ηy t ) - 1 2 ηx 2 t ∇S t • u t + O ϵ 2 δ • max 1, ∥v t ∥ δ 3 = -x t (1 + ηy t ) - 1 2 ηx 2 t ∇S t • u t + O(E t ), where we have defined the error term E t as E t := ϵ 2 δ • max 1, ∥v t ∥ δ 3 . The update in the v ⊥ direction is 2 . v ⊥ t+1 = P ⊥ ut+1 (I -η∇ 2 L t )v t + η∇S ⊥ t ϵ 2 t -x 2 2 - 1 2 ηx 2 t ∇S t • u t P ⊥ ut+1 u t -ηx t P ⊥ ut+1 ∇ 3 L t (u t , v ⊥ t ) + O ϵ 2 δ • max 1, ∥v t ∥ δ 3 = P ⊥ ut+1 (I - Therefore v ⊥ t+2 -v ⊥ t+2 ≤ O η∥∇S t+1 ∥(x 2 t+1 -x 2 t+1 ) + E t ≤ O(ηρ 3 ∥v t ∥|x t+1 -x t+1 | + E t ) ≤ O η 2 ρ 2 3 ∥v t ∥ 3 + E t ≤ O ϵ 2 • ∥v t ∥ 3 δ 2 + E t = O(E t ) Altogether, we get that ∥r t ∥ ≤ O(E t ) = O ϵ 2 δ • max 1, ∥v t ∥ δ 3 , as desired. Proof of Lemma 16. Define w t = 0 t if is even r t-1 t if is odd and define the auxiliary trajectory v by v 0 = v 0 and v t+1 = step( v t ) + w t . I first claim that v t = v t for all even t ≤ T , which we will prove by induction on t. The base case is given by assumption so assume the result for some even t ≥ 0. Then, v t+2 = step t+1 (step t (v t )) + r t = step t+1 (step t ( v t )) + r t = step t+1 ( v t+1 ) + w t+1 = v t+2 which completes the induction. Next, we will prove by induction that for t ≤ T , v ⊥ t - ⋆ v ⊥ t , | x t - ⋆ x t | ≤ O(ϵδ) ≤ c 2 δ. By definition, v 0 = v 0 = ⋆ v 0 , so the claim is clearly true for t = 0. Next, assume the claim holds for t. If t is even then ∥w t ∥ = 0; otherwise ∥v t ∥ ≤ 2c 2 δ, and thus ∥w t ∥ ≤ O ϵ 2 δ • max (1, c 2 ) 3 ≤ O ϵ 2 δ . First observe that v ⊥ t+1 - ⋆ v ⊥ t+1 ≤ (I -η∇ 2 L t )( v ⊥ t - ⋆ v ⊥ t ) + ηρ 3 x 2 t - ⋆ x 2 t 2 + ∥w t ∥ ≤ 1 + η λ min (∇ 2 L t ) v ⊥ t - ⋆ v ⊥ t + O(ϵ) • | x t - ⋆ x t | + O ϵ 2 δ ≤ 1 + η λ min (∇ 2 L t ) v ⊥ t - ⋆ v ⊥ t + O(ϵδ) • x t - ⋆ x t ⋆ x t + O ϵ 2 δ Next, note that  x t+1 ⋆ x t+1 = (1 + η y t ) x t + O(ϵ 2 δ) (1 + η ⋆ y t ) ⋆ x t + O(ϵ 2 δ) = (1 + η ⋆ y t ) x t + O(ϵ 2 δ) + O(ϵ) • v ⊥ t - ⋆ v ⊥ t (1 + η ⋆ y t ) ⋆ x t + O(ϵ 2 δ) = x t ⋆ x t + O ϵ 2 + ϵ δ v ⊥ t - ⋆ v ⊥ t . Therefore x t+1 - ⋆ x t+1 ⋆ x t+1 ≤ x t - ⋆ x t ⋆ x t + O(ϵ 2 + ϵ δ v ⊥ t - ⋆ v ⊥ t ). Let d t = max v ⊥ t - ⋆ v ⊥ t , δ xt-⋆ xt ⋆ xt . Then v ⊥ t+1 - ⋆ v ⊥ t+1 ≤ (1 + η λ min (∇ 2 L t ) + O(ϵ))d t + O(ϵ 2 δ) δ x t+1 - ⋆ x t+1



The condition that ∇L(θ) • u(θ) = 0 is necessary to ensure the stability of the constrained trajectory. Beginning in Section 5, the reference points for our Taylor expansions change at every step to minimize errors. However, fixing the reference point in this section simplifies the analysis, better illustrates the negative feedback loop, and motivates the definition of the constrained trajectory. We give an explicit example of a loss function satisfying these assumptions in Appendix B. A rigorous derivation of this update in terms of S(θt) instad of S(θ ⋆ ) requires a third-order Taylor expansion around θ ⋆ ; see Appendix I for more details. ⋆v t denote our predicted dynamics (defined in Appendix C). Letting⋆ x t = u t • ⋆ v t and ⋆ y t = ∇S ⊥ t • ⋆ v t , we have ⋆ x t+1 = -(1 + η ⋆ y t ) ⋆ x t and ⋆ y t+1 = η t s=0β s→t δ 2 s - , |λmin(∇ 2 Lt)| ∥∇ 2 Lt∥2 ≤ O(ϵ).With these assumptions in place, we can state our main theorem which guarantees ⋆ x, ⋆ y, ⋆ v predict the loss, sharpness, and deviation from the constrained trajectory up to higher order terms:Theorem 1. Let T := O(ϵ -1 ) and assume that min t≤T | ⋆ x t | ≥ c 1 δ. Then for any t ≤ T , we have L(θ t ) = L(θ † t ) + ⋆ x 2 t /η + O ϵδ 2 /η (Loss) S(θ t ) = 2/η + ⋆ y t + (S t • u t ) ⋆ x t + O ϵ 2 /η (Sharpness) θ t = θ † t + ⋆ v t + O(ϵδ) (Deviation from θ † )The sharpness is controlled by the slowly evolving quantity ⋆ y t and the period-2 oscillations of (∇S • U ) ⋆ x t . This combination of gradual and rapid periodic behavior was observed byCohen et al. (2021) and appears in our experiments. Theorem 1 also shows that the loss at θ t spikes whenever ⋆ x t is large. On the other hand, when ⋆ x t is small, L(θ t ) approaches the loss of the constrained trajectory. For simplicity of exposition, we make these bounds on ∇ 3 L globally, however our proof only requires them in a small neighborhood of the constrained trajectory θ † . This sub-quadratic phenomenon was also observed in(Ma et al., 2022). .Lemma 16. Assume that there exists constants c 1 , c 2 such that for all t ≤ T , ∥⋆ v t ∥ ≤ c 2 δ, | ⋆ x t | ≥ c 1 δ.Then, for all t ≤ T , we have∥v t -⋆ v t ∥ ≤ O(ϵδ) Lemma 17. For t ≤ T , ∥ ⋆ v t ∥ ≤ O(δ).With these lemmas in hand, we can prove Theorem 1. )tηδ 2 = O( √ ρ 1 ρ 3 ). Therefore | ⋆ x t+1 | = (1 + η ⋆ y t )| ⋆ x t | ≤ (1 + O(ϵ))| ⋆ x t |. Since t ≤ O(ϵ -1 ), | ⋆ x t |grows by at most a constant factor, and thus | ⋆ x t | ≤ O(δ). Finally, recall that ⋆



-STABILIZATION: THE IMPLICIT BIAS OF INSTABILITY

Figure 1: The four stages of edge of stability, demonstrated on a toy loss function (Appendix B).

Figure3: We empirically demonstrate that the predicted dynamics given by eq. (5) track the true EOS dynamics. For each learning rate, the top row is a zoomed in copy of the bottom row which isolates one cycle and is represented by the dashed rectangle. Reported sharpnesses are two-step averages for visual clarity. For additional experimental details, see Section 6 and Appendix D.

Figure 4: Large learning rates converge faster and generalize better (ResNet18 and CIFAR10).

Figure 5: Edge of stability with multiple unstable eigenvalues. Each vertical line is the time at which the corresponding eigenvalue of the same color becomes unstable.

O(ϵ))d t + O(ϵ 2 δ).Therefored t+1 ≤ (1 + η λ min (∇ 2 L t ) + O(ϵ))d t + O(ϵ 2 δ) ≤ (1 + O(ϵ))d t + O(ϵ 2 δ), so for t ≤ T we have d t+1 ≤ O(ϵδ). Therefore v ⊥ t+1 -⋆ v ⊥ t+1 , | x t+1 -⋆ x t+1 | ≤ O(ϵδ) ≤ c 2 δ, so the induction is proven. Altogether, we get ∥ v t -⋆ v t ∥ ≤ O(ϵδ)for all such t, as desired. Proof of Lemma 17. Recall that x * t+1 = -(1 + ηy * t maxt |λmin(∇ 2 Lt)| , we have that β s→t = O(ρ 2 3 ), and thus ⋆ y t ≤ O(ρ 2

Gradient Descent at the Edge of Stability . . . . . . . . . . . . . . . . . . . . . . . 1.2 Self-stabilization: The Implicit Bias of Instability . . . . . . . . . . . . . . . . . . The Four Stages of Edge of Stability: A Heuristic Derivation . . . . . . . . . . . . 4.2 Analyzing the simplified dynamics . . . . . . . . . . . . . . . . . . . . . . . . . . 4.3 Relationship with the constrained trajectory θ † Architectures . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . D.2 Data . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . D.3 Experimental Setup . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . Properties of the Generalized Predicted Dynamics . . . . . . . . . . . . . . . . . . 21 The Benefit of Large Learning Rates: Training Time and Generalization . . . . . . 22 G.2 Experiments with the Generalized Predicted Dynamics . . . . . . . . . . . . . . . 23 Scale Invariant Lemmas . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 29 Properties of the Constrained Trajectory . . . . . . . . . . . . . . . . . . . . . . . 32 I.2 Proof of Theorem 1 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 34

η∇ 2 L t )P ⊥ ut v t + η∇S ⊥ ∇S t • u t P ⊥ ut+1 u tηx t P ⊥ ut+1 ∇ 3 L t (u t , v ⊥ = u tu t+1 u T t+1 u t ≤ ∥u tu t+1 ∥ 2 ≤ O(∥u tu t+1 ∥)Therefore we can control the first of the error terms as∇S t • u t P ⊥ ut+1 u t ≤ O ∥u tu t+1 ∥ • (∥v t ∥ + ηρ 3 ∥v t ∥As for the second error term, we can decomposeηx t P ⊥ ut+1 ∇ 3 L t (u t , v ⊥ t ) ≤ η∥v t ∥ P ⊥ ut ∇ 3 L t (u t , v ⊥ t ) + P ⊥ ut -P ⊥By Assumption 5, we haveP ⊥ ut ∇ 3 L t (u t , v ⊥ t ) ≤ O(ϵρ 3 ∥v t ∥).Additionally, P ⊥ ut -P ⊥ Thereforeηx t P ⊥ ut+1 ∇ 3 L t (u t , v ⊥ t ) ≤ O(ϵρ 3 ∥v t ∥ • η∥v t ∥ + η∥v t ∥∥u t+1u t ∥ • ρ 3 ∥v t ∥) ≤ O ϵηρ 3 ∥v t ∥ 2 + ηρ 3 ∥v t ∥ 2 ϵ 2 ≤ O ϵ 2 ∥v t ∥where we used ηρ 3 ∥v t ∥ = O(1). Altogether, we havev ⊥ t+1 = P ⊥ ut+1 (I -η∇ 2 L t )P ⊥ ut v t + η∇S ⊥We previously obtained ηρ 3 ∥v t ∥ = O(1). Furthermore,E t+1 = ϵ 2 δ • max 1, ∥v t+1 ∥ δ Hence x t+2 = x t (1 + ηy t )(1 + ηy t+1 ) + η 2 ηy t x 2 t ∇S t • u t + x 2 t ∇S t • u tx 2 t+1 ∇S t+1 • u t+1 + O(E t ).The first of these two error terms can be bounded as1 2 η 2 y t x 2 t ∇S t • u t ≤ O η 2 ρ 2 3 ∥v t ∥ 3 ≤ O ϵ 2 • ∥v t ∥As for the second term, we can bound|∇S t+1 • u t+1 -∇S t • u t | ≤ |u t+1 • (∇S t+1 -∇S t )| + |∇S t • (u t+1u t )| ≤ ∥∇S t+1 -∇S t ∥ + O(ρ 3 ) • ∥u t+1u t ∥ ≤ O η 2 ρ 2 3 ∥∇L t ∥ ≤ O(ϵ 2 ρ 3 )Additionally, we havex t+1 = -x t + O(ηρ 3 ∥v t ∥ 2 + E t ). t+1 ∇S t+1 • u t+1x 2 t ∇S t • u t ≤ ηx 2 t |∇S t+1 • u t+1 -∇S t • u t | + η(x 2 t+1x 2 t )|∇S t+1 • u t+1 | ≤ O ηρ 3 ∥v t ∥ 2 • ϵ 2 + ηρ 3 ∥v t ∥ ηρ 3 ∥v t ∥ 2 + E t ≤ O ϵ 2 ∥v t ∥ + ϵ 2 • ∥v t ∥Altogether, the two-step update for x t isx t+2 = x t (1 + ηy t )(1 + ηy t+1 ) + O(E t ). -η∇ 2 L t+1 )P ⊥ ut+1 (I -η∇ 2 L t )P ⊥ ut v t + ηP ⊥ ut+2 (I -η∇ 2 L t+1 )P ⊥ ut+1 ∇S ⊥ Define v t+1 = step t (v t ), v t+2 = step t+1 (v t ),andx i = v i • u i , y i = ∇S ⊥ i • v i for i ∈ {t + 1, t + 2}. By the definition of step, one sees thatv ⊥ t+1v ⊥ t+1 ≤ O(E t ). |∇S t • u t | + O(E t ) ≤ O(ηρ 3 ∥v t ∥ 2 + E t )The update for x after applying step isx t+2 = -x t+1 (1 + ηy t+1 ) = x t (1 + ηy t )(1 + ηy t+1 ). |x t+2x t+2 | ≤ O |x t |η y t+1y t+1 + O(E t ) ≤ O ηρ 3 ∥v t ∥ v ⊥ t+1v ⊥ t+1 + O(E t ) ≤ O(E t ).Additionally, the update for v ⊥ is v ⊥ t+2 = P ⊥ ut+2 (I -η∇ 2 L t+1 )P ⊥ ut+1 (I -η∇ 2 L t )P ⊥ ut v t + ηP ⊥ ut+2 (I -η∇ 2 L t+1 )P ⊥ ut+1 ∇S ⊥

acknowledgement

ACKNOWLEDGEMENTS AD acknowledges support from a NSF Graduate Research Fellowship. EN acknowledges support from a National Defense Science & Engineering Graduate Fellowship, and NSF grants CIF-1907661 and DMS-2014279. JDL, AD, EN acknowledge support of the Sloan Research Fellowship, NSF CCF 2002272, NSF IIS 2107304, NSF CIF 2212262, and NSF-CAREER under award #2144994.The authors would like to thank Jeremy Cohen, Kaifeng Lyu, and Lei Chen for helpful discussions throughout the course of this project. We would especially like to thank Jeremy Cohen for suggesting the term "self-stabilization" to describe the negative feedback loop derived in this paper.

CNN+Logistic on CIFAR10 (cats vs dogs) Gradient Descent Predicted Dynamics Constrained Trajectory Gradient Flow

Predicted Fixed Point ThenIn particular, ρ 4 = O(ηρ 2 3 ).Proof. Note that S(θ) < 4/η implies thatTherefore,andby compactness of S d-1 . In addition, note thatLemma 10. At any second order stationary point θ of f ,Proof.Lemma 11. At any second order stationary point θ of f ,Proof.where the last step follows from compactness of S d-1 and the fact that ∇ 2 f is non-vanishing.Proof of Lemma 6. The lemma is simply a restatement of Corollary 3, Lemma 10, Lemma 9, and Lemma 11.Proof. Taylor expanding L(θ † t+1 ) around L(θ † t ) and using Lemma 13 givesNext, note that because γ t = Θ(1) we have ∥∇L t ∥ = O( P ⊥ ut,∇St ∇L t ).Therefore for ϵ sufficiently small,Therefore,which completes the proof.Corollary 4. Let L ⋆ = min θ L(θ). Then there exists t ≤ T such thatProof. Inductively applying Lemma 14 we have that there exists an absolute constant c such that 

I.2 PROOF OF THEOREM 1

We first require the following three lemmas, whose proofs are deferred to Appendix I.3.Lemma 15 (2-Step Lemma). Let r t := v t+2step t+1 (step t (v t )).Assume that ∥v t ∥ ≤ ϵ -1 δ. Then

