A QUASISTATIC DERIVATION OF OPTIMIZATION ALGO-RITHMS' EXPLORATION ON THE MINIMA MANIFOLD

Abstract

A quasistatic approach is proposed to derive the optimization algorithms' effective dynamics on the manifold of minima when the iterator oscillates around the manifold. Compared with existing strict analysis, our derivation method is simple and intuitive, has wide applicability, and produces easy-to-interpret results. As examples, we derive the manifold dynamics for SGD, SGD with momentum (SGDm) and Adam with different noise covariances, and justify the closeness of the derived manifold dynamics with the true dynamics through numerical experiments. We then use minima manifold dynamics to study and compare the properties of optimization algorithms. For SGDm, we show that scaling up learning rate and batch size simultaneously accelerates exploration without affecting generalization, which confirms a benefit of large batch training. For Adam, we show that the speed of its manifold dynamics changes with the direction of the manifold, because Adam is not rotationally invariant. This may cause slow exploration in high dimensional parameter spaces.

1. INTRODUCTION

The ability of stochastic optimization algorithms to explore among (global) minima is believed to be one of the essential mechanisms behind the good generalization performance of stochastically trained over-parameterized neural networks. Until recently, research on this topic has focused on how the iterator jumps between the attraction basins of many isolated minima and settles down around the flattest one Xie et al. (2020) ; Nguyen et al. (2019) ; Dai & Zhu (2020) ; Mori et al. (2021) . However, for over-parameterized models, the picture of isolated minima is not accurate, since global minima usually form manifolds of connected minima Cooper (2018) . In addition to crossing barriers and jumping out of the attraction basin of one minima, the optimizer also moves along minima manifold and search for better solutions Wang et al. (2021) . Hence, understanding how optimization algorithms explore along the minima manifold is crucial to understanding how stochastic optimization algorithms are able to find generalizing solutions for over-parameterized neural networks. Some recent works have begun to examine the exploration dynamics of Stochastic Gradient Descent (SGD) along minima manifolds. Many of these works have identified how a change of flatness in the minima manifold adds a driving force to SGD as it oscillates around the minima. For example, Damian et al. (2021) considered an SGD training a neural network with label noise, and showed that the optimizer can find the flattest minimum among all global minima. A more recent work Li et al. (2021b) derived an effective stochastic dynamics for SGD on the manifold. The results in Li et al. (2021b) show that (when the learning rate tends to zero) the changing flatness can give a force to the SGD iterator along the minima manifold and induce a slow dynamics on the manifold that helps the SGD move to the vicinity of flatter minima. In this work, we study the same questions of the flatness-driven exploration along the minima manifold. Instead of searching for a strict proof, we focus on simple and intuitive ways to derive the manifold dynamics. Specifically, we propose a quasistatic approach to derive the manifold dynamics for different optimization algorithms and stochastic noises. The main technique of our derivation is a time-scale decomposition of the motions perpendicular to and parallel with the minima manifold, which we call the normal component and the tangent component, respectively. We treat the normal component as infinitely faster than the tangent component, and thus it is always at equilibrium given the tangent component. The effective dynamics of the tangent component, i.e. the manifold dynamics, is obtained by taking the expectation over the equilibrium distribution of the normal component. The main step in our analysis involves deriving the equilibrium covariance of an SDE. Compared with the theoretical analysis in Li et al. (2021b) , our derivation and results are simpler and easier to interpret, and clearly identifies the roles played by each component of the optimization algorithm (noise covariance, learning rate, momentum). The following simple example demonstrates the main idea of the our derivations. A simple illustrative example: Consider a loss function f (x, y) = h(x)y 2 , where h(x) > 0 is a differentiable function of x. The global minima of this function lie on the x-axis, forming a flat manifold, and h(x) controls the flatness of the loss function at any minimum (x, 0). Let z = [x, y] T . We consider an SGD approximated by SDE Li et al. (2017) Li et al. (2021a)  dz t = -∇f (z t )dt + √ ηD(z t )dW t , where η is the learning rate, D is the square root of the covariance matrix of the gradient noise, and W t is a Brownian motion. For the convenience of presentation, for points (x, y) that are close to the x-axis, we assume the noise covariance aligns with the Hessian of the loss function at (x, 0), i.e. D 2 (z) = σ 2 2 Hf (x, 0) = 0 0 0 σ 2 h(x) , where σ > 0 is a scalar. Then, the SDE equation 1 can be written as dx t = -h ′ (x t )y 2 t dt, dy t = 2h(x t )y t dt + σ ηh(x t )dW t , with W t being a 1-D Brownian motion. When y t is close to 0, the speed of x t is much slower than y t because of the y 2 t in the dynamics of x is much smaller than the y t in the dynamics of y. When this separation of speed is large, the dynamics above can be approximated by the following quasistatic dynamics dx t = -lim τ →∞ E yτ h ′ (x t )y 2 τ dt, dy τ = 2h(x t )y τ dτ + ηh(x t )σdW τ . which assumes y is always at equilibrium given x t . Solving the Ornstein-Uhlenbeck process 2, we know the equilibrium distribution of y τ is ∼ N (0, ησ 2 4 ), and hence the manifold dynamics is dx t dt = - ησ 2 h ′ (x t ) 4 . This derivation shows the slow effective dynamics along the manifold is a gradient flow minimizing the flatness h(x). This simple quasistatic derivation reveals the flatness-driven motion of SGD along the minima manifold, and recovers the same dynamics as given by Li et al. (2021b) in this specific case. On the left panel of Figure 1 , we show an SGD trajectory for f (x, y) = (1 + x 2 )y 2 , illustrating the exploration along the manifold due to the oscillation in the normal space. On the right panel we verify the closeness of the manifold dynamics with the true SGD trajectories for the same objective function. The "Hessian noise" and "Isotropic noise" represent noises whose covariance are the Hessian matrix of f (as analyzed above) and the identity matrix (covered by the analysis in Section 2), respectively. Theoretical applications of the manifold dynamics: The minima manifold dynamics of optimization algorithms can be used as a tool to study and compare the behaviors of optimization algorithms. In Section 3, we illustrate how our derivations can be applied to study the behavior of SGD on a matrix factorization problem. Two more interesting applications are discussed Section 4 and 5. In Section 4, we focus on SGDm, and study the role played by the learning rate, batch size, and momentum coefficient in its manifold dynamics. Based on the analysis, we explore approaches to reliably accelerate the manifold dynamics, which may help accelerate training. Especially, we show that scaling up learning rate and batch size simultaneously accelerates exploration without affecting generalization, which confirms a benefit of large batch training. In Section 5, we study adaptive gradient methods, and show that the speed of the manifold dynamics of Adam Kingma & Ba (2014) changes with the direction of the manifold, because Adam is not rotationally invariant. When the manifold does not align well with some axis direction, the exploration of Adam along the manifold is even slower than SGD with the same learning rate. This shows the sensitivity of Adam (and other adaptive gradient methods) to the parameterization and a potential weakness of Adam on the exploration among global minima. 

