A NEW CHARACTERIZATION OF THE EDGE OF STA-BILITY BASED ON A SHARPNESS MEASURE AWARE OF BATCH GRADIENT DISTRIBUTION

Abstract

For full-batch gradient descent (GD), it has been empirically shown that the sharpness, the top eigenvalue of the Hessian, increases and then hovers above 2/(learning rate), and this is called "the edge of stability" phenomenon. However, it is unclear why the sharpness is somewhat larger than 2/(learning rate) and how this can be extended to general mini-batch stochastic gradient descent (SGD). We propose a new sharpness measure (interaction-aware-sharpness) aware of the interaction between the batch gradient distribution and the loss landscape geometry. This leads to a more refined and general characterization of the edge of stability for SGD. Moreover, based on the analysis of a concentration measure of the batch gradient, we propose a more accurate scaling rule, Linear and Saturation Scaling Rule (LSSR), between batch size and learning rate.

1. INTRODUCTION

For full-batch GD, it has been empirically observed that the sharpness, the top eigenvalue of the Hessian, increases and then hovers above 2/(learning rate) (Cohen et al., 2021) as the training proceeds. This observation can provide a link between two empirical results regarding generalization, (i) using larger learning rates for GD can generalize better (Bjorck et al., 2018; Li et al., 2019b; Lewkowycz et al., 2020; Smith et al., 2020) and (ii) minima with low sharpness tend to generalize better (Hochreiter & Schmidhuber, 1997; Keskar et al., 2017) . This observation has a significant implication in existing neural network optimization convergence analyses since it is contrary to the frequent assumption that 'the learning rate is less than 2/β (here β is an upper bound of the Hessian top eigenvalue)', which ensures the decrease in the training loss (Nesterov, 2003; Schmidt, 2014; Martens, 2014; Bottou et al., 2018) . Even though the training loss evolves non-monotonically over short timescales due to the violation of the assumption, interestingly, the loss is observed to decrease over long timescales consistently. This regime in which GD typically occurs has been referred to as 'the edge of stability (EoS)' (Cohen et al., 2021) . There remain many aspects that are not clearly explained about the EoS regime. For example, it is not clear why and to what extent the sharpness hovers above 2/(learning rate). Moreover, the inherent mechanism is not yet elucidated for the unstable optimization to occur at the EoS consistently while prevented from entirely diverging. How this phenomenon can be generalized beyond GD, especially to mini-batch SGD, is still an open question. In this paper we provide a new characterization of the EoS for SGD, which can serve as an answer to the above questions. As a tool to analyze the optimization process of SGD, we first propose a sharpness measure of neural network loss landscape aware of SGD batch gradient distribution (hence capturing the interaction between SGD and the loss landscape), which we refer to as the interaction-aware-sharpness (IAS) (Section 2). Based on this measure, we define the stable and unstable regions in the neural network parameter space. We then scrutinize both theoretically and empirically the transition process of the iterate from the stable to the unstable region (Section 4.1) and the mechanism to escape from the unstable region, i.e., how the optimization can occur at the EoS. We interpret the latter mechanism based on the non-quadraticity of the loss and the presence of asymmetric valleys in the loss landscape (He et al., 2019 ) (Section 4.2). Based on these analyses, we propose the notion of implicit interaction regularization (IIR), i.e., the IAS is implicitly bounded during SGD, as an implicit bias of SGD (Section 4.3). The value that IAS is bounded by is the ratio of a concentration measure of the batch gradient distribution of SGD to the learning rate. This is a more refined characterization of the EoS, as it shows that IAS does not hover above a certain value, but rather hovers around. More importantly, it can be naturally applied to SGD since we do not make any impractical assumptions on the batch size or learning rate. Our new characterization of the EoS leads to a novel scaling rule between batch size and learning rate, from the idea of preserving a similar level of IIR (Section 5). This scaling rule, referred to as the Linear and Saturation Scaling Rule (LSSR), recovers the well-known linear scaling rule (LSR) (Jastrzębski et al., 2017; Masters & Luschi, 2018; Zhang et al., 2019; Shallue et al., 2018; Smith et al., 2020; 2021) for small batch sizes and reduces to no scaling (due to saturation) for large batch sizes.

2. GRADIENT DISTRIBUTION AND LOSS LANDSCAPE

In this section, we review some concepts required for further discussion. See Appendix A for a quick reference for the notations. To simplify the notations, we often omit the dependence on some variables and the subscript of the expectation operation when clear from the context. For a learning task, we use a parameterized model with model parameter θ ∈ Θ ⊂ R m . Then we train the model using training data D = {x i } n i=1 and a loss function ℓ(x; θ). We denote the (total) training loss by L(θ) ≡foot_0 n n i=1 ℓ(x i ; θ) for training data D. At time step t, we update the parameter θ t using GD: θ t+1 = θ t -η∇ θ L(θ t ) with a learning rate η > 0, or using SGD: θ t+1 = θ t -ηg ξ (θ t ) with a mini-batch gradient g ξ (θ t ) ≡ 1 b x∈B t ξ ∇ θ ℓ(x; θ t ) ∈ R m for a mini-batch B t ξ ⊂ D of size b (1 ≤ b ≤ n). Here, we use the subscript ξ to denote the random batch sampling procedure. Now, we are ready to introduce some important matrices, C b , S b , and H. First, we define the covariance for sampling without replacement 1 for sampling with replacement . C b (θ) ≡ Var ξ [g ξ (θ)] = E ξ (g ξ (θ) -E ξ [g ξ (θ)]) (g ξ (θ) -E ξ [g ξ (θ)]) ⊤ ∈ R We provide a self-contained proof of (1) in Appendix B.1. We note that, for sampling without replacement, many previous works approximate γ n,b ≈ 1 assuming b ≪ n (Jastrzębski et al., 2017; Hoffer et al., 2017; Smith et al., 2021 ), but we consider the whole range of 1 ≤ b ≤ n (0 ≤ γ n,b ≤ 1 with γ n,1 = 1 and γ n,n = 0). Second, we define the Hessian H(θ) = ∇ 2 θ L(θ) = 1 n n i=1 ∇ 2 θ ℓ(x i ; θ) ∈ R m×m and denote the i-th largest eigenvalue and its corresponding normalized eigenvector by λ i (H) ∈ R and q i (H) ∈ R m , respectively, for i = 1, • • • , m. The operator norm ∥H∥ ≡ sup ∥u∥=1 ∥Hu∥ of H is equivalent to the top eigenvalue λ 1 . We emphasize that C b and S b represent the stochasticity of the batch gradients, and H represents the loss landscape geometry. Therefore, we can write one of our goals as follows: we aim to understand how the loss landscape geometry (H) and the gradient distribution (S b ) interact with each other during SGD training. We investigate this "interaction" in terms of matrix multiplication HS b . To be specific, we consider the trace tr(HS b ) and its normalized value tr(HS b ) tr(S b ) , and we call the latter interaction-aware sharpness: Definition 1 (Interaction-Aware Sharpness (IAS)). ∥H∥ S b ≡ tr(HS b ) tr(S b ) . Here, tr(HS b ) ≤ ∥H∥ tr(S b ), i.e., ∥H∥ S b ≤ ∥H∥, and the equality holds only when every g ξ is aligned in the direction of the top eigenvector q 1 of H. (Thomas et al., 2020) , respectively. In this paper, we provide a new insight into the interaction tr(HS b ) without these assumptions.

3. RELATED WORK

Convergence of full-batch GD (b = n) has been instead analyzed with an upper bound on the interaction tr(HS n ) with further assumptions for the stable optimization, such as β-smoothness of the objective and 0 < η <foot_1 β (e.g., η = 1 β ) (Nesterov, 2003; Schmidt, 2014; Martens, 2014; Bottou et al., 2018) . 2 However, it may lose useful information of the interaction between H and S n . Moreover, when we train a standard neural network with GD in practice, ∥H∥(≤ β) increases in the early phase of training and the iterate enters the EoS where ∥H∥ ⪆ 2 η , i.e., η ⪆ 2 ∥H∥ ≥ 2 β . This contradicts with the assumption for stable optimization and the iterate exhibits unstable behavior with a non-monotonically decreasing loss (Xing et al., 2018; Wu et al., 2018; Cohen et al., 2021) . We further extend this discussion of unstable dynamics for GD at the EoS to the case of SGD. From the generalization perspective, many studies focus on the implicit bias of SGD toward a better generalization (Neyshabur, 2017; Kwon et al., 2021) . We provide a link between the batch gradient distribution and the sharpness that the model is implicitly regularized to have a low sharpness when the second moment of the batch gradient is large (see Section 4.3). We provide further discussion to reconcile our arguments with some previous studies in Appendix D

4.1. UNSTABLE OPTIMIZATION

Using the second-order Taylor expansion of the total training loss L(θ) at θ t , the change in the loss L t = L(θ t ) as the SGD iterate moves from θ t to θ t+1 at time step t can be expressed as follows: L t+1 -L t = -η∇L(θ t ) ⊤ g ξ + η 2 2 g ⊤ ξ H(θ t )g ξ + O(∥δ t ∥ 3 ), where δ t = θ t+1 -θ t = -ηg ξ . Given θ t , the expected loss difference over batch sampling ξ is E ξ [L t+1 ] -L t (4) = -η∇L(θ t ) ⊤ E ξ [g ξ ] + η 2 2 E ξ [g ⊤ ξ H(θ t )g ξ ] + ϵ (Taking E ξ for the both sides of (3)) (5) = -η∥∇L(θ t )∥ 2 + η 2 2 tr E ξ [H(θ t )g ξ g ⊤ ξ ] + ϵ (E ξ [g ξ ] = ∇L and u ⊤ v = tr(vu ⊤ )) (6) = η 2 2 tr(S n ) tr(HS b ) tr(S n ) - 2 η + ϵ (by definition of S b and S n ), where ϵ = O(E ξ [∥δ t ∥ 3 ]) and E ξ [X] is the conditional expectation of X given θ t . We consider a bounded region with a finite maximum loss near and including θ t in which the loss is approximately quadratic. Then, when the following instability condition is met within the region, the loss (close to the expected loss) tends to increase and the iterate tends to escape from the region: 12) for GD (bottom)] In the early phase, until the iterate enters the unstable region, it validates ( 7) and ( 12) with the blue line with the slope η 2 2 and x-intercept 2 η . For GD (bottom), they are plotted after ∥H∥ exceeds 2 η after which ∥H∥ Sn starts to increase from 0 to 2 η in a few steps. For cross-entropy (CE) loss, we mark the end point with 'x' when the iterate enters the unstable region. For mean squared error (MSE) loss (bottom right), we plot the graph for a few more steps after the iterate enters the EoS. We train 6CNN on CIFAR-10-8k with η = 0.02 (see Remark at the end of Section 4.2). Theorem 1. For SGD on a quadratic L, the expected loss increases, i.e., E ξ [L t+1 ] -L t > 0, if and only if θ t satisfies the instability condition tr(HS b ) tr(Sn) > 2 η . Furthermore, if the batch gradient g ξ is normally distributed, then the following inequalities hold for any positive x > 0: P L t+1 -E ξ [L t+1 ] ≥ 2βx + η 2 γ n,b b ∥H∥∥C 1 ∥x θ t ≤ exp(-x), P L t+1 -E ξ [L t+1 ] ≤ -2βx θ t ≤ exp(-x), where β = η 2 γ n,b b v ⊤ C 1 v + η 2 γ n,b tr(HC 1 HC 1 ) and v = (I -ηH)∇L. The proof is deferred to Appendix B.2. From the above theorem, we define unstable and stable region:  U ≡ {θ ∈ Θ : tr(HS b ) tr(S n ) > 2 η } and S ≡ U c , ∂S = {θ ∈ Θ : tr(HS b ) tr(S n ) = 2 η }. We emphasize the superiority of our ∂S over the previous EoS that (i) ∂S provides a clearer "edge" since it considers the interaction between the gradient direction and the Hessian and that (ii) it is more general since it applies to SGD with any batch size 1 ≤ b ≤ n. Figure 1 (top row) empirically validates (7) , showing the normalized loss difference E ξ [Lt+1]-Lt tr(Sn) against tr(HS b ) tr(Sn) in the early phase of training before entering the unstable region. This result implies that the training loss L(θ) is approximately locally quadratic, i.e., ϵ ≈ 0, in the early phase. Especially, for full-batch GD (b = n), the instability condition can be rewritten as ∥H∥ Sn > 2 η and we have the following relationship between the loss difference L t+1 -L t and ∥H∥ Sn from (7) : against tr(HS b ) tr(Sn) during training. After the iterate enters the EoS, it often shows a more gentle slope than η 2 2 , especially in the unstable region. See the caption of Figure 1 . L t+1 -L t = η 2 2 tr(S n ) ∥H∥ Sn - 2 η + ϵ. Figure 1 (bottom row) shows ∥H∥ Sn soars from 0 in a few steps after ∥H∥ exceeds 2 η , satisfying (12) with ϵ ≈ 0, before the iterate enters the unstable region. This result is consistent with the following theorem for a quadratic loss L and generalized momentum GD with (β 1 , β 2 ): δ t = β 1 δ t-1 -η∇ θ L(θ t + β 2 δ t-1 ), θ t+1 = θ t + δ t , where β 1 , β 2 ∈ [0, 1). Here, we have vanilla GD when β 1 = β 2 = 0, Polyak momentum when β 1 ∈ (0, 1) and β 2 = 0 (Polyak, 1963), and Nesterov momentum when β 1 = β 2 ∈ (0, 1) (Nesterov, 1983 ). We will focus on the vanilla GD as it can be easily extended to the generalized momentum variants (see Appendix C.6 for details). The proof is deferred to Appendix B.3. Theorem 2. For generalized momentum GD with (β 1 , β 2 ) on a quadratic L, if 0 < λ i < 2 η γ(β 1 , β 2 ) < λ 1 for all i ̸ = 1 where γ(β 1 , β 2 ) = 1+β1 1+2β2 , then q ⊤ 1 δ t oscillates and diverges with the exponential growth of |q ⊤ 1 δ t | = Θ(e ct ) for some c > 0. Moreover, | cos(q 1 , δ t )| and ∥H∥ Sn increase to 1 and λ 1 , as t → ∞, respectively, with 1 -| cos(q 1 , δ t )|, λ 1 -∥H∥ Sn = O(e -2ct ). Note that γ(β 1 , β 2 ) = 1 for the vanilla GD. To summarize, if ∥H∥ exceeds 2 η , then ∥H∥ Sn increases towards ∥H∥ with the exponential convergence rate and also exceeds 2 η in a few steps, i.e., the iterate enters the unstable region U. Together with Theorem 1, if we consider a bounded subregion V ⊂ U with a finite maximum loss near and including θ t in which the loss is approximately quadratic, then the iterate tends to escape from the region V (Nar & Sastry, 2018; Wu et al., 2018; Cohen et al., 2021) .

4.2. NON-QUADRATICITY, ASYMMETRIC VALLEYS AND THE EDGE OF STABILITY

In the previous section, we have shown that the training loss is approximately locally quadratic before the iterate enters the unstable region. However, after the iterate enters the unstable region, i.e., tr(HS b ) tr(Sn) reaches and exceeds 2 η , the step size is relatively large for the sharp loss landscape so that the iterate jumps across the valley (Jastrzębski et al., 2019) , and the higher-order terms ϵ in (7) and (12) become non-negligible and cause a different behavior of the iterate than in the stable region. Figure 2 shows empirical evidences for the non-quadraticity. After the SGD/GD iterate enters the unstable region, when the instability condition tr(HS b ) tr(Sn) > 2 η is met, the normalized increase in the loss E ξ [Lt+1]-Lt tr(Sn) is often smaller than The training loss difference along the gradient descent direction, for each θ t . Each plot is normalized and translated to have the same minimum value and the same zero where ∆L(zero) = L(zero) -L(θ t ) = 0. We also plot the quadratic baseline (cyan dashed curve). When ∥H∥ Sn < 2 η (red), it usually becomes sharper across the valley (right-shifted), while the opposite is observed (left-shifted) when ∥H∥ Sn > 2 η (blue). We train 6CNN using GD with η = 0.04. We hypothesize that due to this non-quadraticity of the training loss, the iterate is discouraged from staying within the unstable region. Note that, for a globally quadratic loss, when the iterate is in the unstable region, it diverges within the unstable region. Figure 3 demonstrates the asymmetric valley (He et al., 2019 ) that one side is sharp and the other is flat. In Figure 3 (left), we evaluate the directional sharpness ∥H α ∥ Sn along the gradient descent direction -η∇L(θ) where H α ≡ H(θ -αη∇L(θ)) for α ∈ 1 4 × [1, 2, 3, 4, 5] , and compare ∥H α ∥ Sn(θ) with ∥H∥ Sn(θ) . At the sharp side, it has a high ∥H∥ Sn > 2 η (blue) with the gradient ∇L and the top eigenvector q 1 (H) of the Hessian being highly aligned (cf. Theorem 2). However, when the loss landscape gets far from being quadratic, the Hessian and its top eigenvector can change abruptly, q 1 (H α ) would not always be aligned with q 1 (H) and ∇L(θ), and ∥H α ∥ Sn tends to decrease. This would be a possible explanation for the tendency of decreasing and then oscillating ∥H∥ Sn . See Appendix C.3 for detailed empirical evidences of the above arguments. Figure 3 (right) similarly shows that when the iterate is at a sharp side of the valley, it tends to jump to the other side of a flatter area, and vice versa. To summarize, we make the following observations for GD in order: (i) ∥H∥ increases in the beginning (the progressive sharpening (Jastrzębski et al., 2019; 2020; Cohen et al., 2021) ), (ii) ∥H∥ exceeds 2 η , (iii) the gradient ∇L becomes more aligned with the top eigenvector q 1 (H) in a few steps, (iv) ∥H∥ Sn reaches the threshold 2 η and the iterate jumps across the valley, (v) ∥H∥ Sn tends to decrease due to the non-quadraticity, and it repeats this process, while ∥H∥ Sn oscillating around 

4.3. IMPLICIT INTERACTION REGULARIZATION (IIR)