2. THE QUASISTATIC DERIVATION FOR MANIFOLD DYNAMICS

In this section, we introduce our quasistatic approach for deriving minima manifold dynamics, which is an effective exploration dynamics of optimization algorithms on minima manifolds. Notations: Let M be a smooth manifold in R d with Euclidean metric. Throughout this paper, we use z to denote points in R d , including M, and use x and y to denote components of z used in the quasistatic derivation for speed separation. For any z ∈ M, let T z M be the tangent space of M at z. Let T z M ⊥ be the orthogonal complement of T z M, which is the normal space of M at z. Let P M be the projection operator onto M, i.e. for any z ∈ R d , P M z gives the closest point on M to z (if exists and unique). P M is well defined when z is close to M Lee (2013). In the paper, we ignore the subscript M for P M when there is no confusion. Problem settings We consider a boundless k-dimensional smooth manifold M in R d , formed by (local or global) minima of a function f : R d → R. Since we are interested in the behavior of stochastic optimizers near the minima manifold M, we ignore the landscape of f when x is far away from M, and only consider a quadratic expansion of f at M. Specifically, let H(•) : M → R d×d be a function on M that gives the Hessian of the loss function on the minima manifold. For any z ∈ M, we assume H(z) is positive semidefinite, and its 0-eigenspace is exactly T z M. The loss functions that we consider take the form f (z) = (z -Pz) T H(Pz)(z -Pz), z ∈ R d . For the optimization dynamics, we start from SGD, approximated by the following SDE Li et al. ( )Li et al. (2021a) dz t = -∇f (z t )dt + √ ηD(Pz t )dW t , where W t is a Brownian motion in R d and D : M → R d×d is the square root of the noise covariance. Later we will extend our analysis to SGDm and Adam. Strictly speaking, the noise of SGD depends on z, which is not always on M. However, in settings that we study, we assume it only depends on Pz because z is close to M.

2.1. FLAT MANIFOLD