In the previous sections, we have shown the SGD iterate is implicitly discouraged from staying in the unstable region. Now, we are ready to investigate this property from the regularization perspective. First, to understand the effect of batch size b on the batch gradient distribution, we define the following concentration measure ρ b : Definition 2 (a concentration measure of the batch gradient). We define ρ b as the ratio of the squared norm of the total gradient ∥∇L∥ 2 to the expected squared norm of the batch gradients E ξ [∥g ξ ∥ 2 ], i.e., ρ b ≡ ∥∇L∥ 2 E ξ [∥g ξ ∥ 2 ] = tr(S n ) tr(S b ) . Here, we can write ∥∇L∥ 2 = ∥E ξ [g ξ ]∥ 2 and thus the ratio ρ b = ∥E ξ [g ξ ]∥ 2 E ξ [∥g ξ ∥ 2 ] ≤ 1 is similar to the square of the mean resultant length R2 b ≡ ∥E ξ [ g ξ ∥g ξ ∥ ]∥ 2 ≤ 1 of the batch gradient g ξ (Mardia et al., 2000), especially when std ξ [∥g ξ ∥] ≪ E ξ [∥g ξ ∥] (see Appendix C.5

for empirical evidences). Both ρ b and R2

b are concentration measures and have lower values when the batch gradients g ξ are more scattered. Therefore, it is natural to expect that the ratio ρ b is small for a small batch size b, and we will revisit this in more detail in the following section (cf. (17) ). We also note that ρ n = R2 n = 1. Yin et al. (2018) call 1/(nρ 1 ) the gradient diversity. Now, we can rewrite the instability condition tr(HS b ) tr(Sn) > 2 η (multiplying both sides by ρ b ) as follows: ∥H∥ S b > 2ρ b η . From Theorem 1 and 2, the instability condition (16) implies that IAS ∥H∥ S b is implicitly regularized and bounded to be less than 2ρ b η . We name this Implicit Interaction Regularization (IIR). We argue that the upper constraint 2ρ b η in IIR is crucial in determining the SGD dynamics. With a low constraint, SGD strongly bounds IAS ∥H∥ S b . We also note that IIR affects not only the magnitude ∥H∥ but also the directional interaction. In other words, IIR discourages the batch gradients from aligning with the top eigensubspace of the Hessian that is spanned by a few largest eigenvectors of the Hessian (cf. Gur-Ari et al. ( 2018)). Figures 4(a)-4(b) show that, for GD (ρ n = 1), IAS ∥H∥ Sn (red) oscillates around 2 η and exhibits IIR. This result is consistent with Cohen et al. (2021) that ∥H∥ hovers above 2 η for GD. This is because, as mentioned earlier, 2 η ≈ ∥H∥ Sn ≤ ∥H∥ and the equality holds only when the gradient ∇L and the top eigenvector q 1 of H are aligned, which is in general not the case. For this reason, IIR provides a tighter relation and more clearly identifies the EoS than Cohen et al. (2021) . These results are also consistent with Theorem 2 that ∥H∥ Sn suddenly increases from 0 to 2 η in a few steps after ∥H∥ exceeds η of ∥H∥ S b according to IIR is higher when using a larger batch size, but limited to less than 2 η (ρ b ≤ 1). We will further discuss this behavior with an investigation of ρ b in the following section.

5. LINEAR AND SATURATION SCALING RULE (LSSR)

In this section, we first introduce two previous scaling rules on how to tune the learning rate for varying batch sizes. Then we explain why they fail and propose a new scaling rule based on IIR. The ratio b/η of batch size b to learning rate η has long been believed as an important factor influencing the generalization performance, and the test accuracy has observed to be similar when trained with the same ratio b/η = b ′ /η ′ , i.e., b ′ = kb and η ′ = kη for k > 0. This is called the linear scaling rule (LSR) (  -θ t = -η b k-1 i=0 x∈B t+i ξ ∇ θ ℓ(x; θ t+i ) ≈ -η b k-1 i=0 x∈B t+i ξ ∇ θ ℓ(x; θ t ) = -η ′ b ′ x∈B t:t+k ξ ∇ θ ℓ(x; θ t ) assuming ∇ θ ℓ(θ t+i ) ≈ ∇ θ ℓ(θ t ) for 0 ≤ i < k, where B t:t+k ξ ≡ ∪ k-1 i=0 B t+i ξ and |B t:t+k ξ | = kb = b ′ . However, the assumption is not accurate since the gradient oscillates mostly with a negative cosine value cos(g ξ (θ t ), g ξ (θ t+1 )) < 0 between two consecutive gradients after entering the EoS (see Figure 24 Based on the analysis of IIR with a new ratio 2ρ b /η in the previous section, we explore why LSR fails in the large-batch regime and provide a more accurate rule to achieve similar generalization performance of the models trained with various choices of batch size and learning rate pairs (b, η). To this end, we investigate the concentration measure ρ b = tr(S n )/ tr(S b ). By combining two equations, C b = S b -S n (by definition) and C b = γ n,b b (S 1 -S n ) in (1), we can obtain S b = C b + S n = γ n,b b S 1 + (1 - γ n,b b )S n . Therefore, we have tr(S b ) = γ n,b b tr(S 1 ) + (1 - γ n,b b ) tr(S n ), which leads to the following equation: 15) where ρ = ρ 1 = tr(S n )/ tr(S 1 ). Note that ρ is (much) smaller than 1 because ∇ θ ℓ(x i ) has different direction for each x i and tr( ρ b ≡ tr(S n ) tr(S b ) = tr(S n ) γ n,b b tr(S 1 ) + (1 - γ n,b b ) tr(S n ) = 1 γ n,b b 1 ρ + (1 - γ n,b b ) ( * ) ≈ b γ n,b ρ ≈ bρ if b is small 1 if b is large (17) from ( S n ) = ∥∇L∥ 2 = ∥ 1 n i ∇ θ ℓ(x i )∥ 2 ≤ 1 n i ∥∇ℓ θ (x i )∥ 2 = tr(S 1 ). In other words, 1/ρ is (much) larger than 1 (see Appendix C.5). b since 1 ρ ≫ 1. Second, as b becomes large, γ n,b /b ≈ 0 and the second term (≈ 1) dominates the first term. Thus, ρ b saturates to 1 and is not linearly related to b, and LSR is no longer valid. The above arguments also hold for the batches sampled with replacement where the only modification is γ n,b = 1, ∀b in (17) . Figure 5 (right) empirically supports LSSR with the test accuracies when trained with various combinations of pairs (b, η). To be specific, the optimal learning rate is almost linear to b when b is small, but it saturates when b is large. We also plot ρ b /η = C (the yellow dashed curve) for some ρ and C which shows a theoretical prediction of pairs (η, b) that yield the optimal performance. Note that Figure 8 Remark (Experiments in Section 4.3 and 5). We train models using vanilla SGD/GD without momentum and weight decay, constant learning rate, and no data augmentation. For Figure 5 (Ioffe & Szegedy, 2015) to achieve a zero training error even with a large b and a small η. In the lower right corner (red area) of each heatmap in Figure 5 (right), when b is too large or η is too small so that ∥θ t+1 -θ t ∥ = η∥g ξ ∥ is too small, it requires an exponentially large number of steps for the iterate to enter the EoS. Thus, in this case, the assumption in Goyal et al. (2017), ∇ θ ℓ(θ t ) ≈ ∇ θ ℓ(θ t+i ) for 0 ≤ i < k, approximately holds and the reasoning on LSR is valid. However, this only holds for a non-practical (b, η) which shows a suboptimal performance. See Appendix C.4-C.5 for the results from other networks and hyperparameters.

6. CONCLUSION

From an analysis of unstable dynamics of SGD and the instability condition, we clearly mark the edge of stability with the interaction-aware sharpness ∥H∥ S b and show the presence of the implicit regularization effect on the interaction between the gradient distribution and the loss landscape geometry (IIR). Moreover, introducing the concentration measure ρ b of the batch gradient, we link the second moment of the gradient distribution and the sharpness of the loss landscape, and propose a new scaling rule called Linear and Saturation Scaling Rule (LSSR). Due to the simplicity of the analysis, we hope that our insights will motivate the future work toward understanding various learning tasks.

A NOTATIONS

We summarize the notations for a quick reference. t ∈ N time step θ ∈ Θ ⊂ R m (or indexed θ t ); dim(θ) = m model parameter x ∈ X (or indexed x i ) training sample D = {x i } n i=1 ; |D| = n training data ℓ(x; θ) loss function L(θ) ≡ 1 n n i=1 ℓ(x i ; θ) = 1 |D| x∈D ℓ(x; θ); L t = L(θ t ) (total) training loss η > 0 learning rate B ⊂ D (or indexed B ξ , B t ξ ) batch b = |B|; 1 ≤ b ≤ n batch size g ξ (θ) ≡ 1 b x∈B ξ ∇ℓ(x; θ) = 1 |B ξ | x∈B ξ ∇ℓ(x; θ) batch gradient θ t+1 = θ t -η∇ θ L(θ t ) GD θ t+1 = θ t -ηg ξ (θ t ) SGD δ t = θ t+1 -θ t displacement/velocity vector C b (θ) ≡ Var ξ [g ξ (θ)] ∈ R m×m the covariance of the batch gradient S b (θ) ≡ E ξ g ξ (θ)g ξ (θ) ⊤ ∈ R m×m the second moment of the batch gradient γ n,b = n-b n-1 for sampling without replacement 1 for sampling with replacement sampling coefficient H(θ) ≡ ∇ 2 θ L(θ) = E x∼D [∇ 2 θ ℓ(x; θ)] ∈ R m×m Hessian ∥u∥ = i u 2 i 1/2 the Euclidean ℓ 2 -norm of a vector u ∥A∥ ≡ sup u̸ =0 ∥Au∥ ∥u∥ the spectral (operator) norm of a matrix A λ i = λ i (H) ∈ R the i-th largest eigenvalue of the Hessian q i = q i (H) ∈ R m the corresponding i-th eigenvector of the Hessian Proof. We start with tr(A) = i A i,i C b = E ξ [g ξ g ⊤ ξ ] -∇L∇L ⊤ (18) = ∇LE ξ [ww ⊤ ]∇L ⊤ -∇L 1 n 1 1 n 1 ⊤ ∇L ⊤ (19) = ∇L E ξ [ww ⊤ ] - 1 n 2 11 ⊤ ∇L ⊤ (20) = ∇LVar ξ [w]∇L ⊤ (21) = 1 b 2 ∇LVar ξ [v]∇L ⊤ , where ∇L = [∇ℓ 1 , • • • , ∇ℓ n ] ∈ R m×n , ℓ i = ℓ(x i ), the random vector w = [w 1 , • • • , w n ] ⊤ , each element of which represents 1 b × "how many times the index i is sampled in B ξ ", and v = bw. In case of sampling with replacement, we have v = v (1)  + • • • + v (b) where v (i) represents sampling of a single sample. Thus, Var ξ [v] = bVar[v (1) ]. We have E[v (1) ] = 1 n 1 and (1) , j ∈ B (1)  E[v (1) v (1)T ] i,j = P [i ∈ B (1) ] = 1 n if i = j P [i ∈ B ] = 0 else , where |B (1) | = 1. Thus, Var ξ [v] = bVar[v (1) ] = b 1 n I - 1 n 2 11 ⊤ . ( ) In case of sampling without replacement, we have E ξ [v] = b n 1 and E ξ [vv ⊤ ] i,j = P [i ∈ B ξ ] = C(n-1,b-1) C(n,b) = b n if i = j P [i ∈ B ξ , j ∈ B ξ ] = C(n-2,b-2) C(n,b) = b(b-1) n(n-1) else , where C(n 1 , r 1 ) is the number of r 1 -combinations from a set of n 1 elements. This leads to E ξ [vv ⊤ ] = b(b -1) n(n -1) 11 ⊤ + b n - b(b -1) n(n -1) I and Var ξ [v] = E ξ [vv ⊤ ] - b 2 n 2 11 ⊤ = b(b -1) n(n -1) - b 2 n 2 11 ⊤ + b(n -b) n(n -1) I (27) = b(b -n) n 2 (n -1) 11 ⊤ + b(n -b) n(n -1) I (28) = b(n -b) n -1 1 n I - 1 n 2 11 ⊤ Putting the two cases together, from (22), ( 24) and (29), we have First, we expect the variance of the loss at the next iteration (t + 1) to be small for a sufficiently large b. Lemma 3. If the batch gradient g ξ is normally distributed and the loss L is quadratic, then C b = 1 b 2 bγ n,b ∇L 1 n I - 1 n 2 11 ⊤ ∇L ⊤ = γ n,b b 1 n ∇L∇L ⊤ - 1 n 2 (∇L1)(∇L1) ⊤ (30) = γ n,b 1 n i ∇ℓ i ∇ℓ ⊤ i -∇L∇L ⊤ (31) = γ n,b b (S 1 -S n ), Var ξ [L t+1 |θ t ] = η 2 γ n,b b v ⊤ C 1 v + η 2 γ n,b 2b tr(HC 1 HC 1 ) , where v = (I -ηH)∇L. Proof. Since L is quadratic, we have the following: L t+1 = L t -ηg ⊤ g ξ + η 2 2 g ⊤ ξ Hg ξ (g = ∇L) (35) = constant -ηg ⊤ (g + ε) + η 2 2 (g + ε) ⊤ H(g + ε) (ε ∼ N (0, Σ = C b )) (36) = constant -ηg ⊤ ε + η 2 g ⊤ Hε + η 2 2 ε ⊤ Hε (37) = constant -η((I -ηH)g) ⊤ ε + η 2 2 ε ⊤ Hε, Then, the variance of L t+1 is Var ξ [L t+1 |θ t ] = E ξ [Q 2 ξ ] -E ξ [Q ξ ] 2 where Q ξ = -ηv ⊤ ε + η 2 2 ε ⊤ Hε. First, we can obtain the square of the expected value of Q ξ as follows: E ξ [Q ξ ] = -ηv ⊤ E ξ [ε] + η 2 2 i,j H i,j E ξ [ε i ε j ] = η 2 2 i,j H i,j Σ i,j = η 2 2 tr(HΣ), E ξ [Q ξ ] 2 = η 4 4 tr(HΣ) 2 , ( ) where the last equation holds since Σ is symmetric and i,j H i,j Σ i,j = i,j H i,j Σ j,i = i [HΣ] i,i = tr(HΣ). Second, we have the expected value of the square of Q ξ as follows: (Isserlis, 1918) for zero-mean normal random vector ε, Q 2 ξ = η 2 i,i ′ v i v i ′ ε i ε i ′ -η 3 i,j,k v i H j,k ε i ε j ε k + η 4 4 j,k,j ′ ,k ′ H j,k H j ′ ,k ′ ε j ε k ε j ′ ε k ′ (41) and, since E ξ [ε i ε j ε k ] = 0 and E[ε j ε k ε j ′ ε k ′ ] = E[ε j ε k ]E[ε j ′ ε k ′ ] + E[ε j ε j ′ ]E[ε k ε k ′ ] + E[ε j ε k ′ ]E[ε j ′ ε k ] by Isserlis' theorem E ξ [Q 2 ξ ] = η 2 i,i ′ v i v i ′ Σ i,i ′ + η 4 4 j,k,j ′ ,k ′ H j,k H j ′ ,k ′ (Σ j,k Σ j ′ ,k ′ + Σ j,j ′ Σ k,k ′ + Σ j,k ′ Σ j ′ ,k ) (42) = η 2 v ⊤ Σv + η 4 4 (tr(HΣ) 2 + 2 tr(HΣHΣ)). Finally, we have the variance Var ξ [L t+1 |θ t ] = E ξ [Q 2 ξ ] -E ξ [Q ξ ] 2 (44) = η 2 v ⊤ Σv + η 2 2 tr(HΣHΣ) (45) = η 2 γ n,b b v ⊤ C 1 v + η 2 γ n,b 2b tr(HC 1 HC 1 ) . ( ) From Lemma 3, we can easily derive the following theorem from Chebyshev's inequality: Theorem 4. For SGD on a quadratic loss L, the expected loss at the next step (t + 1) is E ξ [L t+1 ] = L t - η 2 2 tr(S n ) tr(HS b ) tr(S n ) - 2 η . Further, if the batch gradient g ξ is normally distributed, then given δ > 0 and α ≥ β/δ, we have, with probability of at least 1 -δ, |L t+1 -E ξ [L t+1 ]| ≤ α, where β = η 2 γ n,b b v ⊤ C 1 v + η 2 γ n,b tr(HC 1 HC 1 ) and v = (I -ηH)∇L. However, we can replace this inequality with a better exponential inequality. To do so, we need a generalized Lemma 6 of the following Lemma 5 (Laurent & Massart, 2000) : Lemma 5 (Laurent & Massart (2000)). For i.i.d. Gaussian variables (Y 1 , • • • , Y D ) with mean 0 and variance 1, the following inequality holds for any positive x: P(Z -E[Z] ≥ 2∥a∥ 2 √ x + 2∥a∥ ∞ x) ≤ exp(-x), P(Z -E[Z] ≤ -2∥a∥ 2 √ x) ≤ exp(-x), where Z = i a i Y 2 i . Lemma 6. For i.i.d. Gaussian variables (Y 1 , • • • , Y D ) with mean 0 and variance 1, the following inequality holds for any positive x: P Z -E[Z] ≥ 2 ∥a∥ 2 2 + ∥c∥ 2 2 2 √ x + 2∥a∥ ∞ x ≤ exp(-x), P Z -E[Z] ≤ -2 ∥a∥ 2 2 + ∥c∥ 2 2 2 √ x ≤ exp(-x), where Z = i a i Y 2 i + c i Y i . Proof. For a standard normal random variable Y i ∼ N (0, 1), let ψ be the Cramér transform of Y 2 i + c i Y i -1: ψ(λ) = log E[exp(λ(Y 2 i + c i Y i -1))] = log R p(y; N (0, 1)) exp(λ(y 2 + c i y -1))dy (54) = log R 1 √ 2π exp - y 2 2 exp(λ(y 2 + c i y -1))dy (55) = log R 1 √ 2π exp - y 2 2 + λ(y 2 + c i y -1) dy (56) = log R 1 √ 2π exp - 1 2 (1 -2λ)y 2 + c i λy -λ dy (57) = log R 1 √ 2π exp(-λ) exp - 1 2 (1 -2λ)y 2 exp (-c i λy) dy (58) = log R exp(-λ)σ 1 √ 2πσ 2 exp - 1 2σ 2 y 2 exp (-c i λy) dy (σ 2 = (1 -2λ) -1 ) (59) = -λ + log(σ) + E y∼N (0,σ 2 ) [exp (-c i λy)] (60) = -λ - 1 2 log(1 -2λ) + (c i λσ) 2 2 (61) ≤ λ 2 (1 + c 2 i 2 ) 1 -2λ , where the last inequality holds for 0 < λ < 1 2 since -λ -1 2 log(1 -2λ) = -λ + 1 2 k≥1 (2λ) k k = 1 2 k≥2 (2λ) k k = 2λ 2 k≥0 (2λ) k k+2 ≤ 2λ 2 k≥0 (2λ) k 2 = λ 2 k≥0 (2λ) k = λ 2 1-2λ . Therefore, log E[exp(λZ)] = i log E exp a i λ Y 2 i + c i a i Y i -1 ≤ i (1 + 1 2 ( ci ai ) 2 )a 2 i λ 2 1 -2a i λ (63) ≤ (∥a∥ 2 + ∥c∥ 2 2 )λ 2 1 -2∥a∥ ∞ λ . ( ) We can obtain (69) from the following (Birgé & Massart, 1998) : if log E[exp(λZ)] ≤ b1λ 2 1-b2λ , then, for any positive x > 0, P(Z ≥ √ b 1 x + b 2 x) ≤ exp(-x). Also, for -1 2 < λ < 0, ψ(λ) ≤ λ 2 (1 + c 2 i 2 ) and thus log E[exp(λZ)] = (∥a∥ 2 + ∥c∥ 2 2 )λ 2 which leads to (70). Theorem 1. For SGD on a quadratic L, the expected loss increases, i.e., E ξ [L t+1 ] -L t > 0, if and only if θ t satisfies the instability condition tr(HS b ) tr(Sn) > 2 η . Furthermore, if the batch gradient g ξ is normally distributed, then the following inequalities hold for any positive x > 0: P L t+1 -E ξ [L t+1 ] ≥ 2βx + η 2 γ n,b b ∥H∥∥C 1 ∥x θ t ≤ exp(-x), P L t+1 -E ξ [L t+1 ] ≤ -2βx θ t ≤ exp(-x), where β = η 2 γ n,b b v ⊤ C 1 v + η 2 γ n,b tr(HC 1 HC 1 ) and v = (I -ηH)∇L. Proof. The first part of the statement is trivial from ( 7) with ϵ = 0 for a quadratic L. Now, we focus on the concentration of L t+1 from its expected value:  L t+1 -E[L t+1 ] = Z -E[Z] for Z = Q ξ = -ηv ⊤ ε + η 2 2 ε ⊤ Z = -ηv ⊤ ε + η 2 2 ε ⊤ Hε (65) = -η γ n,b b (B ⊤ v) ⊤ ε + η 2 γ n,b 2b ε⊤ B ⊤ HB ε (66) = c ⊤ Y + Y ⊤ diag(a)Y (67) = i a i Y 2 i + c i Y i , where A = η 2 γ n,b 2b B ⊤ HB = Q diag(a)Q ⊤ , c = -η γ n,b b Q ⊤ B ⊤ v and Y = [Y 1 , • • • , Y m ] = Q ⊤ ε. Therefore, from the rotation invariance, {Y i } are i.i.d. Gaussian variables with mean 0 and variance 1 like ε. From Lemma 6, we have the following inequalities: P L t+1 -E[L t+1 ] ≥ 2 ∥A∥ 2 F + ∥c∥ 2 2 2 √ x + 2∥A∥x ≤ exp(-x), P L t+1 -E[L t+1 ] ≤ 2 ∥A∥ 2 F + ∥c∥ 2 2 2 √ x ≤ exp(-x). Further, we have ∥A∥ 2 F = tr(A ⊤ A) = ( η 2 γ n,b 2b ) 2 tr(B ⊤ HBB ⊤ HB) = ( η 2 γ n,b 2b ) 2 tr(HBB ⊤ HBB ⊤ ) (71) = ( η 2 γ n,b 2b ) 2 tr(HC 1 HC 1 ) Note that θt and δt have the same characteristic equation. Thus, by Lemma 8, δt is asymptotically stable when the following three conditions (C1-3) hold: 1 + p 1 + p 2 = 1 -(1 + β 1 -(1 + β 2 )a) + β 1 -β 2 a > 0, (C1) (87) 1 -p 1 + p 2 = 1 + (1 + β 1 -(1 + β 2 )a) + β 1 -β 2 a > 0, (C2) (88) 1 -p 2 = 1 -β 1 + β 2 a > 0. (C3) C3 and C1 always hold because 1 -β 1 > 0 and 1 -(1 + β 1 -(1 + β 2 )a) + β 1 -β 2 a = a > 0. Lastly, since 1 + (1 + β 1 -(1 + β 2 )a) + β 1 -β 2 a = 2 + 2β 1 -(1 + 2β 2 )a > 0, The asymptotic convergence condition (C2) is equivalent to λ < 2 + 2β 1 (1 + 2β 2 )η = 2 η 1 + β 1 1 + 2β 2 = 2 η γ(β 1 , β 2 ). Therefore, along the direction of q 1 , the sequence {q ⊤ 1 δ t } diverges, and along the direction of q i (i > 1), the sequence {q ⊤ i δ t } converges to 0. The discriminant D(a 0 ) for the characteristic equation ϕ(x; a 0 ) is positive as shown below where a 0 = 2γ(β 1 , β 2 ), i.e., ϕ(x; a 0 ) has two distinct real solutions. D(a) = p 2 1 -4p 2 (93) = (1 + β 1 -(1 + β 2 )a) 2 -4(β 1 -β 2 a), D(a 0 ) = 1 + β 1 1 + 2β 2 2 -4 β 1 -2β 2 1 + 2β 2 (a 0 = 2γ(β 1 , β 2 )) = (1 + β 1 ) 2 -4(β 1 -2β 2 )(1 + 2β 2 ) (1 + 2β 2 ) 2 (96) = (1 -β 1 ) 2 -8((β 1 -1)β 2 -2β 2 2 ) (1 + 2β 2 ) 2 > 0. (94) implies that D(a) is convex quadratic with respect to a. Thus, to show that D(a) > 0 (ϕ(x; a) has two distinct real solutions) for all a > a 0 , it is sufficient to show that D ′ (a) ≥ 0 for a > a 0 . The following inequality holds D ′ (a) = 2(1 + β 2 ) 2 a -2(1 + β 1 -β 2 + β 1 β 2 ) ≥ 0 (98) if a ≥ 1+β1-β2+β1β2 (1+β2) 2 . And, D ′ (a) ≥ 0 for a > a 0 since a > a 0 > 1+β1-β2+β1β2 (1+β2) 2 from a 0 - 1 + β 1 -β 2 + β 1 β 2 (1 + β 2 ) 2 = 2(1 + β 1 ) 1 + 2β 2 - 1 + β 1 -β 2 + β 1 β 2 (1 + β 2 ) 2 (99) = 2(1 + β 1 )(1 + β 2 ) 2 -(1 + 2β 2 )(1 + β 1 -β 2 + β 1 β 2 ) (1 + 2β 2 )(1 + β 2 ) 2 (100) = 1 + β 1 + 3β 2 + β 1 β 2 + 4β 2 2 (1 + 2β 2 )(1 + β 2 ) 2 > 0. Now, for a > a 0 , we want to show that x 1 for the dominant solution (|x 1 | > |x 2 |) is negative so that the general solution δt = c 1 x t 1 + c 2 x t 2 = x t 1 c 1 + c 2 x2 x1 t oscillates where c i , x i ∈ R and i ∈ {1, 2}. The sum of the two solutions of ϕ(x; a) is x 1 + x 2 = -p 1 and this is negative for a > a 0 from the following: p 1 = -(1 + β 1 ) + (1 + β 2 )a (102) > -(1 + β 1 ) + (1 + β 2 )a 0 (103) = -(1 + β 1 -(1 + β 2 )a 0 ) (104) = -1 + β 1 -(1 + β 2 ) 2(1 + β 1 ) 1 + 2β 2 (105) = - (1 + β 1 )(1 + 2β 2 ) -(1 + β 2 )2(1 + β 1 ) 1 + 2β 2 (106) = 1 + β 1 1 + 2β 2 > 0. ( ) Thus, we have that x 1 for the dominant solution is negative which leads to an oscillatory behavior of q ⊤ 1 δ t . And it has the exponential growth of |q ⊤ 1 δ t | = Θ(|x 1 | t ) = Θ(e ct ) for c = ln |x 1 | > 0 since |q ⊤ 1 δ t | = |x 1 | t c 1 + c 2 x 2 x 1 t , ( ) 1 2 |c 1 ||x 1 | t ≤ |q ⊤ 1 δ t | ≤ 2|c 1 ||x 1 | t for t ≥ t 0 for some t 0 (109) (λ 1 violates (92) with λ 1 > 2 η γ(β 1 , β 2 ) which implies |x 1 | > 1) . This makes δ t to be asymptotically aligned with the top eigenvector q 1 of H, i.e., lim t→∞ | cos(q 1 , δ t )| = lim t→∞ |q ⊤ 1 δt| √ m i=1 (q ⊤ i δt) 2 = lim t→∞ | δ(1) t | m i=1 δ(i)2 t = 1 where δ(i) t = q ⊤ i δ t . Moreover, we can obtain the exponential convergence (1 -| cos(q 1 , δ t )| = O(e -2ct )) as follows: | cos(q 1 , δ t )| = | δ(1) t | δ(1)2 t + s 2 t (s 2 t = m i>1 δ(i)2 t → 0), 1 -| cos(q 1 , δ t )| = δ(1)2 t + s 2 t -| δ(1) t | δ(1)2 t + s 2 t (111) = δ(1)2 t + s 2 t -| δ(1) t | δ(1)2 t + s 2 t δ(1)2 t + s 2 t (112) = s 2 t δ(1)2 t s 2 t -| δ(1) t st | δ(1)2 t s 2 t + 1 + 1 δ(1)2 t + s 2 t , 0 ≤ 1 -| cos(q 1 , δ t )| ≤ 1 δ(1)2 t for t ≥ t 1 for some t 1 . ( ) The last inequality holds because a t = | δ(1) t /s t | diverges to ∞ and lim x→∞ x 2 -x √ x 2 + 1 + 1 = 1 2 , Moreover, we have ∇ θ L(θ t ) = Hθ t + b (115) = i λ i q i q ⊤ i θ t + i q i q ⊤ i b = i λ i q i (q ⊤ i θ t + q ⊤ i λ i b) = i λ i θ(i) t q i , ∥∇ θ L(θ t )∥ 2 = i λ 2 i θ(i)2 t , ( ) H∇ θ L(θ t ) = i λ i q i q ⊤ i i λ i θ(i) t q i = i λ 2 i θ(i) t q i , ∇ θ L(θ t ) ⊤ H∇ θ L(θ t ) = i λ i θ(i) t q ⊤ i i λ 2 i θ(i) t q i = i λ 3 i θ(i)2 t , where θ(i) t = q ⊤ i θ t + q ⊤ i λi b. Therefore, we have λ 1 -∥H∥ Sn = O(e -2ct ) from the following: ∥H∥ Sn = ∇ θ L ⊤ H∇ θ L ∥∇ θ L∥ 2 = i λ 3 i θ(i)2 t i λ 2 i θ(i)2 t = λ 1 θ(1)2 t + λ 3 2 λ 2 1 θ(2)2 t + • • • θ(1)2 t + λ 2 2 λ 2 1 θ(2)2 t + • • • , ( ) λ 1 -∥H∥ Sn = ( λ 2 2 λ1 - λ 3 2 λ 2 1 ) θ(2)2 t + • • • θ(1)2 t + λ 2 2 λ 2 1 θ(2)2 t + • • • = λ 2 2 λ 2 1 (λ 1 -λ 2 ) θ(2)2 t + • • • θ(1)2 t + λ 2 2 λ 2 1 θ(2)2 t + • • • , ( ) 0 ≤ λ 1 -∥H∥ Sn ≤ 1 θ(1)2 t for t ≥ t 2 for some t 2 (122) since lim t→∞ θ(i) t = 0 for i > 1. Remark. Due to the exponential convergence of ∥H∥ Sn to λ 1 , it only takes a few steps (5) (6) (7) (8) (9) (10) (11) (12) (13) (14) (15) (16) (17) (18) (19) (20) for ∥H∥ Sn to exceed 2 η γ(β 1 , β 2 ) (see Appendix C.3). Remark. In practice, ∥H∥ = λ 1 keeps increasing after it exceeds 2 η γ(β 1 , β 2 ). Therefore, we may relax the assumption as the eigenvalues λ 1 (θ t ) > 2 η γ(β 1 , β 2 )+ϵ 1 and ϵ 2 < λ i (θ t ) < 2 η γ(β 1 , β 2 )-ϵ 3 for i ̸ = 1 may change within the bounds over t for some ϵ 1 , ϵ 2 , ϵ 3 > 0 (q i 's are fixed). We can draw the same conclusion except that the limit ∥H∥ Sn may not exist because of varying ∥H∥ according to t, but we can ensure that ∥H∥ Sn eventually exceeds 2 η γ(β 1 , β 2 ).

C EXPERIMENTAL SETTINGS AND ADDITIONAL FIGURES

We report the experimental results using vanilla SGD/GD without momentum and weight decay, constant learning rate, and no data augmentation. We use a simple Moreover, we use PyHessian (Yao et al., 2020) 7 to compute the Hessian-vector product (e.g., H∇L), the top eigenvalue λ 1 and its corresponding eigenvector q 1 of the Hessian. For these computations, we use the power iterations with a batch size of 2k, the tolerance of 0.001, and the maximum iteration of 100.

C.1 FIGURE 1

In Figure 1 and Figure 2 , we plot E ξ [Lt+1]-Lt tr(Sn) against tr(HS b ) tr(Sn) , which is equivalent to Lt+1-Lt tr(Sn) against ∥H∥ Sn for GD. Therefore, we expect the following linear relationship with the slope η 2 2 and the x-intercept 2 η when the training loss L is locally quadratic, i.e., ϵ = 0: E ξ [L t+1 ] -L t tr(S n ) = η 2 2 tr(HS b ) tr(S n ) - 2 η , L t+1 -L t tr(S n ) = η 2 2 ∥H∥ Sn - 2 η . ( ) Figure 1 shows the behavior in the early phase, until the iterate enters the edge of stability. For GD, they are plotted after ∥H∥ exceeds 2 η after which ∥H∥ Sn starts to increase from 0 to 2 η in a few steps. For cross-entropy loss, we mark the end point with 'x' when the iterate enters the unstable regime. We optimize the loss by GD with η = 2 32 so that ∥H∥ = 32α > 2 η = 32. We also plot the GD trajectory in yellow starting from (θ 1 , θ 2 ) = (3, 0.1). Right: We show the exponential increase in |q ⊤ 1 ∇L| (purple) and the S-shape increase in ∥H∥ Sn (red) and | cos(q 1 , ∇L)| (green) to ∥H∥ and 1, respectively, which empirically demonstrates Theorem 2. We also note that they start to increase in the order of | cos(q 1 , ∇L)|, ∥H∥ Sn and |q ⊤ 1 ∇L|. After ∥H∥ exceeds 2 η (not shown in this Figure, see Figure 23 ), the cosine | cos(q 1 (H t ), ∇L(θ t ))| (green) between the sharpest direction and the gradient gets large, where H t = H(θ t ). Simultaneously, ∥H t ∥ Sn increases and exceeds 2 η (red solid > red dashed), the iterate entering the unstable regime and oscillating with cos(∇L(θ t ), ∇L(θ t+1 )) ≈ -1 (orange). However, due to the non-quadraticity, the sharpest direction changes with | cos(q 1 (H t+1 ), q 1 (H t ))| (cyan) close to 0. We train ResNet-9 on CIFAR-10-8k with η = 0.04 (top: steps 50-150, bottom: steps 0-800). 

C.5 FIGURE 5

Figure 34 -35 provide some additional information of Figure 5 . Figure 36 shows that 1/ρ ≈ 100 is much larger than 1 and that std[∥g ξ ∥] is 2-3× smaller than E ξ [∥g ξ ∥] even in the case of b = 1. Therefore, we use the approximation ∥g ξ ∥ ≈ E ξ [∥g ξ ∥], and thus the square of the mean resultant length is similar to the concentration measure ρ b as shown in the following approximation:  R2 b ≡ E ξ g ξ ∥g ξ ∥ 2 ≈ E ξ g ξ E ξ [∥g ξ ∥] 2 = E ξ [g ξ ] E ξ [∥g ξ ∥]

D DISCUSSION

We provide a new insight on the link between the batch gradient distribution and the sharpness of the loss landscape. In this section, we reconcile our arguments with some previous studies. Recently, Li et al. (2021) suggest a necessary condition that the "noise-to-signal ratio" needs to be large for LSR (and the SDE assumption) to hold. This is consistent with our result on the linear regime (where b and ρ b are small) because the noise-to-signal ratio is approximately the inverse of the "signal-to-noise" ratio ρ b = tr(S n )/ tr(S b ), but defined for an equilibrium distribution. We provide not only the necessary condition but also the sufficient condition for LSR with a novel scaling rule LSSR applicable to every batch size including where LSR fails (the saturation regime).



These two matrices C b and S b are often called the second central and non-central moments, respectively. But to avoid confusion, we use the term "second moment" only for the non-central S b . L(θt+1)-L(θt) ≤ ∇L ⊤ (θt+1 -θt)+ β 2 ∥θt+1 -θt∥ 2 = -η∥∇L∥ 2 + βη 2 2 ∥∇L∥ 2 = -η(1-βη 2 )∥∇L∥ 2 and thus the loss monotonically decreases when 0 < η < 2 β . from https://github.com/wbaek/torchskeleton, Apache-2.0 license https://www.cs.toronto.edu/~kriz/cifar.html https://cs.stanford.edu/~acoates/stl10/ https://tiny-imagenet.herokuapp.com https://github.com/amirgholami/PyHessian, MIT license



m×m and the second moment S b (θ) ≡ E ξ [g ξ (θ)g ξ (θ) ⊤ ] ∈ R m×m of the mini-batch gradient g ξ (θ) over batch sampling for a batch size 1 ≤ b ≤ n. 1 The covariance C b and the second moment S b satisfy not only C b = S b -S n but also the following equation (Hoffer et al., 2017; Li et al., 2017; Wu et al., 2020): C b = γ n,b b (S 1 -S n ) = γ n,b b C where γ n,b = n-b n-1

Figure1: [An empirical validation of(7)  for SGD (top) and (12) for GD (bottom)] In the early phase, until the iterate enters the unstable region, it validates(7)  and (12) with the blue line with the slope η 2 2 and x-intercept 2 η . For GD (bottom), they are plotted after ∥H∥ exceeds 2 η after which ∥H∥ Sn starts to increase from 0 to 2 η in a few steps. For cross-entropy (CE) loss, we mark the end point with 'x' when the iterate enters the unstable region. For mean squared error (MSE) loss (bottom right), we plot the graph for a few more steps after the iterate enters the EoS. We train 6CNN on CIFAR-10-8k with η = 0.02 (see Remark at the end of Section 4.2).

respectively. It has been empirically shown that, for full-batch GD, ∥H∥ increases and then hovers above 2/η and Cohen et al. (2021) mark the EoS with {θ ∈ Θ : ∥H(θ)∥ = 2 η }, but we mark with

We observe a similar behavior with oscillating tr(HS b ) tr(Sn) around 2 η for SGD. It requires further investigation into the exact underlying mechanisms (e.g., the progressive sharpening) (Arora et al., 2022; Li et al., 2022; Damian et al., 2022; Zhu et al., 2022) and we leave it as a future work. Remark (Experiments in Section 4.1 and 4.2). We report the experimental results using vanilla SGD/GD without momentum and weight decay, constant learning rate, and no data augmentation. We train a simple 6-layer CNN (6CNN, m = 0.51M) on CIFAR-10-8k where DATASET-n denotes a subset of DATASET with |D| = n and k = 2 10 = 1024. See Appendix C.1-C.3 for the results from other datasets, learning rates and networks (ResNet-9 with m = 2.3M (He et al., 2016) and WRN-28-2 with m = 36M (Zagoruyko & Komodakis, 2016) where m = dim(θ)).

Figure 4: [Clearer indication of the EoS] (a)-(b): After a few steps of full-batch training, ∥H∥ (blue) hovers above 2 η (Cohen et al., 2021), but ∥H∥ Sn (red, defined in (2)) oscillates around 2 η (red dashed horizontal line). The EoS is more evident in the latter (red). We also observe a sharp increase in ∥H∥ Sn right after ∥H∥ exceeds 2 η . Curves are plotted for every step. We train a model on CIFAR-10-8k (n = 2 13 ) using cross-entropy loss with η = 0.01/0.02, respectively. (c)-(d): We plot curves (c) ∥H∥ S b and (d) ∥H∥ S b ρ b when trained with various b. After a few steps, the curves in (c) reach the threshold 2ρ b η (see (d) together) which increases as b becomes larger when b ≪ n = 2 13 , and saturates to 2ρ b η ≈ 2 η when b is large. Curves are smoothed for visual clarity.

Figures 4(a)-4(b) show that, for GD (ρ n = 1), IAS ∥H∥ Sn (red) oscillates around 2 η and exhibits IIR. This result is consistent with Cohen et al. (2021) that ∥H∥ hovers above 2 η for GD. This is because, as mentioned earlier, 2 η ≈ ∥H∥ Sn ≤ ∥H∥ and the equality holds only when the gradient ∇L and the top eigenvector q 1 of H are aligned, which is in general not the case. For this reason, IIR provides a tighter relation and more clearly identifies the EoS thanCohen et al. (2021). These results are also consistent with Theorem 2 that ∥H∥ Sn suddenly increases from 0 to 2 η in a few steps after ∥H∥ exceeds 2 η (see Appendix C.3-C.4 for more). Moreover, IIR also applies to a general SGD training with 1 ≤ b ≤ n. Figure 4(c)-4(d) show IIR for SGD with different batch sizes. The upper bound 2ρ b

-25 in Appendix C.3). Moreover, LSR is known to fail when the batch size is large (Jastrzębski et al., 2017; Masters & Luschi, 2018; Zhang et al., 2019; Smith et al., 2020; 2021). On the other hand, Krizhevsky (2014); Hoffer et al. (2017) proposed the square root scaling rule (SRSR) with another ratio √ b/η to keep the covariance of the parameter update constant for b ≪ n based on Var ξ [ηg ξ ] = η 2 C b = γ n,b η 2 b C 1 ≈ η 2 b C 1 . However, Shallue et al. (2018) showed that both LSR and SRSR do not hold in general.

Figure 5 (left) demonstrates a new scaling rule with the ratio ρ b /η, called the Linear and Saturation Scaling Rule (LSSR), with the two regimes that (i) ρ b is almost linear when b ≪ n (linear regime) and (ii) ρ b saturates when b is large (saturation regime). It depends on which part of the denominator

Figure 5: [Linear and Saturation Scaling Rule (LSSR)] Left: LSSR (red) in (17), LSR (black dotted line)(Goyal et al., 2017) and SRSR (blue dotted line)(Hoffer et al., 2017). For LSSR, we can observe both linear and saturation regions (n = 8k, ρ = 2 -7 ). Right: Heatmaps of test accuracy for models trained with a large number of pairs of (b, η) on CIFAR-10-8k , CIFAR-100-8k , STL-10-4k, and Tiny-ImageNet-32k (from left to right, from top to bottom). It does not follow either LSR or SRSR, but tends to follow LSSR. We also plot ρ b /η = C (yellow dashed curve) for some ρ and C on each heatmap. Note that they are all log-log plots and thus a slope of 1 indicates linear relationship.

of Shallue et al. (2018, Section 4.7) shows similar "linear and saturation" behaviors supportive of LSSR on other datasets (see also Figure 7 of Zhang et al. (2019, Section 4.3)).

, we use subsets of the datasets CIFAR-10 (Krizhevsky & Hinton, 2009), CIFAR-100 (Krizhevsky & Hinton, 2009), STL-10 (Coates et al., 2011), and Tiny-ImageNet (a subset of ImageNet (Deng et al., 2009) with 3 × 64 × 64 images and 200 object classes). We use a large number of epochs (800) and batch normalization

trace of a square matrix A ∥H∥ S b ≡ tr(HS b ) tr(S b ) interaction-aware sharpness ρ b ≡ tr(S n ) tr(S b ) concentration measure of the batch gradient ρ ≡ tr(S n ) tr(S 1 ) concentration measure of the per-example gradient U = {θ ∈ Θ : ∥H∥ S b > 2ρ b η } unstable regime S = U c = {θ ∈ Θ : ∥H∥ S b ≤ 2ρ b η } stable regime ∂S = {θ ∈ Θ : ∥H∥ S b = 2ρ b η } the edge of stability β 1 , β 2 ∈ (0, 1] hyperparameters for generalized momentum GD B PROOFS AND REMARKS B.1 PROOF OF (1) We provide a proof of (1) to make the paper self-contained. Similar proofs are given in Li et al. (2017); Hoffer et al. (2017); Wu et al. (2020).

Hε where v = (I -ηH)∇L and ε ∼ N (0, C b ) from (38). For C b = γ n,b b BB ⊤ , we can represent ε = γ n,b b B ε with ε ∼ N (0, I) and

6-layer CNN (6CNN, m = 0.51M), ResNet-9 (He et al., 2016) 3 (m = 2.3M ), WRN-28-2 (Zagoruyko & Komodakis, 2016) (m = 36M ). We use subsets of the datasets CIFAR-10/100 (Krizhevsky & Hinton, 2009) 4 , STL-10 (Coates et al., 2011) 5 , and Tiny-ImageNet 6 (a subset of ImageNet (Deng et al., 2009) with 3 × 64 × 64 images and 200 object classes) where DATASET-n denotes a subset of DATASET with |D| = n and k=2 10 = 1024.

Figure 6: with η = 0.01. See caption of Figure 1 for more details.

Figure 10: ResNet-9 with MSE and η = 0.02/0.04. See caption of Figure 1 for more details.

Figure 19: (DATASET, b, η) = (STL-10-4k, 128, 0.02), (CIFAR-100-4k, 512, 0.01). See caption of Figure 2 for more details.

Figure 20: 6CNN with η = 0.02. See caption of Figure 3 for more details.

Figure 22: Left: A toy two-dimensional loss function L(θ) = θ 21 + 16αθ 2 2 where α = 1.1. We optimize the loss by GD with η = 2 32 so that ∥H∥ = 32α > 2 η = 32. We also plot the GD trajectory in yellow starting from (θ 1 , θ 2 ) = (3, 0.1). Right: We show the exponential increase in |q⊤  1 ∇L| (purple) and the S-shape increase in ∥H∥ Sn (red) and | cos(q 1 , ∇L)| (green) to ∥H∥ and 1, respectively, which empirically demonstrates Theorem 2. We also note that they start to increase in the order of | cos(q 1 , ∇L)|, ∥H∥ Sn and |q ⊤ 1 ∇L|.

Figure 23: After ∥H∥ (blue) exceeds 2 η (red dashed line) at t ≈ 80/176/400/1250/19/43/110/280, in a few steps (≈ 5/6/18/5/5/6/5/15), ∥H∥ Sn (red) starts to increase. As expected in Theorem 2, ∥H∥ Sn increases together with | cos(q 1 , ∇L)| (green) and |q ⊤ 1 ∇L| (purple). They are observed to start to increase in the order of | cos(q 1 , ∇L)|, ∥H∥ Sn and |q ⊤ 1 ∇L|, as shown in Figure 22.

Figure 24:After ∥H∥ exceeds 2 η (not shown in this Figure, see Figure23), the cosine | cos(q 1 (H t ), ∇L(θ t ))| (green) between the sharpest direction and the gradient gets large, where H t = H(θ t ). Simultaneously, ∥H t ∥ Sn increases and exceeds 2 η (red solid > red dashed), the iterate entering the unstable regime and oscillating with cos(∇L(θ t ), ∇L(θ t+1 )) ≈ -1 (orange). However, due to the non-quadraticity, the sharpest direction changes with | cos(q 1 (H t+1 ), q 1 (H t ))| (cyan) close to 0. We train ResNet-9 on CIFAR-10-8k with η = 0.04 (top: steps 50-150, bottom: steps 0-800).

Figure 25: We train ResNet-9 on CIFAR-10-8k with η = 0.08 (top: steps 0-100, bottom: steps 0-500). See caption of Figure 24 for more details.

Figure 29: WRN-28-2, MSE, and η = 0.01/0.02/0.04. See caption of Figure 4 for more details.

Figure 31: IIR of tr(HS b ) tr(Sn) ≤ 2 η for SGD. With the upper bound 2 η (red dashed line), this shows the effects of IIR more clearly than Figure 30. ResNet-9, CE, η = 0.08, and b ∈ {2 12 , 2 11 , • • • , 2 7 }.

Figure 33: Left/Right: See caption of Figure 4(c)/4(d) for more details. WRN-28-2, CIFAR-10-8k, η = 0.01.

Figure 34: (DATASET, MODEL, EPOCHS) = (CIFAR-100-32k, ResNet-9, 800 epochs), (Tiny-ImageNet-32k, ResNet-9, 400 epochs), (Tiny-ImageNet-32k, WRN-28-2, 800 epochs). Middle: The model is trained for 400 epochs, which is short compared to Figure 5 (bottom right heatmap). See caption of Figure 5 (right) for more details.

Figure 35: Left: We use γ n,b = 1 for sampling with replacement ('wr', dotted purple curve), which is almost equivalent to the "without replacement" counterpart (red). Right: LSSR for different ρ values with the corresponding LSR (dotted lines). See caption of Figure 5 (left) for more details.

Figure 38: We use β 1 = 0.1/0.2/0.4, and β 2 = 0 (Right, Polyak momentum (Polyak, 1963)) or β 2 = β 1 (Left, Nesterov momentum (Nesterov, 1983)) with DATASET = CIFAR-10-8k, η = 0.01 on ResNet-9 without BN. The red horizontal lines indicate 2 η γ(β 1 , β 2 ) where γ(β 1 , β 2 ) = 1+β1 1+2β2 . See caption of Figure 4 for more details.

Jastrzębski et al. (2017) explain the optimization behavior of SGD with the SDE approximation dθ t = -∇L(θ t )dt + η b C dW (t) of the SGD where W is an m-dimensional Brownian motion. Therefore, the same ratio η b = η ′ b ′ leads to the same SDE, which implies LSR. Moreover, a large η b implies a large diffusion in SDE, which has been linked with the escaping efficiency from a sharp local minimum in Zhu et al. (2019). Our arguments are free from any other problems raised for the SDE-based analyses of SGD which assume vanishing learning rates (Mandt et al., 2016; 2017; Hu et al., 2019; Li et al., 2017; 2019a; Jastrzębski et al., 2017; Smith & Le, 2018; Chaudhari & Soatto, 2018), e.g., the mismatch to practical finite learning rate regime or the inherent theoretical issues in the SDE approximations (Yaida, 2019; Li et al., 2021). We instead argue that a large second moment tr(S b ) (compared to tr(S n)) and a large η lead to a low constraint 2ρ b /η on the interaction-aware sharpness.Wu et al. (2020) empiricallyshow that what is important for the generalization performance of a neural network is not the class to which the gradient distribution belongs, but the second moment of the distribution. This is consistent with our arguments with the interaction tr(HS b ) and the concentration measure ρ b = tr(S n )/ tr(S b ), because they depend on the second moment S b , not on the class of the gradient distribution.

Some studies investigate the interaction between the gradient distribution and the loss landscape geometry represented by tr(HS b ) in the context of escaping efficiency (Zhu et al., 2019, Section 3.1), stationarity (Yaida, 2019, Section 2.2), and convergence (Thomas et al., 2020, Section 3.1.1). However, they require some additional assumptions like stochastic differential equation (SDE) approximation of SGD(Zhu et al., 2019), the existence of a stationary-state distribution of the model parameter (Yaida, 2019, Section 2.3.4), and strong convexity of the training loss function

Zhang et al., 2021; Gunasekar et al., 2017; Soudry et al., 2018; Jastrzębski et al., 2020; 2021; Barrett & Dherin, 2021; Smith et al., 2021). There are mainly two factors known to correlate with the generalization performance: the batch gradient distribution during training (Hoffer et al., 2017; Jastrzębski et al., 2017; Smith & Le, 2018; Zhu et al., 2019) and the sharpness of the loss landscape at the minimum (Hochreiter & Schmidhuber, 1997; Keskar et al., 2017; Dinh et al., 2017; Jiang et al., 2020; Foret et al., 2021;

∥H α ∥ Sn is usually larger than ∥H∥ Sn . On the other hand, when ∥H∥ Sn > 2 η (blue), ∥H α ∥ Sn is usually smaller than ∥H∥ Sn . Right:

Krizhevsky, 2014; Goyal et al., 2017; Jastrzębski et al., 2017; Smith & Le, 2018; Zhang et al., 2019). They argued that LSR holds because θ t+k

acknowledgement

ACKNOWLEDGEMENTS S. Lee is also the lead author and the corresponding author. This work was done while S. Lee was at Korea Institute for Advanced Study. S. Lee was supported by a KIAS Individual Grant (AP083601) via the Center for AI and Natural Sciences at Korea Institute for Advanced Study. C. Jang was supported in part by IITP Artificial Intelligence Graduate School Program for Hanyang University funded by MSIT (Grant No. 2020-0-01373) and NRF/MSIT (Grant No. 2021M3E5D2A01019545). This work was supported by the Center for Advanced Computation at Korea Institute for Advanced Study. This work was supported by the research fund of Hanyang University (HY-202300000000552).

REPRODUCIBILITY STATEMENT

We refer the reader to the following pointers for reproducibility:• Codes to reproduce some Figures: Supplementary Material.• Proofs of the claims: Appendix B. • Experimental settings : Remarks at the end of Sections 4.2 and 5, and Appendix C. • Additional Figures for other hyperparameter settings: Appendix C.1-C.6.

annex

andThese lead to ∥A∥ 2 F + ∥c∥ 2 2 2 = (2b ) 2 tr(HC 1 HC 1 ) +2 and then (9) follows from (70). Moreover, we have ∥A∥ ≤2b ∥H∥∥C 1 ∥ which leads to (8) from (69).

B.3 PROOF OF THEOREM 2

We consider the following linear homogeneous second-order equation for the sequence {y t }: y t+2 + p 1 y t+1 + p 2 y t = 0 (74) and its the characteristic equation ϕ(x) = x 2 + p 1 x + p 2 = 0.(75) 2005)). The conditions(C1) are necessary and sufficient for the equilibrium point (solution) of (74) to be asymptotically stable (i.g., all solutions converges to 0). Theorem 2. For generalized momentum GD with1+2β2 , then q ⊤ 1 δ t oscillates and diverges with the exponential growth of |q ⊤ 1 δ t | = Θ(e ct ) for some c > 0. Moreover, | cos(q 1 , δ t )| and ∥H∥ Sn increase to 1 and λ 1 , as t → ∞, respectively, with 1 -| cos(q 1 , δ t )|, λ 1 -∥H∥ Sn = O(e -2ct ).We provide a proof of Theorem 2 which is modified from the proofs in Appendix A in Cohen et al. (2021) .Proof. Put a quadratic training loss L(θ) = 1 2 θ ⊤ Hθ + b ⊤ θ + c. The update rules for generalized momentum GD with (β 1 , β 2 ) on this quadratic function L(θ) are:and thus, for a pair of eigenvalue/eigenvector (q, λ) the update rules for the quantities θt = q ⊤ θ t +where a = ηλ > 0 and this leads to: 

C.6 GENERALIZED MOMENTUM VARIANTS

For generalized momentum variants of GD with (β 1 , β 2 ), we haveWhen ∥H∥ is very small in the beginning, we may approximate that L is linear where ∇ θ L(θ) = g is constant with respect to θ in some region. Then δ t converges to δ satisfying the following equations:and thusTherefore, together with Theorem 2, we generalize (12) with the following approximation:where γ(β 1 , β 2 ) = 1+β1 1+2β2 . In Figures C.6 and C.6, we empirically validate (131) similar to Figures 1 and 4 , respectively. 