We start from the case in which the minima manifold M is flat. The derivation in this case is similar to the 2-D example given in the introduction. For any z ∈ R d , let z = [x T , y T ] T , with x ∈ R k and y ∈ R d-k . Without loss of generality, we assume M = {z = [x T , y T ] T : y = 0}, i.e. M is the linear subspace formed by the first k axes. Then, for any z = x y , we have Pz = x 0 and z -Pz = 0 y . For the loss function f , since we assume its Hessian has zero eigenvalues along the tangent space of M, the Hessian must take the form d-k) . Here, we can treat H as a function of x because it is defined on M. Hence, written as a function of x and y, the loss function 3 becomes f (x, y) = y T H(x)y. H(x) = 0 0 0 H(x) with H(x) ∈ R (d-k)×( (5) Next, we rewrite the SDE 4 using x and y. Again, since the noise coefficient D in 4 is defined on M, it can be treated as a function of x. For any x, let D(x) = D11 (x) D12 (x) D21 (x) D22 (x) ∈ R d×d with D11 (x) ∈ R k×k . Then, D11 represents the noise in the tangent space, D22 represents the noise in the normal space, and D12 , D21 represent the interaction of the tangent space and normal space noises. In many cases, the noise covariance matrix of SGD aligns with the Hessian Mori et al. (2021) . Hence, D22 dominates the other three components. (Actually, when the noise covariance strictly aligns with the Hessian, only D22 is nonzero.) In the following, we assume the interactions D12 and D21 are 0, while D11 can still be nonzero. Then, by the form of the loss function in 5, the SDE 4 can be written as the following system of x and y: dx t = -y T t ∂ x H(x t )y t dt + √ η D11 (x t )dW (1) t dy t = -2 H(x t )y t dt + √ η D22 (x t )dW (2) t , where ∂ x H(x t ) is a (d -k) × (d -k) × k tensor containing all partial derivatives of x. When z is close to M, y is small, in which case the dynamics of y is much faster than that of x, because the drift term for x depends quadratically with y while the drift term for y only depends linearly with y. Therefore, we can use a quasistatic dynamics to approximate the original dynamics. The quasistatic dynamics assumes that y is always at equilibrium: dx t = -lim τ →∞ y T τ ∂ x H(x t )y τ dt + √ η D11 (x t )dW (1) t ( 7) dy τ = -2 H(x t )y t dt + √ η D22 (x t )dW (2) τ . Fixing x t , the dynamics for y τ is a linear SDE. We have y τ as lim τ →∞ Ey τ = 0 and lim τ →∞ Ey τ y T τ = V t , where V t ∈ R (d-k)×(d-k) satisfies H(x t )V t + V t H(x t ) = η D22 (x t ) D22 (x t ) T /2. The derivations here are standard. Readers can refer to textbooks or lecture notes such as Herzog (2013) . Substituting the moments into the u-dynamics in 7, we have dx t = - n-k i,j=1 (V t ) ij ∇ x ( H(x t ) ij )dt + √ η D11 (x t )dW (1) t . Understanding x as a vector on M, equation 9 gives the effective manifold dynamics on M. This result recovers the simple example in the introduction if we take f (x, y) = h(x)y 2 and D 2 = σ 2 H/2.

2.2. GENERAL MANIFOLD

For general smooth manifold, the manifold dynamics can be derived locally by approximating M using a flat manifold. The resulting dynamics is different from 9 only in that the gradients are taken on the manifold. To see this, consider any point z 0 ∈ M. Without loss of generality, we assume there exists x 0 ∈ R k , such that z 0 = x 0 0 , and T z0 M = x 0 : x ∈ R k . Because M is a smooth manifold, around z 0 the projection operator onto T z0 M, denoted by P z0 , induces an 1-1 map between M and T z0 M, and we have ∥z -P z0 z∥ = O(∥z -z 0 ∥ 2 ). Let P -foot_0 z0 : T z0 M → M be the inverse of this 1-1 map. With an abuse of notations, for z = [x T , 0] T ∈ T z0 M, we sometimes also use P -1 z0 x to denote P -1 z0 z. Let O z0 = {e 1 , ..., e d } be the standard orthonormal basis for R d , with {e 1 , ..., e k } being an orthonormal basis for T z0 M. Also because M is smooth, for any z ∈ M close to z 0 , there exists an orthonormal basis O z = {e z 1 , ..., e z d } for R d which is close to O z0 , such that {e z 1 , ..., e z k } form an orthonormal basis for T z M. Specifically, for any 1 ≤ i ≤ d we have ∥e z i -e i ∥ = O(∥z -z 0 ∥ 2 ). Now, consider SDE 4 with the loss function f defined in 3. For any z ∈ M close to z 0 , let H(z) = 0 0 0 H(z) and D(z) = D11 (z) 0 0 D22 (z) be the Hessian and the noise coefficient matrix expressed in O z . The zeros in the expressions are due to assumptions on H and the noise, i.e. T z M is in the 0-eigenspace of H, and there is no interaction between the tangent space noise and the normal space noise. Now, we define a companion loss function f whose minima manifold is T z0 M by f (z) = y T H(P -1 z0 x)y, where z = [x T , y T ] T , and consider a companion SDE dz t = -∇ f (z t )dt + √ ηD(P -1 z0 P z0 z t )dW t . ( ) The SDE above approximates an SGD minimizing f , whose minima manifold is flat. Hence, using the results from the previous subsection, we can derive a manifold dynamics on T z0 M, dx = - n-k i,j=1 (V t ) ij ∇ x ( H(P -1 z0 x) ij )dt + √ η D11 (P -1 z0 x)dW (1) t , where V t is obtained by H(P -1 z0 x)V t + V t H(P -1 z0 x) = η D22 (P -1 z0 x) D22 (P -1 z0 x) T /2 1 . By the discussions above, around z 0 the SDE 4 is close to 10, because f (z) ≈ f (z) and D(P -1 z0 P z0 z t ) ≈ D(Pz). Hence, the effective manifold dynamics of 4 is close to 11 in a neighborhood of z 0 , and this approximation is better in smaller neighborhood of z 0 . Therefore, at z = z 0 the manifold dynamics is approximately 11 taking x = x 0 , which leads to dz = - n-k i,j=1 V ij ∇ x ( H(z) ij )dt + √ η D11 (z)dW t . Since ∇ x ( H(z) ij ) are gradients in the tangent space T z0 M, under Euclidean metric they are gradients on the manifold at z 0 . Hence, at z 0 the manifold dynamics can be written as dz = - n-k i,j=1 V ij ∇ M ( H(z) ij )dt + √ η D11 (z)dW (1) t . ( ) Since the above analysis holds for any z 0 ∈ M, the manifold dynamics of 4 is given by 12 at any z. Examples: One interesting case is when the noise covariance matrix is proportional with the Hessian. In this case, D 2 (z) = σ 2 H(z) for any z ∈ M. Since H(z) contains all the nonzero eigenvalues of H(z) and the nonzero eigenspace corresponds to T z M ⊥ , we have D11 = 0 and D22 (z) = σ H(z). In this case, V t satisfies H(z t )V t + V t H(z t ) = ησ 2 H(z t )/2, which gives V t = ησ 2 4 I. Substituting V t into the dynamics equation 12 we have the following effective dynamics dz t = - ησ 2 4 ∇ M Tr( H(z t ))dt. ( ) In the case of a flat manifold, this result above corresponds to the dynamics we derived in the introduction. Effectively, the SGD is minimizing ησ 2 4 Tr(H(z)) on the manifold using a gradient flow. (Note that Tr( H(z))=Tr(H(z)) for z ∈ M). For another case, if we assume the noise in T z M ⊥ is isotropic with a constant magnitude, we have D22 (z) = σI for some σ. Hence, we have V t = ησ 2 4 H-1 (z t ), and the manifold dynamics becomes dz t = - ησ 2 4 ∇ M Tr(log H(z t ))dt. Effectively, the SGD is minimizing ησ 2 4 Tr(log H(z)) on the manifold with a gradient flow. Remark 1. The manifold dynamics we derive is similar to that studied in Li et al. (2021b) . Instead of providing a rigorous proof, our main contribution is to give a simple and intuitive quasistatic approach to derive the manifold dynamics. Our methods can be applied to a wide class of noise models, and also can be applied to other optimizers such as SGD with momentum (See Section 2.3) and Adam (See Section 5).

2.3. EXTENDING ANALYSIS TO SGD WITH MOMENTUM

The quasistatic approach we take can be extended to derive the effective manifold dynamics for SGDm. We consider the following SGDm scheme on the same loss function f studied above: m k+1 = µm k -η∇f (z k ), z k+1 = z k + m k+1 , ( ) where η is the learning rate, µ ∈ [0, 1) is the momentum factor, and m is the momentum. By the derivation in Li et al. (2017) , we consider the following SDE system that approximates 15: dm t = - 1 -µ η m t + ∇f (z t ) dt + √ ηD(Pz t )dW t , dz t = 1 η m t dt. ( ) By the discussion for SGD, the manifold dynamics can be obtained by assuming that M is flat, and applying a quasistatic analysis on a decomposition of tangent and normal components. We put the details in Appendix A. The resulting manifold dynamics on a general manifold is dm t = -   1 -µ η m t + n-k i,j=1 (V sgdm t ) ij ∇ M ( H(z t ) ij )   dt + √ η D11 (z t )dW t , dz t = 1 η m t dt, where the form and derivation of V sgdm t are given in Appendix A (equation 39). Note that when η is small, equation 39 is close to HV + V H = η 2(1-µ) D22 DT 22 . Let Ṽ sgdm be the solution of this equation. We have V sgdm ≈ Ṽ sgdm = V sgd 1-µ , where V sgd is the matrix V for SGD used in previous sections. This shows that the momentum amplifies the flatness driven force by a factor of 1/(1 -µ). Besides this acceleration, the momentum scheme itself also accelerates the speed of manifold dynamics. To see this, when there is no noise along the manifold direction, i.e. D11 = 0, by Kovachki & Stuart (2021) , the ODE żt = - 1 1 -µ n-k i,j=1 (V sgdm t ) ij ∇ M ( H(z t ) ij ) is a first-order approximation of the manifold dynamics 17 (This approximation assumes the momentum is always at equilibrium). The term 1/(1 -µ) is the acceleration brought by the momentum scheme. In this case, compared with SGD, the approximate manifold dynamics for SGDm is żt = - 1 (1 -µ) 2 n-k i,j=1 (V sgd t ) ij ∇ M ( H(z t ) ij ), which is 1/(1 -µ) 2 faster than SGD. The full derivation for SGDm are put in Appendix A. Examples: We still consider the example where D11 = 0 and D22 (z) = α H(z). In this case, V sgdm ≈ Ṽ sgdm = ηα 2 4(1-µ) I. Then, the approximate effective dynamics according to equation 18 is żt = - ηα 2 4(1 -µ) 2 ∇ M Tr( H(z t )). ( ) As a numerical justification for the manifold dynamics we derived for SGDm, for the same function f (x, y) = (x 2 + 1)y 2 tested in Figure 1 , we compare the true x-coordinate dynamics, the SGDm-like discretization for equation 17, and the ODE solution of 19. Shown in the left panel of Figure 2 , the three dynamics are close for all µ tested. The results also show that the manifold dynamics get faster for larger µ, as predicted by its expression.

3. APPLICATION ON MATRIX FACTORIZATION PROBLEMS

In this section, we consider an objective function Let H(U, V ) be the Hessian matrix of f . f (U, V ) = ∥U V T -M ∥ 2 F with U ∈ R m×p , V ∈ R n×p Theoretically, it is easy to verify that Tr(H(U, V )) is proportional to ∥U ∥ 2 F + ∥V ∥ 2 F wherever U V T = M . Therefore, the manifold dynamics of SGD minimizes ∥U ∥ 2 F + ∥V ∥ 2 F and drives the iterator to the most balanced global minimum, which is also the flattest. Numerically, we take m = n = p = 5 and compare the true dynamics with the manifold dynamics. The experiments are initialized from U 0 = M, V 0 = I. The middle panel of Figure 2 shows the distance between the true dynamics and the manifold dynamics, as well as the distance traveled by the true dynamics, for SGD and SGDm. On the right panel, we inject an isotropic noise on the tangent space for SGD and do the same comparison. The results show good approximation of the manifold dynamics to the real dynamics. Under the Hessian noise, then, the SGD indeed moves towards minima with smaller Frobenius norm, same as the manifold dynamics.

4. LEARNING RATE, BATCH SIZE, AND MOMENTUM

With the manifold dynamics, we can study the impact of hyperparameters on the behavior of optimizers around manifold of minima. Here, we focus on the learning rate η, batch size B, and momentum µ of the SGDm algorithm. Written in the SDE form, the batch size changes the covariance of the noise by a factor 1/B. Hence, using equation 16, the SDE with batch size B is dm t = - 1 -µ η m t + ∇f (z t ) dt + √ η √ B D(Pz t )dW t , dz t = 1 η m t dt. ( ) To focus on the drift dynamics on the manifold and avoid the influence of the noise, we consider Hessian noise which only exists in the normal space of the manifold, i.e. we assume D11 = 0, D22 (z) = σ H(z). Then, by equation 18, the first-order ODE on the minima manifold representing the manifold dynamics is żt = - ησ 2 4B(1 -µ) 2 ∇ M Tr( H(z t )). By equation 21, the manifold dynamics takes the same trajectory with different speed for different hyperparameters. Let zt be the trajectory of żt = -σ 2 4 ∇ M Tr( H(z t )). Considering the discretization, an SGDm with learning rate η, batch size B, and momentum µ takes T B(1-µ) 2 η 2 iterations to solve for zt until t = T . Hence, decreasing the batch size, or increasing the learning rate or momentum factor, can accelerate the (discrete) manifold dynamics. We let s(η, µ, B) := η 2 B(1-µ) 2 be the speed factor for the dynamics. The experiment results in the left panel of Figure 3 justify that the speed factor indeed controls the dynamics' speed. Implications in practical cases: For the training process of over-parameterized neural networks, the exploration around the minima manifold is an important source of implicit regularization. The driven force of the movement along the manifold is still the change of flatness. However, in this more complicated case, the discussion above may face two problems: (1) The curvature in the directions perpendicular with the manifold may not be quadratic. Which leads to different manifold trajectory if the range of oscillation is different. (i.e. if the iterator oscillates in a larger range, the flatness driven force may change its direction.) (2) The manifold dynamics is just a first-order approximation of the true dynamics, which may not be accurate for a long time period. The second problem is intrinsic to all studies that use continuous dynamics to approximate discrete dynamics. It will not impose a serious problem as long as the curvature on and around the minima manifold does not change drastically. The first problem motivates us to find ways to accelerate the manifold dynamics without changing the range of oscillation. By the discussion in Section 2.3, the range of the oscillation is given by V sgdm , which is ηαfoot_1 4B(1-µ) I. Let r(η, µ, B) := η B(1-µ) be the "range factor". Then, Combining the discussions above, we want to increase the speed factor s(η, µ, B) while keeping the range factor r(η, µ, B) fixed. Since the ratio η/(1 -µ) appears in both factors, to increase the dynamics speed without changing the oscillation range we must change B. Concretely, if we pick B ′ = cB and η ′ , µ ′ such that η ′ 1-µ ′ = c η 1-µ , then η ′ B ′ (1 -µ ′ ) = η B(1 -µ) , η ′2 B ′ (1 -µ ′ ) 2 = c η 2 B(1 -µ) 2 , i.e. the range factor is not changed while the speed factor is multiplied by c. In the second panel of Figure 3 

5. ADAPTIVE GRADIENT METHODS AND ROTATIONAL INVARIANCE

We can also study the manifold dynamics of adaptive gradient methods. We start with experiments which show that the manifold dynamics of Adam changes according to the direction of the manifold. This is due to the fact that adaptive gradient methods are not rotational invariant. We consider the loss function, f (x, y) = (x sin θ + y cos θ) 2 ((x cos θ -y sin θ) 2 + 1), (23) which is the counterclockwise rotation of (x 2 + 1)y 2 by θ 2 . The minima manifold of f is the line x sin θ + y cos θ = 0. We run Adam Kingma & Ba (2014) on f with different θ with Hessian noise. The right panel of Figure 3 compares the dynamics projected onto the minima manifold. The results show that Adam moves very fast along the manifold when the manifold aligns well with an axis. When the manifold does not align with an axis, Adam moves much slower, sometimes even slower than a plain SGD. This is because in Adam the adaptive learning rate is computed for each coordinate, and hence when the manifold direction is close to an axis, the learning rate along the manifold can be drastically increased due to the small gradient on this corresponding axis direction. Otherwise, all axis directions have big gradients due to the oscillation and the learning rate along the manifold is not increased in a desirable way. Using the SDE approximation for Adam recently derived in Malladi et al. (2022) , we can derive and compare the manifold dynamics for Adam on the loss equation 23 for different θ. Consider an Adam algorithm with hyperparameters (β 1 , β 2 , η, ϵ), where β 1 and β 2 are momentum coefficients for the first and second order moments, respectively, η is the learning rate, and ϵ is the small number that prevents division by zero Kingma & Ba (2014) . By Malladi et al. (2022) , let Σ be the gradient noise covariance matrix depending on the parameters, and σ be a additional noise strength (i.e. the real noise covariance is σΣ), define σ 0 = ση, ϵ 0 = ϵη, c 1 = (1 -β 1 )/η 2 , c 2 = (1 -β 2 )/η 2 , γ 1 (t) = 1 -e -c1t , and γ 2 (t) = 1 -e -c2t , then the Adam trajectory is approximated by the SDE: dx t = - γ 2 (t) γ 1 (t) P -1 t m t dt, dm t = c 1 (∇f (x t ) -m t )dt + σ 0 c 1 Σ 1/2 (x t )dW t , du t = c 2 (diag(Σ(x t )) -u t )dt, P t = σ 0 diag(u t ) 1/2 + ϵ 0 γ 2 (t)I. Here, f is the loss function, x is the parameter, and W t is a Brownian motion. The time scale of the SDE above is t = kη 2 , which is different from the usual time scale t = kη studies for other optimization algorithms. Using the quasistatic approach, in Appendix B we derive the approximate manifold dynamics for two cases: (1) θ = 0, in which the minima manifold aligns with one axis, and (2) θ = π/4, in which the angles between the minima manifold and coordinate axes are maximized. After some approximations which are detailed in Appendix B (such as γ 1 (t) = γ 2 (t) = 1 which happens when t is large), we have the following effective dynamics on the minima manifold: θ = 0 : dm x,t = c 1 σ 0 h ′ (x t ) 4 h(x t ) -m x,t dt, dx t = - m x,t ϵ 0 dt. ( ) θ = π 4 : dm x,t = c 1 σ 0 h ′ (x t ) 2 2h(x t ) -m x,t dt, dx t = - √ 2m x,t σ 0 h(x t ) dt. Here, x is the coordinate along the manifold direction, and m is a corresponding momentum. The SDEs equation 25 and equation 26 show the difference of the manifold dynamics for different θ. When θ = 0, the x dynamics is very fast, because of the ϵ 0 on the denominator. When θ = π/4, instead, the x dynamics is slower. If we further make a first-order approximation of the dynamics by assuming the momentum is always at equilibrium, like we did for SGDm in equation 18, we have the following manifold dynamics: θ = 0 : ẋ = - σ 0 h ′ (x) 4ϵ 0 h(x) , θ = π 4 : ẋ = - h ′ (x) 2h(x) . ( ) Here we see that when θ = 0 we get a gradient flow minimizing h(x) on the minima manifold, while when θ = π/4 we get a gradient flow minimizing ln h(x). The former dynamics is much faster due to the ϵ 0 on the denominator. For the detail of the analysis please see Appendix B. When the dimension of the parameter space is high, it is hard for the minima manifold to align well with coordinate directions. Hence, the exploration of Adam (as well as other adaptive gradient methods) on the minima manifold is slower than SGD and SGDm. This may be one reason that Adam does not generalize as good as SGD in many cases Keskar & Socher (2017); Wilson et al. (2017) . Remark 3. Unlike the effective dynamics for SGD and SGDm, the dynamics in equation 27 do not depend on the learning rate η. This is because in the SDE approximation the time scale is t = kη 2 . Therefore, a η factor will appear if we transform the time scale to t = kη. letting z = [x T , y T ] T with x in the tangent space and y in the normal space, the dynamics 28 can be written as an SDE system for x and y. dm x,t = - 1 -µ η m x,t + v T t ∂ u H(x t )y t dt + √ η D11 (x t )dW (1) t , dx t = 1 η m x,t dt, dm y,t = - 1 -µ η m y,t + 2 H(x t )y t dt + √ η D22 (x t )dW (2) t , dy t = 1 η m y,t dt. (29) Here, m x,t and m y,t denote the momentum for x and y components, respectively. Again, since the y dynamics is faster than the x dynamics, we take a quasistatic approach by assuming y t is always at the equilibrium given x t , and taking expectation on y t in the x dynamics. Note that the two equations for m y,t and y t form a linear system of SDEs, we can still compute the first and second moments of y t at equilibrium and substitute the results into the equations for m x,t and x, which gives the following effective dynamics for x dm x,t = -   1 -µ η m x,t + n-k i,j=1 (V sgdm t ) ij ∇ x ( H(x t ) ij )   dt + √ η D11 (x t )dW t , dx t = 1 η m x,t dt, where the form and derivation of V sgdm t are given in the next subsection. Finally, replacing x by z on M and consider gradients on M, we have the following effective dynamics on general minima manifold: dm t = -   1 -µ η m t + n-k i,j=1 (V sgdm t ) ij ∇ M ( H(z t ) ij )   dt + √ η D11 (z t )dW t , dz t = 1 η m t dt, A.1 THE DERIVATION FOR V sgdm In this section, we derive V sgdm from the SDE for SGDm equation 16. Assume x t and m x,t is fixed, we search for the equilibrium of the following system of y and m y : dm y,τ = - 1 -µ η m y,τ + 2 H(x)y τ dτ + √ η D22 (x t )dW (2) τ , dy t = τ = 1 η m y,τ dτ. The SDE system above is linear. Let u τ = m y,τ y τ , the SDE can be written as du τ = Au τ dτ + √ ηDdB τ , where we have A = -1-µ η I -2 H(x t ) 1 η I 0 ∈ R 2(d-k)×2(d0k) , D = 0 D22 0 0 ∈ R 2(d-k)×d , and B τ is a Brownian motion. Let C τ be the second moment matrix Eu τ u T τ . By Herzog (2013) , C τ satisfies the ODE d dτ C τ = AC τ + C τ A T + ηDD T . Therefore, taking τ → ∞, let C ∞ = lim τ →∞ C τ be the moment matrix at equilibrium, we have AC ∞ + C ∞ A T = -ηDD T . ( ) By the definition of u, we have C ∞ = Em y,∞ m T y,∞ Em y,∞ y T ∞ Ey ∞ m T y,∞ Ey ∞ y T ∞ . We are interested in the Ey ∞ y T ∞ in the above matrix. By the symmetry of C ∞ , let C ∞ = C 1 C 2 C T 2 C 3 . Then, we want to derive C 3 . Substituting the blockwise C ∞ into equation 32, we have -1-µ η I -2 H(x) 1 η I 0 C 1 C 2 C T 2 C 3 + C 1 C 2 C T 2 C 3 -1-µ η I 1 η I -2 H(x) 0 = -η D22 DT 22 0 0 0 , which gives - 2(1 -µ) η C 1 -2 H(x)C T 2 + C 2 H(x) = -η D22 DT 22 , 1 η C 1 - 1 -µ η C 2 -2 H(x)C 3 = 0 (34) 1 η C 2 + 1 η C T 2 = 0. ( ) By equation 35, C 2 is skew symmetric. By definition, C 1 and C 3 are symmetric. Hence, adding equation 34 with its transpose, we obtain 1 η C 1 -H(x)C 3 + C 3 H(x) = 0, which gives C 1 = η H(x)C 3 + C 3 H(x) . Substituting into equation 35, we have C 2 = η 1 -µ C 3 H(x) -H(x)C 3 . ( ) Plugging equation 36 and equation 37 into equation 33, we have 2(1 -µ) H(x)C 3 + C 3 H(x) + 2η 1 -µ C 3 H(x) 2 + H(x) 2 C 3 -2 H(x)C 3 H(x) = η D22 DT 22 . There, denote H = H(x), V sgdm is the solution of HV + V H + η (1 -µ) 2 V H2 + H2 V -2 HV H = η 2(1 -µ) D22 DT 22 . B ADAM AND RMSPROP In this section, we derive the effective dynamics equation 25 and equation 26 for Adam, on the 2-D problem equation 23: f (x, y) = (x sin θ + y cos θ)((x cos θ -y sin θ) 2 + 1). We first describe the relation between two coordinate systems. Let xOy be the coordinate system on which the loss function equation 23 is defined, and Adam is conducted. Let x ′ Oy ′ be the coordinate system obtained by rotating xOy counterclockwise by θ. Then, the x ′ axis aligns with the direction of the minima manifold of f . In this coordinate system, f has the form: f (x, y) = ((x ′ ) 2 + 1)(y ′ ) 2 . Let R θ be the rotation matrix counterclockwise by θ, i.e. R θ = cos θ -sin θ sin θ cos θ . Then, for any vector x in xOy, let x ′ be its coordinate in x ′ Oy ′ , we have x ′ = R -θ x, and x = R θ x ′ . Recall the SDE for Adam derived in Malladi et al. (2022) , dx t = - γ 2 (t) γ 1 (t) P -1 t m t dt, dm t = c 1 (∇f (x t ) -m t )dt + σ 0 c 1 Σ 1/2 (x t )dW t , du t = c 2 (diag(Σ(x t )) -u t )dt, P t = σ 0 diag(u t ) 1/2 + ϵ 0 γ 2 (t)I, where x is the parameter vector, m is the first momentum vector, u is the second momentum vector, and Σ is the noise covariance. All these quantities are defined in the coordinate system xOy. Let x ′ , m ′ , Σ ′ be the counterparts of x, m, Σ in x ′ Oy ′ , then x ′ = R -θ x, m ′ = R -θ m, Σ ′ = R -θ ΣR θ . We do not consider u ′ as u in x ′ Oy ′ , because u is not rotationally invariant. This is the reason that Adam has different effective dynamics on the minima manifold for different θ. By the relations equation 41, the SDE equation 40 can be written as dx ′ t = - γ 2 (t) γ 1 (t) R -θ P -1 t R θ m ′ t dt, dm ′ t = c 1 (R -θ ∇f (x t ) -m ′ t )dt + σ 0 c 1 Σ ′1/2 (x t )dB t , ( ) du t = c 2 (diag(Σ(x t )) -u t )dt, P t = σ 0 diag(u t ) 1/2 + ϵ 0 γ 2 (t)I, where B t = R -θ W t is a Brownian motion in the x ′ Oy ′ system. Note that R -θ ∇f (x) is the gradient of f in the x ′ Oy ′ system, letting x = [x, y] T and x ′ = [x ′ , y ′ ] T , we have R -θ ∇f (x) = 2(x ′ )y ′2 2(x ′2 + 1)y ′ . From now on, we denote h (x) = x 2 + 1. Then R -θ ∇f (x) = h ′ (x ′ )y ′2 2h(x ′ )y ′ . Our analysis actually works for any positive and differentiable function h. Under the Hessian noise assumption, we take Σ ′ (x) = 0 0 0 h(x ′ ) , then for Σ we have Σ(x) = R θ Σ ′ (x)R -θ = h(x ′ ) sin 2 θ -sin θ cos θ -sin θ cos θ cos 2 θ . We do not add a σ before h(x ′ ) because in the derivation of the SDE equation 40 a strength factor σ is included into σ 0 . By the discussion on R -θ ∇f (x) and Σ ′ (x), letting m ′ = [m x ′ , m y ′ ] T , u = [u, v] T , we can write equation 42 into the following system of x ′ , y ′ , m x ′ , m y ′ , u, v: dx ′ t dy ′ t = - γ 2 (t) γ 1 (t) R -θ P -1 t R θ m x ′ ,t dt m y ′ ,t dt dm x ′ ,t = c 1 (h ′ (x ′ t )y ′2 t -m x ′ ,t )dt dm y ′ ,t = c 1 (2h(x ′ t )y ′ t -m y ′ ,t )dt + σ 0 c 1 h(x ′ t )dB t du t = c 2 (h(x ′ t ) sin 2 θ -u t )dt dv t = c 2 (h(x ′ t ) cos 2 θ -v t )dt P t = σ 0 u t 0 0 v t 1/2 + ϵ 0 γ 2 (t)I. In equation 43, B t is a 1-D Brownian motion. Next, we consider two cases: θ = 0 and θ = π 4 . Case θ = 0. When θ = 0, we have R θ = R -θ = I. Also, since P t = σ 0 u t 0 0 v t 1/2 + ϵ 0 γ 2 (t) = σ 0 √ u t + ϵ 0 γ 2 (t) 0 0 σ 0 √ v t + ϵ 0 γ 2 (t) , we have P -1 t =   1 σ0 √ ut+ϵ0 √ γ2(t) 0 0 1 σ0 √ vt+ϵ0 √ γ2(t)   . Therefore, equation 43 can be written as dx ′ t = - γ 2 (t) γ 1 (t) m x ′ ,t σ 0 √ u t + ϵ 0 γ 2 (t) dt dy ′ t = - γ 2 (t) γ 1 (t) m y ′ ,t σ 0 √ v t + ϵ 0 γ 2 (t) dt dm x ′ ,t = c 1 (h ′ (x ′ t )y ′2 t -m x ′ ,t )dt dm y ′ ,t = c 1 (2h(x ′ t )y ′ t -m y ′ ,t )dt + σ 0 c 1 h(x ′ t )dB t du t = -c 2 u t dt dv t = c 2 (h(x ′ t ) -v t )dt. Using the quasistatic approach, we assume the dynamics of y ′ and m y ′ is at equilibrium at any fixed t. Hence, fixing x ′ t , we consider the system  dy ′ τ = - γ 2 (t) γ 1 (t) m y ′ ,τ σ 0 √ v t + ϵ 0 γ 2 (t) dτ dm y ′ ,τ = c 1 (2h(x ′ t )y ′ τ -m y ′ ,τ )dτ + σ 0 c 1 h(x ′ t ) Substituting equation 45 into equation 44, we have the following effective dynamics only on the minima manifold: dx ′ t = - γ 2 (t) γ 1 (t) m x ′ ,t σ 0 √ u t + ϵ 0 γ 2 (t) dt dm x ′ ,t = c 1 σ 2 0 γ 2 (t)h ′ (x ′ t ) 4γ 1 (t)(σ 0 √ v t + ϵ 0 γ 2 (t)) -m x ′ ,t dt du t = -c 2 u t dt dv t = c 2 (h(x ′ t ) -v t )dt. Next, we try to make the manifold dynamics equation 46 simpler by doing some approximations. First, solving the ODEs for u t and v t , we get u t = u 0 e -c2t , v t = v 0 e -c2t + c 2 t 0 e c2(s-t) h(x ′ s )ds. We first assume t is big enough, such that e -t is close to 0. Then, we can take u t = 0, v t = c 2 t 0 e c2(s-t) h(x ′ s )ds, and also γ 1 (t) = γ 2 (t) = 1. Moreover, since the dynamics of x ′ is slow, we can assume the change of h(x s ) is slow compared with e c2s . In this case, we have c 2 t 0 e c2(s-t) h(x ′ s )ds ≈ h(x ′ t ). Hence, we can approximate u t and v t by u t = 0, v t = h(x ′ t ). Substituting equation 47 into equation 46, and taking γ 1 (t) = γ 2 (t) = 1, we get the following approximate manifold dynamics: dx ′ t = - m x ′ ,t ϵ 0 dt dm x ′ ,t = c 1 σ 2 0 h ′ (x ′ t ) 4σ 0 h(x ′ t ) + 4ϵ 0 -m x ′ ,t dt. ( ) Finally, in the denominator of the dynamics of m x ′ , the ϵ 0 term is usually small compared with the σ 0 h(x ′ t ) term. Hence, we can take 4σ 0 h(x ′ t ) + 4ϵ 0 ≈ 4σ 0 h(x ′ t ) and write the following approximate manifold dynamics: dx ′ t = - m x ′ ,t ϵ 0 dt dm x ′ ,t = c 1 σ 0 h ′ (x ′ t ) 4 h(x ′ t ) -m x ′ ,t dt. ( ) Note that h ′ (x) √ h(x) = (2 h(x)) ′ , the dynamics equation 49 can be understood as a gradient flow with momentum that minimizes √ h on the minima manifold. In the region that h is close to linear, we can suppose m x ′ ,t = σ 0 h ′ (x ′ t ) 4 h(x ′ t ) , and write down the following first-order approximation of the dynamics: dx ′ t dt = σ 0 h ′ (x ′ t ) 4ϵ 0 h(x ′ t ) . ( ) Note that due to the time scale chosen in the SDE equation 40, one time unit of the manifold dynamics that we derive corresponds to 1/η 2 steps in the SGD trajectory, i.e. each SGD step corresponds to a time period of η 2 . If we change the dynamics to the usual time scale, each SGD step correspond to a time period η, an additional η will appear on the numerator of the dynamics. Therefore, the dynamics is still a slow dynamics that is an η factor slower than the original SGD dynamics. Case θ = π 4 . When θ = π 4 , by equation 43, the dynamics of u and v are du t = c 2 h(x ′ t ) 2 -u t dt, dv t = c 2 h(x ′ t ) 2 -v t dt. In this case, u t and v t have the same dynamics. If we assume u 0 = v 0 , then we have u t = v t for any t ≥ 0. Then, P -1 t =   1 σ0 √ ut+ϵ0 √ γ2(t) 0 0 1 σ0 √ vt+ϵ0 √ γ2(t)   = 1 σ 0 √ u t + ϵ 0 γ 2 (t) I. Hence, R -θ P -1 t R θ = 1 σ 0 √ u t + ϵ 0 γ 2 (t) R -θ R θ = 1 σ 0 √ u t + ϵ 0 γ 2 (t) I. The dynamics equation 55 is a gradient flow that minimizes ln h(x) on the minima manifold. Again, due to the time scale choice, the dynamics gets slower for smaller learning rate. Compared with the dynamics equation 50 for the case θ = 0, this dynamics is slower because there is no ϵ 0 on the denominator. Remark 4. The approximation steps in the derivations above are conducted intuitively without rigorous proof. The goal is to unveil and compare the essential components of the dynamics. If strict theorems are to be proved, conditions and assumptions need to be imposed.



Although 11 is a dynamics for x ∈ R k , it can be understood as a dynamics for z on Tz 0 M, in which the y component is always 0. Our analysis works for any loss function with the form f (x, y) = (x sin θ + y cos θ) 2 h(x cos θ -y sin θ).



Figure 1: (left) The trajectory of SGD (real dynamics) with Hessian noise initialized from (5, 0). (middle) The x-coordinate of the real dynamics and the manifold dynamics for SGD with Hessian and isotropic noises.

and an SGD with Hessian noise, i.e. the noise covariance is proportional with the Hessian of f .

Figure 2: (left) SGDm and its manifold dynamics for different µ. (middle) The distance between the real optimization dynamics and the manifold dynamics for SGD and SGDm with Hessian noise, compared with the displacement of the optimization dynamics' iterators. (right) The distance and displacement curves for average trajectories of SGD when there is noise long the minima manifold.

Figure 3: (Left) The x-coordinate dynamics of three experiments for SGD with different hyperparameters, on function f (x, y) = (1 + x 2 )y 2 with Hessian noise. Two experiments have the same s(η, µ, B), while the other one has a smaller s(η, µ, B). Here the role of B is played by 1/σ 2 . (Middle left) The moving average of y-magnitude for the three experiments shown in the left panel. The three experiments have the same r(η, µ, B), hence the y-magnitudes are on the same order. (Middle right) The test accuracy curves of two neural network runs with different hyperparameters but the same s and r. Experiments are conducted on ResNet18 and CIFAR100 dataset. (Right) The distance between the manifold projection and the origin of Adam iterators for loss function equation 23 with different θ, compared with that for SGD.

, we show for synthetic problem that r(η, µ, B) is indeed proportional to the range of oscillation. In the third panel, we show for neural networks that two experiments with the same speed and range factors but different hyperparameters indeed follow the same test loss curve with respect to epochs. On the other hand, if the oscillation range is not kept, the training trajectories may go to different solutions (with different training and testing error). Remark 2. By equation 22, increasing the batch size while changing other hyperparameters accordingly can accelerate the training speed without changing the training trajectory. However, this acceleration happens on the level of number of iterations. Since the batch size changes accordingly, the number of samples used during the training period does not change. This means we are training for less iterations, but the same number of epochs. This is shown in the third panel of Figure3. Nevertheless, this usually saves time because training one big batch is faster than training several small batches with the same number of total samples. Therefore, our results reveals a theoretical mechanism underpinning the empirical benefit of large batch trainingYou et al. (2019; 2017);Hoffer et al. (2017);Geiping et al. (2021).

dB τ and compute lim

A SGD WITH MOMENTUM

Consider the following SDE approximation for SGDm:By the discussion for SGD, the manifold dynamics can be obtained by assuming that M is flat. The result for non-flat manifold is different only by a gradient taken on the manifold. In the flat case, still Therefore, equation 43 can be written asThe quasistatic step here still deals with the system of y ′ and m y ′ , and the results take the same form.we haveand the following effective dynamics on the minima manifold:To simplify the manifold dynamics above, we take the similar approximation steps as did for the θ = 0 case. We first solve the u dynamics, which givesStill assume t is large, and the change of h(x ′ s ) is slow compared with e -t . Then, we can approximately takeSubstituting the above approximations into equation 52, we obtainDropping the ϵ 0 terms on the denominator, equation 53 is approximated byFinally, if we assume m x ′ is close to its stationary solution, i.e. m x ′ ,t =, we have the following first-order dynamics for x ′ that approximates the manifold dynamics:.(55)

