HOW DOES ADAPTIVE OPTIMIZATION IMPACT LOCAL NEURAL NETWORK GEOMETRY?

Abstract

Adaptive optimization methods are well known to achieve superior convergence relative to vanilla gradient methods. The traditional viewpoint in optimization, particularly in convex optimization, explains this improved performance by arguing that, unlike vanilla gradient schemes, adaptive algorithms mimic the behavior of a second-order method by adapting to the global geometry of the loss function. We argue that in the context of neural network optimization, this traditional viewpoint is insufficient. Instead, we advocate for a local trajectory analysis. For iterate trajectories produced by running a generic optimization algorithm OPT, we introduce R OPT med , a statistic that is analogous to the condition number of the loss Hessian evaluated at the iterates. Through extensive experiments, we show that adaptive methods such as Adam bias the trajectories towards regions where R Adam med is small, where one might expect faster convergence. By contrast, vanilla gradient methods like SGD bias the trajectories towards regions where R SGD med is comparatively large. We complement these empirical observations with a theoretical result that provably demonstrates this phenomenon in the simplified setting of a two-layer linear network. We view our findings as evidence for the need of a new explanation of the success of adaptive methods, one that is different than the conventional wisdom.

1. INTRODUCTION

The efficient minimization of a parameterized loss function is a core primitive in statistics, optimization and machine learning. Gradient descent (GD), which iteratively updates a parameter vector with a step along the gradient of the loss function evaluated at that vector, is a simple yet canonical algorithm which has been applied to efficiently solve such minimization problems with enormous success. However, in modern machine learning, and especially deep learning, one frequently encounters problems where the loss functions are high dimensional, non-convex and non-smooth. The optimization landscape of such problems is thus extremely challenging, and in these settings gradient descent often suffers from prohibitively high iteration complexity. To deal with these difficulties and improve optimization efficiency, practitioners in recent years have developed many variants of GD. One prominent class of these GD variants is the family of adaptive algorithms (Duchi et al., 2011; Tieleman et al., 2012; Kingma & Ba, 2015) . At a high level, adaptive methods scale the gradient with an adpatively selected preconditioning matrix, which is constructed via a moving average of past gradients. These methods are reminiscent of second order gradient descent, since they construct approximations to the Hessian of the loss functions, while remaining computationally feasible since they eschew full computation of the Hessian. A vast line of empirical work has demonstrated the superiority of adaptive methods over GD to optimize deep neural networks, especially on Natural Language Processing (NLP) tasks with transformers (Vaswani et al., 2017; Devlin et al., 2019) . From a theoretical perspective, adaptive methods are well understood in the traditional context of convex optimization. For instance, Duchi et al. (2011) show that when the loss function is convex, then the Adagrad algorithm yields regret guarantees that are provably as good as those obtained by using the best (diagonal) preconditioner in hindsight. The key mechanism that underlies this improved performance, is that the loss function has some global geometric property (such as sparsity or a coordinate wise bounded Lipschitz constant), and the algorithm adapts to this global geometry by adaptively selecting learning rates for features that are more informative. However, in non-convex optimization, and deep learning in particular, it is highly unclear whether this simple characterization is sufficient to explain the superiority of adaptive methods over GD. Indeed, for large scale neural networks, global guarantees on the geometric properties of the loss are typically vacuous. For instance, for a 20-layer feedforward neural network, if we scale up the weights in each layer by a factor of 1.5, then the global Lipschitz constant of the network is scaled up by a factor of at least e 10 . Hence it only makes sense to study convergence by looking at the local geometry of the loss along the trajectory of the optimization algorithm (Arora et al., 2018) . (right) The 10th largest value over median in the diagonal of loss Hessian (which can be viewed as a variant of R OPT med (t) defined in eq. ( 1)) for Adam and SGD+M. Since the full Hessian is too big, here we selected several layers and randomly sampled 200 coordinates per layer to compute. Moreover, the interaction between an optimization algorithm and neural network geometry is highly complex -recent work has shown that geometric characteristics of iterates encountered during optimization is highly dependent on the choice of optimization algorithm and associated hyperparameters (Lewkowycz et al., 2020; Cohen et al., 2021) . For instance, Cohen et al. (2021) demonstrate that while training neural networks with GD, the maximum eigenvalue of the Hessian evaluated at the GD iterates first increases and then plateaus at a level 2/(step size). The viewpoint from convex optimization, where a loss function has some (potentially) non-uniform but fixed underlying geometry that we must adapt to, is thus insufficient for neural networks, since the choice of optimization algorithm can actually interact with and influence the observed geometry significantly. To provide another example of this interactive phenomenon, we consider the following experiment. On the same network training loss function f , we run stochastic gradient descent with momentum (SGD+M) and Adam to obtain two different trajectories. We select an iterate x Adam from the Adam trajectory and an iterate x SGD from the SGD trajectory, such that f (x Adam ) = f (x SGD ). We then run SGD+M twice, once from x Adam and once from x SGD . If the underlying geometry of the loss function f was truly fixed, then we would not expect a significant difference in the performance of running SGD+M from either of the two iterates. However, as shown in Figure 1 (left), running SGD+M from x Adam achieves lower loss than that from x SGD , suggesting that Adam may bias the trajectory towards a region which is more favorable for rapid training. This motivates the following question. How does adaptive optimization impact the observed geometry of a neural network loss function, relative to SGD (with momentum)? The remainder of this paper is dedicated to answering the above question. To this end, for each iterate in a trajectory produced by running an optimization algorithm OPT, where the Hessian of the tth iterate is given by H (t) ∈ R d×d , we define the second order statistic R OPT med (t) in the following fashion. For the tth iterate in the trajectory, let R OPT med (t) be the ratio of maximum of the absolute entries of the diagonal of H (t) , to the median of the absolute entries of the diagonal of H (t) . Concretely, we define R OPT med (t) = max{|H (t) ii |} d i=1 median {|H (t) ii |} d i=1 . (1) This statistic thus measures the uniformity of the diagonal of Hessian, where a smaller value of R OPT med (t) implies that the Hessian has a more uniform diagonal. It can also be viewed as a stable 1 variant of 1 Consider the case where one parameter has little impact on the loss, then the second derivative w.r.t. this parameter is almost zero, making max{|H (t) ii |} d i=1 min{|H (t) ii |} d i=1 infinity. So we consider median which is more stable. the condition number. Instead of eigenvalues, we choose diagonal entries because adaptive methods used in practice are coordinate-wise, which can be viewed as the diagonal scaling approaches. 2In Appendix A.9 we discuss this intuition in detail and compare R OPT med (t) with singular valuebased metrics. As a supplementary result, in Appendix E, we demonstrate that the loss Hessian approaches diagonal during training for Adam and SGD+M. There has been prior theoretical work on overparameterized neural networks showing that a smaller condition number of Hessian, Neural Tangent Kernel (Jacot et al., 2018) etc. could yield to faster convergence rate for (S)GD (Liu et al., 2022) . As for (diagonal) adaptive methods (e.g. Adagrad), they were original designed to adapt to the nonuniform diagonal geometry. Intuitively, a smaller R OPT med (t), which implies more uniform diagonal geometry, could lead to faster convergence. Armed with this statistic, we make the following contributions: • On a wide variety of neural network transformer architectures and language modeling datasets, we conduct experiments to compare how R Adam med (t) and R SGDM med (t) evolve over time, when Adam and SGD+M are run from the same initialization and with their optimal (initial) learning rates respectively. In each case, we demonstrate that the Adam trajectory attains R Adam med (t) values that are significantly smaller than the R SGDM med (t) values found by SGD+M. We show a simple example of this phenomenon in Figure 1 (right). This suggests that relative to SGD+M, Adam biases the optimization trajectory to a region where the Hessian diagonal is more uniform. We call this phenomenon the uniformity of diagonal geometry for adaptive methods. As an aside, we observe that larger improvements in optimization performance of Adam over SGD+M are correlated with larger gaps between R Adam med (t) and R SGDM med (t). This suggests that a region where the Hessian diagonal is more uniform is also a region that is more amenable to rapid optimization. • We complement our empirical results with a theoretical analysis of this phenomenon in the simplified setting of large batch Adam and SGD+M, on a two-layer linear network with d-dimensional input and hidden layer, and one dimensional output. We show that for a wide range of t, R Adam med (t) = 1 ± o(1) but R SGDM med (t) = Ω(log d). Our proof reveals that Adam induces the weight matrices to have low rank whose leading singular vectors have certain type of uniformity (see Section 6 for discussion), a fact that we also observe empirically in large scale neural networks, suggesting that this may be a mechanism by which adaptive methods bias trajectories to have uniformity of diagonal geometry.

2. RELATED WORK

Existing analyses of adaptive methods. The vast majority of prior theoretical work on adaptive methods has focused on the blackbox setting (Duchi et al., 2011; Kingma & Ba, 2015; Chen et al., 2020; Reddi et al., 2018; Ward et al., 2020; Défossez et al., 2020; Ene et al., 2021) . These works make minimal assumptions about the structure of the loss function, beyond (possibly) some global properties such as convexity or smoothness. These global properties (governed by parameters such as the smoothness parameter) are assumed to hold over the entire domain. Hence this style of analysis is worst case, since the resulting convergence bounds depend on polynomially on these global parameters. However, as we show in Section 3.1, in neural networks these parameters are prohibitively large. This worst case analysis is hence unlikely to explain the success of adaptive methods on neural networks. By contrast, our focus is on analyzing the local trajectory that is induced by running the optimization method. Existing analyses of (S)GD on neural networks. There is an extensive literature on the analysis of GD/SGD in the non-blackbox setting, e.g. overparameterized neural networks, (Du et al., 2018; Ji & Telgarsky, 2020; Allen-Zhu et al., 2019a; b; Arora et al., 2019a; Liu et al., 2022) . However, it is unclear how to translate these analyses of GD/SGD, to an analysis that explains the gap between GD/SGD and adaptive methods. Influence of algorithms on the loss geometry. In many simple convex settings, e.g. linear or logistic regression and the Neural Tangent Kernel (Jacot et al., 2018) , the loss geometry is usually fixed and not influenced by learning algorithms. However, in neural networks the interaction between algorithms and loss landscapes is more complicated. Lewkowycz et al. (2020) find a so-called catapult effect of initial learning rate on the training trajectory of SGD and related loss curvature. Cohen et al. (2021) demonstrate that while training neural networks with GD, the maximum eigenvalue of the Hessian evaluated at the GD iterates first increases and then plateaus at a level that is inversely proportional to the step size. However, Cohen et al. (2021) leave open the problem of whether similar interactive phenomena occur in algorithms that are not GD, including adaptive methods.

3.1. ISSUES OF PRIOR ANALYSES ON ADAPTIVE METHODS

As is mentioned in Section 2, existing work on adaptive algorithms has mainly focused on blackbox analysis assuming some global worst-case parameters. However, these global bounds can be extremely bad in complicated deep learning models, as is discussed in Section 1. To see this, we initialized a transformer model 3 with default initialization in Pytorch but chose a large gain 4 , and computed the smoothness parameter (denoted as l) and the condition number (denoted as κ) of loss Hessian on one layer. We observed that setting the gain as a large constant (e.g. 800) results in extremely large l and κ (l ≥ 10 7 and κ ≥ 10 10 ), which makes the convergence rates in prior black-box analysis vacuous. The failure of global worst-case analysis implies that we need to focus on the local trajectory of algorithms. However, it is unclear that when two optimization algorithms are used, they will have the same geometry in local trajectory. In particular, although in theory, adaptive algorithms can yield to a convergence rate with better dependency on certain local geometry of the function comparing to SGD (with momentum), it could still be the case that the local geometry along the trajectory of adaptive algorithm can be much worse than that of SGD (with momentum). That motivates us to study the local geometry, especially that obtained by adaptive methods comparing to SGD (with momentum) in the paper. Motivated by the diagonal scaling of Adagrad and Adam for neural network training, we ask the follow main question in our paper: How does the local diagonal geometry (diagonal of the loss Hessian) along the local trajectory of adaptive algorithms compare to that of SGD (with momentum)?

3.2. OVERVIEW OF THE EXPERIMENTS

As is discussed in Section 1, we consider R OPT med (t) defined in eq. ( 1) as a measurement of the uniformity of the diagonal of the loss Hessian. We conduct experiments on different NLP tasks to examine R OPT med (t), as in language models, adaptive methods have shown significantly faster convergence than SGD (with momentum). The details of these experiments will be shown in Section 4. To explore potential different patterns of different layers, we do the computation layer by layer. On a wide variety of transformer architectures and language modeling datasets from the same initialization, we observe that: When we train the neural network using Adam, the uniformity of diagonal geometry, measured by R OPT med (t) is smaller than that when we train using SGD+M from the same initialization, except for first several layers. Table 1 shows a typical example of R Adam med (t) compared to R SGDM med (t) on a sentence classification task using BERT-small (Turc et al., 2019; Bhargava et al., 2021) (See Section 4.1 for details). We repeated the experiments for 12 times starting from the same initialization. Table 1 shows the averaged R Adam med (t) and R SGDM med (t) in some randomly selected layers (except for the first several). We also report the averaged R SGDM med (t) R Adam med (t) and their standard deviations in the brackets. 5 Figure 2 shows the corresponding training losses of one in these 12 experiments. 3 https://pytorch.org/tutorials/beginner/transformer_tutorial.html 4 This refers to the gain parameter in some commonly used initialization functions of Pytorch, e.g. torch.nn.init.xavier uniform (). 5 R SGDM med (t) values in Table 1 for most layers are roughly 1.4 to 2 times R Adam med (t) in corresponding layers. In practice, it can be considered significant because it might imply 1.4 to 2 times faster convergence. To understand this phenomenon in a more principled point of view, we also provide a formal proof of the statement in a simplified setting: large batch Adam and SGD+M on a 2-layer linear network. Although simple, the choice of 2-layer linear network to understand learning dynamics is common in prior works (e.g. (Tian et al., 2021) ). Section 3.3 below describes the theoretical setup.

3.3. SETUP OF THE THEORETICAL ANALYSIS

Notation Let [d] = {1, 2, ..., d}. We use • 2 to denote the l 2 norm of a vector, and • F to denote the Frobenius norm of a matrix. Let •, • be the Euclidean inner product between vectors or matrices. Let N (µ, σ 2 ) be the one-dimensional Gaussian distribution with mean µ and variance σ 2 . For a scalar (vector, matrix) A which evolves over time, we use A (t) to denote its value at time t. Let there be m data points. The data matrix is X ∈ R dx×m and the label matrix is Y ∈ R dy×m . We assume that the input dataset is whitened, i.e. Λ xx := 1 m XX T ∈ R dx×dx is an identity matrix. The parameters of a 2-layer linear network are given by W Arora et al. (2019b) show that with whitened dataset, := (W 2 , W 1 ). Assume W i ∈ R di×di-1 for i = 1, 2. We have d 2 = d y , d 0 = d x . We consider the square loss L(W ) := 1 2m W 2 W 1 X -Y 2 F . Denote A := 1 m Y X T ∈ R dy×dx . L(W ) := 1 2m W 2 W 1 X -Y 2 F = L(W ) + c, L(W ) := 1 2 W 2 W 1 -A 2 F . where c does not depend on W . We consider the following model with small Gaussian initialization. Assumption 1 (Setup). The input covariance Λ xx := 1 m XX T ∈ R dx×dx is an identity matrix. The input and hidden layers are both of dimension d, i.e. d 1 = d 0 = d. Without loss of generality, we can assume that A is a row vector (i.e. d 2 = 1) whose coordinates are positivefoot_1 and Θ(1) in terms of d. Assumption 2 (Gaussian Initialization). ∀i, j : w (0) 2i ∼ N (0, 1 d 2α ), W (0) 1 [i, j] ∼ N (0, 1 d 4α ) are independently initialized with sufficiently large α > 0. Denote Ã and Λxx as the batch versions of A and Λ xx . We make the following large-batch assumption. We emphasize that large batches are commonly used in NLP tasks (e.g. (Brown et al., 2020) ). Assumption 3 (Large Batch). For the randomly selected batches, assume E[ Ã] = A, E[ Λxx ] = Λ xx . ∀i, j ∈ [d] : E ( Ãi -A i ) 2 ≤ σ 2 , E ( Λxx [i, j] -Λ xx [i, j]) 2 ≤ σ 2 , and σ 2 = O( 1 poly(d) ). Denote g(t) as the batch gradient at time t. The update rules of SGD+M and Adam are given by SGD+M: u (t+1) = βu (t) + g(t) , W (t+1) = W (t) -ηu (t) , Adam: η t = η • 1 -β t+1 2 /(1 -β t+1 1 ), m (t+1) = β 1 m (t) + (1 -β 1 )g (t) , v (t+1) = β 2 v (t) + (1 -β 2 )g (t) g(t) , W (t+1) = W (t) -η t m (t) √ v (t) + ξ , ( ) where η is the learning rate, β, β 1 , β 2 are momentum parameters, and ξ is for numerical stability. All operations on vectors are element-wise. Here and throughout, the notation f (x) = O(g(x)) (resp. f (x) = Ω(g(x)), f (x) = Θ(g(x))) means that there exist constants C 1 , C 2 > 0 such that f (x) ≤ C 2 g(x) (resp. f (x) ≥ C 1 g(x), C 1 g(x) ≤ f (x) ≤ C 2 g(x) ). We will also use the notation with ∼, i.e. Õ(•), Ω(•), Θ(•) to hide factors that are logarithmic in d. In our theoretical analysis, "with high probability", or "w.h.p." for short, means that with probability at least 1 -1 poly(d) .

4. THE UNIFORMITY OF DIAGONAL GEOMETRY

As is mentioned in Section 3.2, we computed R OPT med (t) defined in eq. ( 1) on different language models. In this section, we present the results of SGD+M and Adam on different architectures and datasets. In Appendix A, we present the results of other adaptive algorithms. During training we started from the same initial weights and used the same learning rate schedule (constant or decreasing) for SGD+M and Adam. We tuned and chose the best (initial) learning rate of SGD+M. The (initial) learning rate of Adam was set as a value under which Adam converged faster than SGD+M with its best learning rate. The concrete values will be stated in later parts of this section. We used large batch sizes to make the training procedure stable. When computing Hessian, we also used large batch sizes. Due to the extremely large dimension, we did the computation on some uniformly selected coordinates, more precisely, 200 coordinates per layer.

4.1. EXPERIMENTS ON REAL DATASETS

Sentence classification task on BERT-small We fine-tuned BERT-small (Turc et al., 2019; Bhargava et al., 2021) on the IMDB dataset (Maas et al., 2011) : the task is to classify whether movie reviews are positive or negative. 7 The momentum parameter β in SGD was set as 0.9. The two momentum parameters (β 1 , β 2 ) of Adam were set as (0.9, 0.999). We trained the model using linearly decreasing learning rates for 10 epochs (2500 iterations). The initial learning rates of SGD+M and Adam were 0.001 and 5e-5, respectively. As mentioned in Section 3.2, Figure 2 and Table 1 show the training losses and the comparison between R Adam med (t) and R SGDM med (t), respectively. Translation task We trained a Seq2Seq network that uses Transformer to solve a machine translation task on Multi30k (Elliott et al., 2016) (CC BY-NC-SA 4.0): this task is to train a German to English translation model. 8 The momentum parameter β in SGD was set as 0.9. The two momentum parameters (β 1 , β 2 ) of Adam were set as (0.9, 0.98). We trained the model using constant learning rates (0.03 for SGD+M and 1e-4 for Adam) for 60 epochs (1800 iterations). The experiments were repeated for 8 times starting from the same initialization. Figure 3 

4.2. EXPERIMENTS ON RANDOM DATASETS

We used the same model and momentum parameters as in the translation task described in Section 4.1 but generated random integers as targets. Similar to the setting on real targets, the model was trained Overall, through extensive experiments on language models, we demonstrate that starting from the same initialization, the R OPT med (t) values found by Adam are smaller than those found by SGD+M, except for the first several layers. This suggests that Adam is biased towards a region with more uniform diagonal Hessian than SGD+M. In Appendix A.10 we also validate this observation on the in-distribution test data. Positive correlation between uniformity of diagonal Hessian and fast convergence. We observe that on random dataset, SGD+M plateaus after about 400 steps and thus converges much slower when compared to Adam than on real dataset (see Figure 3 ). On the other hand, the gaps of R SGDM med (t) and R Adam med (t) are more significant on random data than on real data (see Table 2b ) as well. In Appendix A.4, we conduct another experiment where we switch from SGD to Adam in the middle and compare it with the model trained by Adam from the beginning. The observation is that both the loss gap and the gap of R OPT med (t) are gradually closed after switching (see Figure 7 and Table 8 ). Hence we find a positive correlation between fast convergence and the uniformity of diagonal of loss Hessian, suggesting that a region with more uniform diagonal of Hessian is also a region that is more amenable to fast optimization. In Appendix A we study other adaptive algorithms (Adagrad, RMSprop and AMSGrad) and get similar observation: all these adaptive methods converge faster than SGD or SGD+M and also bias the trajectory to regions with smaller R OPT med (t), suggesting that the uniformity of diagonal Hessian might be a universal mechanism (partially) explaining the faster optimization of adaptive algorithms than SGD (with momentum). More discussions on the trajectory difference. Considering the fact that our comparison between R Adam med (t) and R SGDM med (t) is conditioned on the same iteration when SGD+M has larger training loss than Adam, there is a potential alternative explanation of the Hessian diagonal uniformity. That is, the global minimum has uniform Hessian, and Adam simply converges faster to it than SGD+M. To rule out this possibility, in Appendix A.3 we add a comparison of our measurements R Adam med (t) and R SGDM med (t ), where t, t are picked such that tth Adam iterate and t th SGD+M iterate have the same training loss. The results (in Table 7 ) show that R Adam med (t) < R SGDM med (t ) for most layers, thus demonstrating that the trajectories of Adam and SGD+M are truly different and that the difference is because Adam biases the local geometry (as opposed to faster convergence). Adding regularization. People in practice usually add weight decay (equivalent to l 2 regularization) to encourage better generalization ability. In Appendix A.7 we compare SGD+M and Adam when both using small weight decay values (0.001). The results in Figure 13a and Table 9 suggest that in this case, the positive correlation between R OPT med (t) and convergence speed still holds: Adam converges faster than SGD+M and in most of the layers except for the first several, R Adam med (t) values 9 To prevent R OPT med (t) from getting too large due to tiny median, we added an additional term 0.001 max{|H (t) ii |} d i=1 to the denominator of eq. ( 1) when computing. are smaller than R SGDM med (t). This reveals the robustness of our observation under weak regularization. However, under large weight decay parameters, we observed cases where Adam still converged faster but R Adam med (t) values were larger rather than smaller. In the case of strong regularization, the adaptivity of Adam requires further exploration and we hope to find new mechanisms in the future. Image tasks. Although in this paper we focus on language models where Adam shows significant fast convergence, we also add supplementary results in Appendix A.8 on image tasks where SGD+M performs better. On a residual network trained on CIFAR-10, we observed that Adam did not converge faster than SGD+M (see Figure 13b ) and in the meantime, R Adam med (t) values were no longer smaller than R SGDM med (t) during training (see Table 10 ). This reveals the connection between the local diagonal geometry and the convergence speed from another perspective. That is, when the diagonal of Hessian of Adam is not more uniform than SGD+M, its convergence speed is not better, either.

5. THEORETICAL ANALYSIS

In Section 4, we empirically demonstrate the uniformity of diagonal geometry. In this section, we theoretically analyze this property for large batch Adam and SGD+M on a two-layer linear network with 1-dimensional output. ) obtained by SGD+M (resp. Adam) defined in (3).

1.. For any

p > 0, pick 0 < < 1 d p , η ≤ O d 7α/4+4 and α ≥ 4(p + 2). Suppose σ ≤ η 3/2 d α/2+1 , then there exists T SGD,1 , T SGD,2 such that w.h.p., L W (TSGD,1) SGD = Θ(d), L W (TSGD,2) SGD ≤ Õ 1 d p , and ∀t ∈ [T SGD,1 , T SGD,2 ] : R SGDM med,k (t) = Ω(log d), k = 1, 2.

2.. For any

p > 0, pick η ≤ O 1 d 3α , ξ ≤ η d 3α-1 , α ≥ p+4 3 and β 2 = β 2 1 . Suppose σ ≤ η 3/2 ξ 2 d 13/4 , Then ∃T Adam,1 , T Adam,2 such that w.h.p., L W (T Adam,1 ) Adam = Θ(d), L W (T Adam,2 ) Adam ≤ Õ 1 d p , and ∀t ∈ [T Adam,1 , T Adam,2 ] : R Adam med,k (t) = 1 ± Õ η 1 4 + 1 d α 2 -1 4 , k = 1, 2. An immediate corollary of this theorem below gives the difference between iterates of Adam and SGD+M that have the same loss. Corollary 1. Under the setup in Theorem 1, w.h.p., for any t ∈ [T SGD,1 , T SGD,2 ] and t ∈ [T Adam,1 , T Adam,2 ] such that L W (t) SGD = L W (t ) Adam ∈ Ω 1 d p , Θ(d) , we have R SGDM med,k (t) = Ω(log d), R Adam med,k (t ) = 1 ± Õ η 1 4 + 1 d α 2 -1 4 , k = 1, 2. Theorem 1 and Corollary 1 tell us that during a long training period when the loss decreases from Θ(d) to Õ 1 d p , the diagonal of loss Hessian for Adam keeps nice uniformity in the sense that for each layer, its diagonal elements have roughly the same value, i.e. R Adam med,k (t) = 1 ± o(1), k = 1, 2. On the other hand, the diagonal of loss Hessian for SGD+M is less uniform. Appendix B gives a proof sketch of Theorem 1. The detailed proof can be found in Appendix C and D.

6. THE LOW RANK STRUCTURE OF WEIGHT MATRICES AND UNIFORMITY OF LEADING SINGULAR VECTORS

The proof sketch in Appendix B highlights one crucial intuition of Theorem 1: After T SGD,1 (resp. T Adam,1 ) steps, W 1 of SGD+M (resp. Adam) becomes an approximately rank-1 matrix. Consider the left singular vector u := [u 1 , u 2 , ..., u d ] T which corresponds to the leading singular value σ 1 . We can show that the distribution of u 2 1 , u 2 2 , ..., u 2 d for Adam is more uniform than that of SGD+M. This property, we call the uniformity of the leading singular vector, is related to the uniformity of the diagonal of loss Hessian, see Appendix F for more details. Similar low rank bias after training has been studied in prior works (e.g. (Gunasekar et al., 2017; Li et al., 2018; Chou et al., 2020) ). For more complicated models, we want to check whether the weight matrices also have low rank structures and if so, whether we can still observe the uniformity of leading singular vectors. More formally, consider the weight matrix in some layer W ∈ R m×n , we want to check in the rank 1 case. We want to see whether R u obtained by Adam is smaller than that of SGD+M. (A) Whether W ∈ R After reviewing the weight matrices we got in different settings, we observed that (A) and (B) hold for many layers in those models. For example, on the translation task mentioned in Section 4.1, we found 12 layers which had approximately low rank structures and for 10 of them, R u values (defined in (B)) obtained by Adam were smaller than those found by SGD+M. Figure 4 shows the result on one typical layer. Results of more layers can be found in Appendix A.5. Remarks 1. The definition of R u is based on the connection between diagonal of loss Hessian and weight matrices. Appendix F shows that for a 2-layer linear network, R OPT med,2 (t) = maxi W (t) 1 [i,:] 2 2 median W (t) 1 [i,:] 2 2 . When W 1 ∈ R m×n is approximately rank k, i.e. W 1 ≈ k i=1 σ i u i v T i , de- note u i = [u i1 , u i2 , ..., u im ] T and v i = [v i1 , v i2 , ..., v in ] T , we have that for the j-th row, W 1 [j, :] 2 2 ≈ k i=1 σ i u ij v T i 2 2 = k i=1 σ 2 i u 2 ij . By defining ũ = [ũ 1 , ũ2 , ..., ũd ] T := k i=1 σ 2 i u i u i , we have that W 1 [j, :] 2 2 ≈ ũj . Although in multi-layer nonlinear neural networks, the connection between diagonal of loss Hessian and the weight matrices is more complicated and R OPT med,2 (t) may depend on the product of many weight matrices rather than one single matrix, we still believe that this definition of R u is a reasonable ratio to consider. 2. We may also want to consider the right singular vectors v 1 , v 2 , ...v k and corresponding ṽ = [ṽ 1 , ṽ2 , ..., ṽd ] T := k i=1 σ 2 i v i v i and compute R v := maxi ṽi median ṽi for Adam and SGD+M. However, on this translation task, among the 12 layers which were approximately low rank, for only 6 of them, R v of Adam were smaller, i.e. we did not observe uniformity of the leading right singular vector for Adam. Results of R v can be found in Appendix A.5. One possible reason is that for a weight matrix, its right singular vectors are closer to the input data than left singular vectors and more easily influenced by the data, therefore may not show uniformity.

7. CONCLUSION AND FUTURE WORK

We demonstrate that adaptive optimization methods bias the training trajectory towards a region where the diagonal of loss Hessian is more uniform, through extensive experiments on language models and theoretical analysis in a simplified setting of two-layer linear networks. Although our findings may not directly lead to an improved algorithm for practical use, they provide a new way of thinking when designing new algorithms: in contrast with the traditional view which tries to design a method that performs better in the bad loss geometry, our findings suggest that we can design algorithms which implicitly avoid regions with bad geometry. There are a lot of future directions along this line. For example, our theoretical results on the two-layer linear networks may be able to generalize to multi-layer networks. In fact, people conjecture that the key-value-query structure in language models can be approximated by a three-layer linear network. Hence the generalization to multi-layer networks might provide more connection to real deep models and could be an interesting and challenging future direction. Moreover, it is also possible to relax our large-batch assumption (Assumption 3) and prove similar results in the general stochastic setting.

A MORE EXPERIMENTS OF THE UNIFORMITY OF DIAGONAL GEOMETRY

A.1 VANILLA SGD VS. ADAGRAD In this section, we present the R OPT med (t) values defined in eq. ( 1) obtained by vanilla SGD and Adagrad on a language modeling taskfoot_4 . The task is to assign a probability for the likelihood of a given word (or a sequence of words) to follow a sequence of words. We trained a transformer model to solve this problem on both Wikitext-2 (Merity et al., 2017) (CC BY-SA 3.0) and random dataset (generating random integers as targets). This model has roughly 8 layers (not counting normalization and dropout layers) The setup is the same as in Section 3.2. We used the same learning rate schedule (constant or decreasing) for SGD and Adagrad. We tuned and chose the best (initial) learning rate of SGD. The (initial) learning rate of Adagrad was set as a value under which Adagrad converged faster than SGD with its best (initial) learning rate. We used large batch sizes to make the training procedure more stable. When computing Hessian, we also used large batch sizes. Due to the extremely large dimension, we did the computation on some uniformly selected coordinates, more precisely, 200 coordinates per layer. We tried different initialization (normal and uniform) by using different gains of the Pytorch initialization schedule.

A.1.1 EXPERIMENTS ON REAL DATASET

Figure 5a shows the training losses on real dataset (wikitext-2). The experiments were conducted on the translation task described in Section 4.1. The learning rates we used were 2.5e-5 for RMSprop, 0.0005 for AMSGrad and 0.03 for SGD+M. Both RMSprop and SGD+M used momentum parameter 0.9. The two momentum parameters (β 1 , β 2 ) of AMSGrad were (0.9, 0.98). Figure 6 shows the training losses and Table 6 shows the corresponding R OPT med (t). In this section we describe another learning schedule: the "Adam after SGD" schedule, where we switch from SGD to Adam in the middle to see whether the loss and R OPT med (t) can catch up with the model trained by Adam from the very beginning. Again, we used the same model as in the translation task in Section 4.1. In this section, we did not add momentum term to SGD in order to get a larger gap between SGD and Adam than the case using momentum. We want to see whether this larger gap can be closed after switching to Adam in the middle. As is shown in Figure 7 and Table 8 , both the loss gap and the gap of R OPT med (t) were closed after a period of training after switching algorithms, which provides evidence of the connection between convergence speed and uniformity of diagonal of loss Hessian.

A.5 THE LOW RANK STRUCTURE

In this section, we present more results for the experiments in Section 6. We examined the weights of the model trained for the translation task in Section 4.1. Among roughly 30 layers, we observed that for 12 layers, at least the weight matrices obtained by Adam after training have approximately low rank structures. Figure 8 shows the examples of layers with or without the low rank structure. We then studied the uniformity of leading singular vectors of these 12 layers, i.e. computed R u and R v defined in (B) and the second remark of Section 6. The observation is that for 10 out of these 12 layers, R u values of Adam were smaller those of SGD, which implies the uniformity of leading left singular vectors of Adam. However, we did not observe significant uniformity for Adam in terms of leading right singular vectors (R v ). The second remark of Section 6 discusses possible reasons. Figure 9 shows how R u and R v changed over time in some layers. In this section we present the uniformity of diagonal geometry of adaptive methods from another perspective. Denote H ii as the (i, i)-th element of the loss Hessian H and g i as the i-th element of the gradient. It is conjectured that when |H ii | is large, the corresponding |g i | is usually large as well. For adaptive methods, we can regard the update per step as the learning rate times the "adaptive gradient". Let's use g adapt,i to represent the i-th component of the adaptive gradient. Through experiments on language models, we found that |g adapt,i | for different i are quite uniform and do not align with |H ii | as the true gradient |g i | does. In the experiments, we first sorted |H ii | in the ascent order: |H i1,i1 | ≤ |H i2,i2 | ≤ ... ≤ |H i d ,i d | (suppose H ∈ R d×d ), and then plotted the corresponding |g i k | and |g adapt,i k | for k ∈ [d]. A.6.1 SGD VS. ADAGRAD Here we compare SGD and Adagrad on the language modeling task on wikitext-2 described in Section A.1. We observed that the figures of all layers are quite similar so we select one layer as an example, as is shown in Figure 10 .  ,i 1 | ≤ |Hi 2 ,i 2 | ≤ ... ≤ |Hi d ,i d | (suppose H ∈ R d×d ). Experiments were conducted on the model described in Section A.1. This figure shows the results on the 12-th layer.

A.6.2 SGD WITH MOMENTUM VS. ADAM

Here we compare Adam and SGD+M on the tasks described in Section 4.1. Again, we select one layer as an example for each task. Figure 11 shows the results on the sentence classification task and Figure 12 shows the results on the translation task.

A.7 ADDING REGULARIZATION AND OTHER TRICKS

In this section, we add weight decay to both Adam and SGD+M on the translation task described in Section 4. The momentum parameter β in SGD was set as 0.9. The two momentum parameters (β 1 , β 2 ) of Adam were set as (0.9, 0.98). For both algorithms, we set the weight decay parameter as 0.001. We trained the model using constant learning rates for 60 epochs (1800 iterations). We tuned and chose the best learning rate 0.03 for SGD+M. The learning rate of Adam was set as 0.0001, under which Adam converged faster than SGD+M with its best learning rate 0.03. Figure 13a shows the  ,i 1 | ≤ |Hi 2 ,i 2 | ≤ ... ≤ |Hi d ,i d | (suppose H ∈ R d×d ). Experiments were conducted on the sentence classification task described in Section 4.1. This figure shows the results on the 12-th layer.  ({|Hi k ,i k |} d k=1 ). Here coordinates are sorted such that |Hi 1 ,i 1 | ≤ |Hi 2 ,i 2 | ≤ ... ≤ |Hi d ,i d | (suppose H ∈ R d×d ). Experiments were conducted on the translation task described in Section 4.1. This figure shows the results on the 5-th layer. training losses and Table 9 shows the values of R Adam med (t), R SGDM med (t) and R SGDM med (t) R Adam med (t) in some randomly selected layers.

A.8 RESULTS ON IMAGE TASKS

We trained a ResNetfoot_5 on CIFAR-10 dataset and compared the convergence speed and R OPT med (t) of SGD+M and Adam. The momentum parameter β in SGD was set as 0.9. The two momentum parameters (β 1 , β 2 ) of Adam were set as (0.9, 0.98). The model was trained using constant learning rates for 41 epochs (2050 iterations). We tuned and chose the best learning rates for both algorithms: 0.5 for SGD+M and 0.005 for Adam. Figure 13b shows the training losses and Table 10 shows the values of R Adam med (t), R SGDM med (t) and In Section 4, through extensive experiments on language models, we demonstrate that when we train the neural network using Adam, the uniformity of diagonal geometry, measured by R OPT med (t) is smaller than that when we train using SGD+M from the same initialization, for most of the layers. We are aware that people also usually consider Hessian singular values instead of diagonal entries to measure the loss geometry. Hence in this section we make a comparison between our diagonal-based metric and singular value-based metrics. R SGDM med (t) R Adam med (t) R SGDM med (t) R Adam med (t) R SGDM med (t) R Adam med (t) R SGDM med (t) R Adam med (t) R SGDM med (t) R Adam med (t) R SGDM med (t) R Adam med (t) . R SGDM med (t) R Adam med (t) R SGDM med (t) R Adam med (t) R SGDM med (t) R Adam med (t) R SGDM med (t) R Adam med (t) R SGDM med (t) R Adam med (t) R SGDM med (t) R Adam med (t) First, we believe that our metric has a natural connection to the mechanism that underlies adaptive methods. Adaptive methods in practice choose coordinate-wise adaptive learning rates. From a high-level perspective, this procedure can be viewed as adapting to the loss smoothness with respect to each coordinate. The smoothness of certain coordinate is measured by the second derivative with respect to this coordinate and therefore corresponds to the diagonal entries of loss Hessian. Our metric, which measures these diagonal entries, is thus fundamentally intertwined with the mechanism that underlies adaptive methods. Next, we empirically demonstrate that our metric R OPT med (t) is a reasonable proxy of singular valuebased metrics. Define a singular value-based metric S OPT med (t) := max{σi(t)} d i=1 median {σi(t)} d i=1 as an analogy of our diagonal-based metric R OPT med (t), where {σ i (t)} d i=1 denotes the singular values of loss Hessian H(t) ∈ R d×d at the tth iterate. We compare S OPT med (t) along the trajectories of Adam and SGD+M in the translation task described in Section 4.1. Table 11 suggests that if measured by singular values, Adam is also biased to a region with smaller S OPT med (t) than SGD+M, similar to the observation for R OPT med (t). This is expected because in Appendix E, we demonstrate that the loss Hessian approaches diagonal during training. The fact that our diagonal-based metric and singular value-based metric give the same result also reveals the robustness of our observation to the choice of metric, demonstrating that there does exist some geometry bias of Adam towards more uniform regions even when measured by different metrics. Finally, there is strong reason why our metric is often easier to compute empirically and analyze theoretically than singular value-based metrics such as S OPT med (t). 1. From the empirical computation perspective, suppose the loss Hessian is d × d. Then computing its singular values, in general, requires computing the whole matrix with d 2 elements. However, our metric only requires computing the d diagonal entries. 2. From the theoretical analysis perspective, in Appendix F, we show that the diagonal of loss Hessian in linear networks can be connected to weight matrices by simple formulas. These straightforward formulas simplify the analysis and allow us to connect our metric to the low-rank structure of weight matrices and the uniformity of their leading singular vectors (see Section 6 for more discussions). However, all these nice connections fail to hold for singular value-based metrics. The formulas of singular values are very complicated even in linear networks, making it almost impossible to theoretically analyze any singular value-based metrics. In this section we compare R Adam med (t) and R SGDM med (t) on the in-distribution test data. The task is the translation task described in Section 4.1. Table 12 validates that on in-distribution test data, Adam is also biased to a region with smaller R OPT med (t) than SGD+M, similar to what happens on the training data shown in Table 2a . This is expected because of the same distribution. One other thing we want to emphasize is that, in real language tasks, the dataset is typically very large and the model see each training example only once. Hence the training behavior usually implies similar in-distribution test behavior.  R SGDM med (t) R Adam med (t) R SGDM med (t) R Adam med (t) R SGDM med (t) R Adam med (t) R SGDM med (t) R Adam med (t) R SGDM med (t) R Adam med (t)

B PROOF SKETCH OF THEOREM 1

Now we give a proof sketch of Theorem 1, which contains three major steps. The detailed proof can be found in Appendix F, C and D. First we relate the diagonal of Hessian to weight matrices W 1 , W 2 . Under Assumption 1, denote W 1 [i, :] as the i-th row of W 1 and W 2 := [w 2i , w 22 , ..., w 2d ]. Since the input dataset is whitened, we can show that R OPT med,1 (t) = max i w (t) 2i 2 median w (t) 2i 2 , R OPT med,2 (t) = max i W (t) 1 [i, :] 2 2 median W (t) 1 [i, :] 2 2 . Next, due to the one-dimensional output, we can prove that W 1 converges to an approximately rank-1 matrix. More precisely, we have W (t) 1 = u (t) v (t)T + R (t) 1 , W (t) 2 = c (t) u (t)T + R (t)T 2 . where c (t) is a scalar, u (t) , v (t) , R (t) 2 ∈ R d and R (t) 1 ∈ R d×d .Denote the i-th coordinate of u (t) , v (t) , R (t) 2 as u (t) i , v (t) i , R (t) 2i , respectively. Denote the (i, j)-th element of R (t) 1 as R (t) 1 [i, j]. We have that ∀i, j ∈ [d] : R (t) 2i c (t) u (t) i and R (t) 1 [i, j] u (t) i v (t) i . Using the rank 1 structure, we can further simplify R OPT med,1 (t) and R OPT med,2 (t) by R OPT med,k (t) ≈ max i u (t) i 2 median u (t) i 2 , k = 1, 2. The final step is the detailed analysis of u (t) . For SGD+M, we can prove that u (t) ≈ C(t)[X 1 , X 2 , ..., X d ] T where C(t) ∈ R and X i , i ∈ [d] are i.i.d. Gaussian variables. Then we have with high probability, maxi u (t) i 2 median u (t) i 2 = Ω(log d). For Adam, we can prove that ∀i ∈ [d] : u (t) i ∈ {±1}, which gives us maxi u (t) i 2 median u (t) i 2 = 1. Substituting into eq. ( 4) completes the proof.

C ANALYSIS OF SGD+M

Note that A = 1 m Y X T , Λ xx := 1 m XX T . Denote g (t) k := ∇ W k L(W (t) ), k = 1, 2. We have that g (t) 1 = W (t)T 2 W (t) 2 W (t) 1 -A , g (t) 2 = W (t) 2 W (t) 1 -A W (t)T 1 . Let Ã(t) , Λ(t) xx and g(t) k , k = 1, 2 be the corresponding batch versions at time t. Let E (t) := W (t) 2 W (t) 1 -A, and use E (t) i , A i and W (t) 2 W (t) 1 i to represent the i-th coordinates of E (t) , A and W (t) 2 W (t) 1 , respectively. By eq. ( 2), the update rules of W 1 and W 2 for SGD+M are given by: W (t+1) 1 = W (t) 1 -η t τ =0 β t-τ W (τ )T 2 W (τ ) 2 W (τ ) 1 -A -η t τ =0 β t-τ Dg (τ ) 1 , W (t+1) 2 = W (t) 2 -η t τ =0 β t-τ W (τ ) 2 W (τ ) 1 -A W (τ )T 1 -η t τ =0 β t-τ Dg (τ ) 2 , where Dg (t) 1 := g(t) 1 -g (t) 1 = W (t)T 2 W (t) 2 W (t) 1 Λ(t) xx -Λ xx -Ã(t) -A , Dg (t) 2 := g(t) 2 -g (t) 2 = W (t) 2 W (t) 1 Λ(t) xx -Λ xx -Ã(t) -A W (t)T 1 . Based on the magnitude of W 2 and W 1 , we can intuitively divide the training procedure into 2 phases. 1. First phase: the first several iterations when W 1 and W 2 are "small" so that W 2 W 1 -A ≈ -A. 2. Second phase: later iterations when W 2 W 1 cannot be ignored. More formally, the boundary between the first and second phase is defined below. Definition 1 (End of the first phase). The end of the first phase (denoted as T 1 ) is defined as T 1 := inf t ≥ 0 : ∃i, j ∈ [d] : w (t) 2i ≥ 1 d α 2 or W (t) 1 [i, j] ≥ 1 d α 2 . By Assumption 2 and the assumption that ∀j ∈ [d] : A j > 0, A j = Θ(1), at the beginning, w.h.p., ∀j ∈ [d] : (W 2 W 1 ) j -A j < 0. During the training, each (W 2 W 1 ) j increases and approaches A j . We hope that by choosing a small learning rate, when (W 2 W 1 ) j overshoots for some coordinate j, i.e. (W 2 W 1 ) j > A j , it will be close to convergence. To analyze this overshooting issue more carefully, let's first define the following "almost overshooting time". Definition 2 (Almost overshooting time). For > 0, denote 0 := 1 d 1 4 α-1 + log d . Define T 2 := inf t ≥ 0 : ∃j ∈ [d] : W (t) 2 W (t) 1 j -A j ≥ - √ 0 . Definition 3 (Convergence time). For > 0, we define the "convergence time" T 3 := inf t ≥ 0 : E (t) 2 2 ≤ . We can first show that after the first phase, i.e. when t = T 1 , W 1 will become an approximately rank-1 matrix, as described in the following lemma. Lemma 1. Under Assumption 1, 2 and 3, suppose σ ≤ η 3/2 d α/2+1 . By picking η ≤ O 1 d α , we have that when t = T 1 , L W (T1) = Θ(d), and that W (T1) 1 = R (T1) 1 + u (T1) v (T1)T , W (T1) 2 = R (T1)T 2 + c (T1) u (T1)T , where c (T1) ∈ R, u (T1) , v (T1) , R (T1) 2 ∈ R d and R (T1) 1 ∈ R d×d . Denote the i-th coordinate of u (T1) , v (T1) , R (T1) 2 as u (T1) i , v (T1) i , R (T1) 2i , respectively, and the (i, j)-th element of R (T1) 1 as R (T1) 1 [i, j]. Then w.h.p., ∀1 ≤ i, j ≤ d : R (T1) 1 [i, j] u (T1) i v (T1) j ≤ Õ 1 d 1 4 α-1 , R (T1) 2i c (T1) u (T1) i ≤ Õ 1 d 1 4 α-1 . The following lemma tells us that this approximate rank-1 structure is preserved when T 1 ≤ t ≤ min{T 2 , T 3 }. Lemma 2. Under Assumption 1, 2 and 3, suppose σ ≤ η 3/2 d α/2+1 . By picking η ≤ O d 7α 4 +4 , we have that w.h.p. for T 1 ≤ t ≤ min{T 2 , T 3 }, W (t) 1 = u (T1) v (t)T + R (t) 1 , W (t) 2 = c (t) u (T1)T + R (t)T 2 . where ∀1 ≤ i, j ≤ d : R (t) 1 [i, j] u (T1) i v (t) j ≤ Õ( 0 ), R (t) 2i c (t) u (T1) i ≤ Õ( 0 ), and 0 is defined in Definition 2. Moreover, when t = min{T 2 , T 3 }, L W (t) = O( 0 d). The following lemma gives us a more detailed description of u (T1) . Lemma 3. The u (T1) in Lemma 1 and 2 can be written as u (T1) = X + Y where X i , i ∈ [d] are i.i.d Gaussian random variables and that w.h.p. ∀i ∈ [d] : |Yi| |Xi| ≤ Õ 1 d 1 4 α-1 2 . Now we are ready to prove the SGD+M part of Theorem 1.

C.1 PROOF OF THE SGD+M PART OF THEOREM 1

Define T SGD,1 = T 1 , T SGD,2 = min{T 2 , T 3 }. By picking η ≤ O d 7α 4 +4 , we can apply Lemma 1 and 2 to conclude that L W (TSGD,1) = Θ(d) and L W (TSGD,2) = O( 0 d). For any p > 0, by picking 0 < < 1 d p and α ≥ 4(p + 2), we have L W (TSGD,2) = O( 0 d) ≤ Õ 1 d p . Moreover, when t ∈ [T SGD,1 , T SGD,2 ], the conditions in Lemma 30 are satisfied with δ = Õ( 0 ). Then we can apply Lemma 30 and get that R SGDM med,1 (t), R SGDM med,2 (t) ≥ 1 -Õ( 0 ) 1 + Õ( 0 ) 2 • max i u (T1) i 2 median u (T1) i 2 . By Lemma 3, u (T1) = X + Y where w.h.p. ∀i ∈ [d] : |Yi| |Xi| ≤ Õ 1 d 1 4 α-1 2 . This fact yields ∀i ∈ [d] : max i u (T1) i 2 median u (T1) i 2 ≥   1 -Õ 1 d 1 4 α-1 2 1 + Õ 1 d 1 4 α-1 2   2 max i X 2 i median X 2 i . Here X i , i ∈ [d] are i.i.d Gaussian random variables by Lemma 3. To prove the concentration of median X 2 i , we borrow the Proposition 12 in Chapter 2.3 of (Lerasle, 2019) . By setting K = N = d in this proposition, we have P median X 2 i -E[X 2 1 ] > 2 Var(X 2 1 ) ≤ e -d 8 . Denote σ 2 as the variance of X i , i ∈ [d]. Then E[X 2 i ] = σ 2 and Var(X 2 i ) = 2σ 4 . Hence P median X 2 i -σ 2 > 2 √ 2σ 2 ≤ e -d 8 . That means with high probability, median X 2 i ≤ Cσ 2 for some C > 0. By Lemma 34 in Appendix G, we know that w.h.p. max 1≤i≤d X 2 i = σ 2 Ω(log d), which gives us w.h.p. max 1≤i≤d X 2 i median X 2 i = Ω(log d). Hence we have proved that R SGDM med,1 (t), R SGDM med,2 (t) ≥ Ω(log d).

C.2 PROOF OF LEMMA 1

In the first phase, W 2 W 1 is "small", and we write the update equations in the following way W (t+1) 1 = W (t) 1 -η t τ =0 β t-τ W (τ )T 2 W (τ ) 2 W (τ ) 1 -A -η t τ =0 β t-τ Dg (τ ) 1 = W (t) 1 + η t τ =0 β t-τ W (τ )T 2 A -η t τ =0 β t-τ W (τ )T 2 W (τ ) 2 W (τ ) 1 -η t τ =0 β t-τ Dg (τ ) 1 = W (t) 1 + ηW (t)T 2 A t τ =0 β t-τ + η t τ =0 β t-τ W (τ )T 2 -W (t)T 2 A -η t τ =0 β t-τ W (τ )T 2 W (τ ) 2 W (τ ) 1 -η t τ =0 β t-τ Dg (τ ) 1 = W (t) 1 + η 1 -β W (t)T 2 A + η 1 -β r (t) 1 , where r (t) 1 = -β t+1 W (t)T 2 A + (1 -β) t τ =0 β t-τ W (τ )T 2 -W (t)T 2 A -(1 -β) t τ =0 β t-τ W (τ )T 2 W (τ ) 2 W (τ ) 1 -(1 -β) t τ =0 β t-τ Dg (τ ) 1 . Similarly, we have W (t+1) 2 = W (t) 2 -η t τ =0 β t-τ W (τ ) 2 W (τ ) 1 -A W (τ )T 1 = W (t) 2 + η 1 -β AW (t)T 1 + η 1 -β r (t) 2 , where r (t) 2 = -β t+1 AW (t)T 1 + (1 -β) t τ =0 β t-τ A W (τ )T 1 -W (t)T 1 -(1 -β) t τ =0 β t-τ W (τ ) 2 W (τ ) 1 W (τ )T 1 -(1 -β) t τ =0 β t-τ Dg (τ ) 2 . The following lemma gives us an explicit formula of W (t) 2 . Lemma 4. Let λ 1 < λ 2 be the two roots of the quadratic equation x 2 -2x + 1 -η 2 (1-β) 2 A 2 2 = 0. Pick η < 1-β A 2 , then we have that W (t) 2 = C 1 λ t 1 + C 2 + r (t) 5 λ t 2 , where C 1 = - W (1) 2 -λ2W (0) 2 λ2-λ1 , C 2 = W (1) 2 -λ1W (0) 2 λ2-λ1 . r (t) 5 will be specified in the proof. We can prove that in the first phase, r 5 is "small". More specifically, denote its i-th coordinate as r (t) 5i , and the i-th coordinate of C 2 as C 2i . Then the following lemmas tell us that ∀i ∈ [d], r (t) 5i ≤ O 1 d p(α) , where w.h.p. O 1 d p(α) min i∈[d] |C 2i |. We first have the following bounds of r (t) 1i , r 2i and r (t) 5i for i ∈ [d]. Lemma 5. Under Assumption 1, 2 and 3, suppose σ ≤ η 3/2 d α/2+1 and pick η ≤ O 1 d α . We have w.h.p. for all t ≤ T 1 , ∀i ∈ [d] : r (t) 1 [i, j] ≤ Õ 1 d 3 2 α-1 , r (t) 2i ≤ Õ 1 d 3 2 α-2 . Lemma 6. Under conditions of Lemma 5, we have that w.h.p. for all t ≤ T 1 , ∀i ∈ [d] : r (t) 5i ≤ Õ 1 d 3 2 α-1 . Next we prove upper and lower bounds of |C 1i | and |C 2i | for i ∈ [d]. Lemma 7. Under Assumption 1, 2 and 3, suppose σ ≤ η 3/2 d α/2+1 . Pick η < 1-β A 2 , we have that i) w.h.p., ∀i ∈ [d] : |C 1i | ≤ Õ 1 d α , |C 2i | ≤ Õ 1 d α ; ii) C 2 can be written as C 2 := 1 2 (C 3 + C 4 ) where C 3i , i ∈ [d] are i.i.d Gaussian random variables and that w.h.p. ∀i ∈ [d] : |C4i| |C3i| ≤ Õ 1 d 1 4 α-1 2 ; iii) w.h.p., ∀i ∈ [d], |C 1i | ≥ Ω 1 d 5 4 α , |C 2i | ≥ Ω 1 d 5 4 α . Now we are ready to prove Lemma 1. Lemma 4 tells us that W (t) 2 = C 1 λ t 1 + C 2 + r (t) 5 λ t 2 , where λ 1 = 1 -η 1-β A 2 and λ 2 = 1 + η 1-β A 2 . Under the conditions of Theorem 1 and pick η ≤ O 1 d α , by Lemma 6 and 7, we know that w.h.p. ∀t ≤ T 1 , ∀1 ≤ i ≤ d, r (t) 5i ≤ Õ 1 d 3 2 α-1 , |C 2i | ≥ Ω 1 d 5 4 α , |C 2i | ≤ Õ 1 d α , r (t) 5i |C 2i | ≤ Õ 1 d 1 4 α-1 . We first prove that w (t) 2i reaches 1 d α/2 for some coordinate i before W (t) 1 [k, j] for ∀k, j ∈ [d]. To see this, first note that W (t) 1 = W (t-1) 1 + η 1 -β W (t-1)T 2 A + η 1 -β r (t-1) 1 = W (0) 1 + η 1 -β t-1 τ =0 W (τ )T 2 A + η 1 -β t-1 τ =0 r (τ ) 1 = W (0) 1 + η 1 -β C 1 t-1 τ =0 λ τ 1 + C 2 t-1 τ =0 λ τ 2 + t-1 τ =0 λ τ 2 r (τ ) 5 T A + η 1 -β t-1 τ =0 r (τ ) 1 = W (0) 1 + η 1 -β t-1 τ =0 r (τ ) 1 + η 1 -β C 1 t-1 τ =0 λ τ 1 + t-1 τ =0 λ τ 2 r (τ ) 5 T A + η 1 -β t-1 τ =0 λ τ 2 C T 2 A := W (0) 1 + η 1 -β t-1 τ =0 r (τ ) 1 + C 1 t-1 τ =0 λ τ 1 + t-1 τ =0 λ τ 2 r (τ ) 5 T v (t)T + u (t) v (t)T , where v (t)T = η 1-β A and u (t) = t-1 τ =0 λ τ 2 C T 2 . ( ) Moreover, we have that W (t) 2 = C 1 λ t 1 + C 2 + r (t) 5 λ t 2 := C 1 λ t 1 + r (t) 5 λ t 2 + c (t) u (t)T , c (t) = λ t 2 t-1 τ =0 λ τ 2 . For t ≤ T 1 , by eq. ( 7), we get that w.h.p., ∀1 ≤ i, j ≤ d : t-1 τ =0 λ τ 2 r (τ ) 5i v (t) j u (t) i v (t) j ≤ Õ 1 d 3 2 α-1 t-1 τ =0 λ τ 2 Ω 1 d 5 4 α t-1 τ =0 λ τ 2 ≤ Õ 1 d 1 4 α-1 , λ t 2 r (t) 5i c (t) u (t) i = r (t) 5i |C 2i | ≤ Õ 1 d 1 4 α-1 . For t ≤ T 1 , by Lemma 5, ∀1 ≤ i, j ≤ d : r (t) 1 [i, j] ≤ Õ 1 d 3 2 α-1 . Then we have that w.h.p. η 1-β t-1 τ =0 r (τ ) 1 [i, j] u (t) i v (t) j = t-1 τ =0 r (τ ) 1 [i, j] t-1 τ =0 λ τ 2 C 2i A j ≤ t-1 τ =0 r (τ ) 1 [i, j] t-1 τ =0 C 2i A j ≤ t-1 τ =0 Õ 1 d 3 2 α-1 t-1 τ =0 Ω 1 d 5 4 α = Õ 1 d 1 4 α-1 . Here we used ∀i ∈ [d] : A i = Θ(1) by Assumption 1. Since λ 1 = 1 -η 1-β A 2 , we have that |C 1i λ t 1 | ≤ |C 1i | ≤ Õ 1 d α and that C 1i t-1 τ =0 λ τ 1 v (t) j = ηA j 1 -β C 1i t-1 τ =0 λ τ 1 ≤ ηA j |C 1i | (1 -β)(1 -λ 1 ) ≤ A j |C 1i | A 2 ≤ Õ 1 d α+ 1 2 . Using the Gaussian tail bound and union bound, we have w.h.p. ∀1 ≤ i, j ≤ d : W (0) 1 [i, j] = Õ 1 d 2α . Combining the above bounds together yields that for t ≤ T 1 and ∀i, j ∈ [d], W (t) 1 [i, j] = R (t) 11 [i, j] + u (t) i v (t) j (1 + e (t) 1 [i, j]), w (t) 2i = R (t) 21,i + c (t) u (t) i (1 + e (t) 2i ). (9) where for ∀i, j ∈ [d]. R (t) 11 [i, j] ≤ Õ 1 d α+ 1 2 , R (t) 21,i ≤ Õ 1 d α and e (t) 1 [i, j] , e (t) 2i ≤ Õ 1 d 1 4 α-1 . Further we notice that for t ≤ T 1 , we have ∀j ∈ [d], v (t) j c (t) = ηA j 1 -β • t-1 τ =0 λ τ 2 λ t 2 = ηA j 1 -β λ t 2 -1 λ t 2 (λ 2 -1) = A j (λ t 2 -1) λ t 2 A 2 ≤ A j A 2 = O 1 √ d . which yields that u (t) i v (t) j ≤ O 1 √ d c (t) u (t) i . Together with eq. ( 9) gives us that w (t) 2i reaches 1 d α/2 for some i ∈ [d] before W (t) 1 [k, j] for ∀k, j ∈ [d], i.e. T 1 = inf t ≥ 0 : ∃i ∈ [d] : w (t) 2i ≥ 1 d α 2 . Further, we know that at time T 1 , c (T1) u (T1) i0 = |C 2i0 |λ T1 2 = Θ 1 d α/2 for some i 0 ∈ [d], which means w.h.p. Θ 1 d α 2 Õ 1 d α ≤ λ T1 2 ≤ Θ 1 d α 2 Ω 1 d 5 4 α , ⇒ Ω d α 2 ≤ λ T1 2 = 1 + η 1 -β A 2 T1 ≤ Õ d 3 4 α , ⇒ T 1 = Θ log d η A 2 . (10) This is the length of the first phase. As for c (T1) u (T1) i and u (T1) i v (T1) j for other coordinates, we have that w.h.p. ∀1 ≤ i, j ≤ d, u (T1) i v (T1) j = η 1 -β T1-1 τ =0 λ τ 2 |C 2i A j | = η 1 -β • λ T1 2 -1 λ 2 -1 |C 2i A j | (i) = λ T1 2 -1 A 2 |C 2i A j | ≥ Ω d α/2 Θ √ d Ω 1 d 5 4 α = Ω 1 d 3 4 α+ 1 2 , c (T1) u (T1) i = |C 2i |λ T1 2 ≥ Ω d α/2 Ω 1 d 5 4 α = Ω 1 d 3 4 α . Here in (i) we used λ 2 = 1 + η 1-β A 2 . Then we have at time T 1 , ∀i, j ∈ [d], R (T 1 ) 11 [i,j] u (T 1 ) i v (T 1 ) j ≤ Õ 1 d 1 4 α and that R (T 1 ) 21,i c (T 1 ) u (T 1 ) i ≤ Õ 1 d 1 4 α . Together with eq. ( 9), we have the following weight structure: W (T1) 1 = R (T1) 1 + u (T1) v (T1)T , W (T1) 2 = R (T1)T 2 + c (T1) u (T1)T , where w.h.p., ∀1 ≤ i, j ≤ d : R (T1) 1 [i, j] u (T1) i v (T1) j ≤ Õ 1 d 1 4 α-1 , R (T1) 2i c (T1) u (T1) i ≤ Õ 1 d 1 4 α-1 . Finally, we consider the loss. Since ∀j ∈ [d] : W (T1) 2 W (T1) 1 j -A j = -Θ(1), we know that L W (T1) = Θ(d). C.3 PROOF OF LEMMA 3 Eq. ( 8) tells us that u . Combining these two facts together finishes the proof. (T1) = T1-1 τ =0 λ τ 2 C T 2 . C.4 PROOF OF LEMMA 4 Replacing t by t -1 in eq. ( 6), we get W (t) 2 = W (t-1) 2 + η 1 -β AW (t-1)T 1 + η 1 -β r (t-1) 2 . ( ) Eq. ( 6)-( 11) and substituting eq. ( 5) yield W (t+1) 2 -W (t) 2 = W (t) 2 -W (t-1) 2 + η 2 (1 -β) 2 A 2 2 W (t-1) 2 + η 2 (1 -β) 2 Ar (t-1)T 1 + η 1 -β r (t) 2 -r (t-1) 2 , ⇒ W (t+1) 2 = 2W (t) 2 -1 - η 2 (1 -β) 2 A 2 2 W (t-1) 2 + r (t) 3 , where r 3 (t) := η 2 (1-β) 2 Ar (t-1)T 1 + η 1-β r (t) 2 -r (t-1) 2 . For the equation x 2 -2x + 1 -η 2 (1-β) 2 A 2 2 = 0, the roots are λ 1 = 1 -η 1-β A 2 and λ 2 = 1 + η 1-β A 2 . We have that W (t+1) 2 -λ 2 W (t) 2 = λ 1 W (t) 2 -λ 2 W (t-1) 2 + r (t) 3 ⇒ W (t) 2 -λ 2 W (t-1) 2 = λ t-1 1 W (1) 2 -λ 2 W (0) 2 + t-1 τ =1 λ t-1-τ 1 r (τ ) 3 := λ t-1 1 W (1) 2 -λ 2 W (0) 2 + r (t) 4 . We further have W (t) 2 = λ t 2 W (0) 2 + t-1 τ =0 λ t-1-τ 2 λ τ 1 W (1) 2 -λ 2 W (0) 2 + t τ =1 λ t-τ 2 r (τ ) 4 = λ t 2 W (0) 2 + λ t 2 -λ t 1 λ 2 -λ 1 W (1) 2 -λ 2 W (0) 2 + t τ =1 λ t-τ 2 r (τ ) 4 = C 1 λ t 1 + C 2 λ t 2 + t τ =1 λ t-τ 2 r (τ ) 4 = C 1 λ t 1 + C 2 + r (t) 5 λ t 2 , where r (t) 5 = t τ =1 λ -τ 2 r (τ ) 4 , C 1 = - W (1) 2 -λ2W (0) 2 λ2-λ1 and C 2 = W (1) 2 -λ1W (0) 2 λ2-λ1 . C.5 PROOF OF LEMMA 5 Write r (t) 1 = -β t+1 W (t)T 2 A + q (t) 12 + q (t) 13 + q (t) 14 where q (t) 12 = (1 - β) t τ =0 β t-τ W (τ )T 2 -W (t)T 2 A, q (t) 13 = -(1 -β) t τ =0 β t-τ W (τ )T 2 W (τ ) 2 W (τ ) 1 and q (t) 14 = -(1 -β) t τ =0 β t-τ Dg (τ ) 1 . And write r (t) 2 = -β t+1 AW (t)T 1 + q (t) 22 + q (t) 23 + q (t) 24 , where q (t) 22 = (1 -β) t τ =0 β t-τ A W (τ )T 1 -W (t)T 1 , q (t) 23 = -(1 -β) t τ =0 β t-τ W (τ ) 2 W (τ ) 1 W (τ )T 1 and q (t) 24 = -(1 -β) t τ =0 β t-τ Dg (τ ) 2 . Let's first try to bound q (t) 12 [i, j] and q (t) 22,i . For any τ ≤ T 1 , we have that ∀i ∈ [d] : W (τ ) 2 W (τ ) 1 i = d j=1 w (τ ) 2j W (τ ) 1 [j, i] ≤ d j=1 w (τ ) 2j W (τ ) 1 [j, i] ≤ d i=1 1 d α = 1 d α-1 , and thus ∀i ∈ [d] : E (τ ) i = O(1). Then we have for all i, j ∈ [d], W (τ +1) 1 [i, j] -W (τ ) 1 [i, j] ≤ η τ k=0 β τ -k w (k) 2i E (k) j ≤ η τ k=0 β τ -k O 1 d α/2 = ηO 1 d α/2 , w (τ +1) 2i -w (τ ) 2i ≤ η τ k=0 β τ -k d j=1 E (k) j W (k) 1 [i, j] ≤ η τ k=0 β τ -k O 1 d α/2-1 = ηO 1 d α/2-1 . That gives us ∀i, j ∈ [d], q (t) 12 [i, j] ≤ (1 -β) t τ =0 β t-τ w (τ ) 2i -w (t) 2i A j ≤ η(1 -β) t τ =0 O β t-τ (t -τ ) d α/2-1 = O η d α/2-1 , q (t) 22,i ≤ (1 -β) t τ =0 β t-τ d j=1 A j W (τ ) 1 [i, j] -W (t) 1 [i, j] ≤ η(1 -β) t τ =0 O β t-τ (t -τ ) d α/2-1 = O η d α/2-1 . Then we bound q (t) 13 [i, j] and q (t) 23,i . We have for ∀i, j ∈ [d], q (t) 13 [i, j] ≤ (1 -β) t τ =0 β t-τ w (τ ) 2i W (τ ) 2 W (τ ) 1 j ≤ (1 -β) t τ =0 β t-τ 1 d α 2 • 1 d α-1 = O 1 d 3 2 α-1 , q (t) 23,i ≤ (1 -β) t τ =0 β t-τ d j=1 W (t) 2 W (t) 1 j W (t) 1 [i, j] ≤ (1 -β) t τ =0 β t-τ d i=1 1 d α-1+ α 2 = O 1 d 3 2 α-2 . Finally we use Lemma 31 to bound q (t) 14 [i, j] and q (t) 24,i . For t ≤ T 1 , the  M (t) 1 , M Dg (t) 1 [i, j] = g(t) 1 [i, j] -g (t) 1 [i, j] ≤ O 1 d 3α 2 -3 σ d α+1 η log d + O 1 d α 2 σ d α+2 η log d ≤ Õ 1 d α 2 σ d α+2 η , Dg (t) 2i = g(t) 2i -g (t) 2i ≤ O 1 d 3α 2 -4 σ d α+1 η log d + O 1 d α 2 -1 σ d α+2 η log d ≤ Õ 1 d α 2 -1 σ d α+2 η . By picking σ ≤ η 3/2 d α/2+1 , we have w.h.p. for ∀t ≤ T 1 and ∀i, j ∈ [d], Dg (t) 1 [i, j] ≤ η Õ 1 d α 2 and Dg (t) 2i ≤ η Õ 1 d α 2 -1 , which yields q (t) 14 [i, j] ≤ (1 -β) t τ =0 β t-τ Dg (τ ) 1 [i, j] ≤ (1 -β) t τ =0 β t-τ η Õ 1 d α 2 = η Õ 1 d α 2 , q (t) 24,i ≤ (1 -β) t τ =0 β t-τ Dg (τ ) 2i ≤ (1 -β) t τ =0 β t-τ η Õ 1 d α 2 -1 = η Õ 1 d α 2 -1 . Combining all the above bounds and substituting η ≤ O 1 d α gives us for ∀t ≤ T 1 and ∀i, j ∈ [d], r (t) 1 [i, j] ≤ β t+1 w (t) 2i A j + Õ 1 d 3 2 α-1 , r (t) 2i ≤ β t+1 d j=1 A j W (t) 1 [i, j] + Õ 1 d 3 2 α-2 . ( ) For t ≤ T 1 , we have ∀i, j ∈ [d], w (t) 2i A j ≤ O 1 d α/2 and d j=1 A j W (t) 1 [i, j] ≤ O 1 d α/2-1 , which gives us r (t) 1 [i, j] ≤ O 1 d α/2 and r (t) 2i ≤ O 1 d α/2-1 . Substituting into eq. ( 5) and eq. ( 6) yields that for t ≤ T 1 and ∀i, j ∈ [d], W (t+1) 1 [i, j] -W (t) 1 [i, j] ≤ O η d α/2 , w -w (t) 2i ≤ O η d α/2-1 . Hence for t ≤ min α log d log(1/β) , T 1 , we have ∀i, j ∈ [d], W (t) 1 [i, j] ≤ W (0) 1 [i, j] + α log d log(1/β) O η d α/2 ≤ Õ 1 d 3α 2 , w (t) 2i ≤ w (0) 2i + α log d log(1/β) O η d α/2-1 ≤ Õ 1 d 3α 2 -1 . Then we know that T 1 > α log d log(1/β) and also get tighter bounds of W (t) 1 [i, j] , w 2i for t ≤ α log d log(1/β) . Now we use these new bounds to analyze r (t) 1 [i, j] and r (t) 2i again. When t ≤ α log d log(1/β) , we have for all i, j ∈ [d], β t+1 w (t) 2i A j ≤ w (t) 2i A j ≤ Õ 1 d 3α 2 -1 and β t+1 d j=1 A j W (t) 1 [i, j] ≤ d j=1 A j W (t) 1 [i, j] ≤ Õ 1 d 3α 2 -1 . When α log d log(1/β) < t ≤ T 1 , we have β t+1 ≤ 1 d α , suggesting that ∀i, j ∈ [d], β t+1 w (t) 2i A j ≤ 1 d α Õ 1 d α 2 ≤ Õ 1 d 3α 2 and β t+1 d j=1 A j W (t) 1 [i, j] ≤ 1 d α Õ 1 d α 2 -1 ≤ Õ 1 d 3α 2 -1 . Substituting into (12) completes the proof. Under review as a conference paper at ICLR 2023 C.6 PROOF OF LEMMA 6 Based on the bound in Lemma 5, we have r (t) 3i = η 2 (1 -β) 2 d j=1 A j r (t-1) 1 [i, j] + η 1 -β r (t) 2i -r (t-1) 2i ≤ η 2 (1 -β) 2 d j=1 A j r (t-1) 1 [i, j] + η 1 -β r (t) 2i + η 1 -β r (t-1) 2i ≤ η 2 Õ 1 d 3 2 α-2 + 2η Õ 1 d 3 2 α-2 = η Õ 1 d 3 2 α-2 . Since λ 1 = 1 -η 1-β A 2 , λ 2 = 1 + η 1-β A 2 , and note that A 2 = Θ √ d , we have that r (t) 4i = t-1 τ =1 λ t-1-τ 1 r (τ ) 3i ≤ t-1 τ =1 λ t-1-τ 1 Õ η d 3 2 α-2 ≤ η 1 -λ 1 Õ 1 d 3 2 α-2 = Õ 1 d 3 2 (α-1) , r (t) 5i = t τ =1 λ -τ 2 r (τ ) 4i ≤ η t τ =1 λ -τ 2 Õ 1 d 3 2 (α-1) ≤ η λ 2 -1 Õ 1 d 3 2 (α-1) = Õ 1 d 3 2 α-1 . C.7 PROOF OF LEMMA 7 For the equation x 2 -2x + 1 -η 2 (1-β) 2 A 2 2 = 0, the roots are λ 1 = 1 -η 1-β A 2 and λ 2 = 1 + η 1-β A 2 , which gives us C 2 = W (1) 2 -λ 1 W (0) 2 λ 2 -λ 1 = W (0) 2 + ηAW (0)T 1 + ηr (0) 2 -W (0) 2 + η 1-β A 2 W (0) 2 2η 1-β A 2 = 1 2 W (0) 2 + 1 -β 2 A 2 AW (0)T 1 + 1 -β 2 A 2 r(0) 2 , where r(0) 2 = -W (0) 2 W (0) 1 W (0)T 1 -Dg (0) 2 . Note that this is slightly different from the definition of r (0) 2 in eq. ( 6). Now let's bound the i-th coordinate of r(0) 2 . In Section C.5 we have shown that w.h.p. for ∀t ≤ T 1 and ∀i, j ∈ [d], Dg (t) 2i ≤ η Õ 1 d α 2 -1 = Õ 1 d 3α 2 -1 , which also applies to t = 0. Using the Gaussian tail bound and union bound, w.p. at least 1 -δ, for ever 1 ≤ i, j ≤ d, we have that w (0) 2i ≤ 2 d 2α log 2d δ , W (0) 1 [i, j] ≤ 2 d 4α log 2d 2 δ . Then we have that w.p. at least 1 -δ, ∀1 ≤ i, j ≤ d :, W (0) 2 W (0) 1 i = d j=1 w (0) 2j W (0) 1 [j, i] ≤ d j=1 w (0) 2j W (0) 1 [j, i] ≤ d i=1 2 d 2α log 2d δ 2 d 4α log 2d 2 δ ≤ 2 d 3α-1 log 2d 2 δ , ⇒ r(0) 2i ≤ d j=1 W (0) 2 W (0) 1 j W (0) 1 [i, j] + Dg (0) 2i ≤ d i=1 2 d 3α-1 log 2d 2 δ 2 d 4α log 2d 2 δ + Õ 1 d 3α 2 -1 = Õ 1 d 3α 2 -1 . Next, we bound the i-th coordinate of W (0) 2 + 1-β A 2 AW (0)T 1 , i.e. w (0) 2i + 1-β A 2 A W (0) 1 [i, :] T . By independence under Assumption 2, we have that Var w (0) 2i + 1 -β A 2 A W (0) 1 [i, :] T = Var w (0) 2i + (1 -β) 2 A 2 2 d j=1 A 2 j Var W (0) 1 [i, j] = 1 d 2α + (1 -β) 2 A 2 2 d i=1 A 2 j 1 d 4α = O 1 d 2α . Using the Gaussian tail bound and union bound, w.p. at least 1 -δ, for ever 1 ≤ i ≤ d, we have that w (0) 2i + 1 -β A 2 A W (0) 1 [i, :] T ≤ O 1 d 2α log d δ = Õ 1 d α . Since for X ∼ N (0, σ 2 ), we have that P (|X| ≤ t) ≤ 2t √ 2πσ , then for a fixed i, P w (0) 2i + 1 -β A 2 A W (0) 1 [i, :] T ≤ 1 d 5 4 α ≤ O 2/d 5 4 α √ 2π • 1/d 2α = Θ 1 d α 4 . Then by union bound, we have that w.p. at least 1 - 1 d α 4 -1 , for every 1 ≤ i ≤ d, w (0) 2i + 1 -β A 2 A W (0) 1 [i, :] T ≥ Θ 1 d 5 4 α . Now define C 3 := W (0) 2 + 1-β A 2 AW (0)T 1 and C 4 := 1-β 2 A 2 r(0) 2i . We get that C 3i , i ∈ [d] are i.i.d Gaussian random variables and that C 2 = 1 2 (C 3 + C 4 ) , where w.h.p. for all i ∈ [d], |C 3i | ≤ Õ 1 d α , |C 3i | ≥ Θ 1 d 5 4 α , |C 4i | (i) ≤ Õ 1 d 3α 2 -1 2 , ( ) where (i) follows from eq. ( 14) and the fact that A 2 = √ d. Then we get that w.h.p. ∀i ∈ [d] : |C 4i | |C 3i | ≤ Õ 1 d 3α 2 -1 2 Ω 1 d 5 4 α = Õ 1 d 1 4 α-1 2 . Substituting eq. ( 15) into eq. ( 13), we get that w.h.p., |C 2i | = Θ w (0) 2i + 1 -β A 2 A W (0) 1 [i, :] T ∈ Ω 1 d 5 4 α , Õ 1 d α . Similarly, note that C 1 = - W (1) 2 -λ 2 W (0) 2 λ 2 -λ 1 = - W (0) 2 + ηAW (0)T 1 + ηr (0) 2 -W (0) 2 -η 1-β A 2 W (0) 2 2η 1-β A 2 = 1 2 W (0) 2 - 1 -β 2 A 2 AW (0)T 1 - 1 -β 2 A 2 r(0) 2 , we can use the same techniques to get that i) w.p. at least 1 -δ, ∀i ∈ [d] : |C 1i | ≤ Õ 1 d α , ii) w.p. at least 1 -δ -1 d α 4 -1 , ∀i ∈ [d], |C 1i | ≥ Ω 1 d 5 4 α . C.8 PROOF OF LEMMA 2 The proof in Section C.2 tells us that at the end of the first phase (when t = T 1 ), W (T1) 1 = u (T1) v (T1)T + R (T1) 1 , W (T1) 2 = c (T1) u (T1)T + R (T1)T 2 , where v (T1)T = ηA 1 -β , c (T1) = λ T1 2 T1-1 τ =0 λ τ 2 . (16) Denote the i-th coordinate of u (t) , v (t) , R (t) 2 as u (t) i , v (t) i , R (t) 2i , respectively. Denote the (i, j)-th element of R (t) 1 as R (t) 1 [i, j]. For t ≥ T 1 , we prove by induction that, W (t) 1 = u (T1) v (t)T + R (t) 1 , W (t) 2 = c (t) u (T1)T + R (t)T 2 , ( ) where v (t+1)T = v (t)T -η t c (t) E (t) , R (t+1) 1 = R (t) 1 -η t R (t) 2 E (t) + r (t) 1 , c (t+1) = c (t) -η t E (t) v (t) , R (t+1)T 2 = R (t)T 2 -η t E (t) R (t)T 1 + r (t) 2 , with r (t) 1 := η t τ =0 β t-τ W (t)T 2 E (t) -W (τ )T 2 E (τ ) -η t τ =0 β t-τ Dg (τ ) 1 , E (t) := W (t) 2 W (t) 1 - A, η t = η t τ =0 β t-τ and r (t) 2 = η t τ =0 β t-τ E (t) W (t)T 1 -E (τ ) W (τ )T 1 -η t τ =0 β t-τ Dg (τ ) 2 . Note that the r 2 to represent the error terms. The base case is already given by eq. ( 16). Suppose our lemma holds for t, then for t + 1, using the same techniques as in eq. ( 5) and eq. ( 6), we have that W (t+1) 1 = W (t) 1 -η t τ =0 β t-τ W (τ )T 2 E (τ ) -η t τ =0 β t-τ Dg (τ ) 1 = W (t) 1 -η t W (t)T 2 E (t) + r (t) 1 , W (t+1) 2 = W (t) 2 -η t τ =0 β t-τ E (τ ) W (τ )T 1 -η t τ =0 β t-τ Dg (τ ) 2 = W (t) 2 -η t E (t) W (t)T 1 + r (t) 2 , Plugging in the inductive hypothesis yields W (t+1) 1 = W (t) 1 -η t W (t)T 2 E (t) + r (t) 1 = u (T1) v (t)T + R (t) 1 -η t c (t) u (T1) + R (t) 2 E (t) + r (t) 1 = u (T1) v (t)T -η t c (t) E (t) + R (t) 1 -η t R (t) 2 E (t) + r (t) 1 , W (t+1) 2 = W (t) 2 -η t E (t) W (t)T 1 + r (t) 2 = c (t) u (T1)T + R (t)T 2 -η t E (t) v (t) u (T1)T + R (t)T 1 + r (t) 2 = c (t) -η t E (t) v (t) u (T1)T + R (t)T 2 -η t E (t) R (t)T 1 + r (t) 2 . It implies that our lemma holds for t + 1, which completes the proof.

Now we analyze the error terms

R (t) 1 [i, j] and R (t) 2i . Eq. ( 16) tells us that c (T1) and ∀i ∈ [d], v (T1) i are all positive. We first prove by induction that for all T 1 ≤ t ≤ T 2 , c (t) > 0, ∀i ∈ [d], v (t) i > 0. The above discussion already proves the base case. Suppose at time t, we have c (t) > 0, ∀i ∈ [d], v (t) i > 0. Note that when T 1 ≤ t < T 2 , ∀i ∈ [d] : E (t) i ≤ 0, then for t + 1, v (t+1) i = v (t) i -η t c (t) E (t) i > 0, c (t+1) = c (t) -η t d i=1 E (t) i v (t) i > 0. Therefore by induction, we have proved that for all T 1 ≤ t ≤ T 2 , c (t) > 0, ∀i ∈ [d], v (t) i > 0. Now we prove that for all T 1 ≤ t ≤ T 2 , ∀1 ≤ i, j ≤ d : 0 ≤ R (t) 1 [i, j] u (T1) i v (t) j ≤ δ i + t-1 τ =T1 (τ ) i , 0 ≤ R (t) 2i c (t) u (T1) i ≤ δ i + t-1 τ =T1 (τ ) i , ( ) where δ i := max    max j R (T1) 1 [i, j] u (T1) i v (T1) j , R (T1) 2i c (T1) u (T1) i    , (t) i := max    max j r (t) 1 [i, j] u (T1) i v (t) j , r (t) 2i c (t) u (T1) i    . The left hand sides of the inequalities are trivial since we have proved that c (t) > 0, ∀i ∈ [d], v i > 0 for all T 1 ≤ t ≤ T 2 . Now we prove the right hand sides by induction. The base case is already verified by the definition of δ i . Suppose eq.( 18) holds for T 1 ≤ t < T 2 . Then for t + 1, using ∀i ∈ [d] : E (t) i ≤ 0 and v (t+1) ≥ v (t) , c (t+1) ≥ c (t) , we can get that ∀1 ≤ i, j ≤ d R (t+1) 1 [i, j] u (T1) i v (t+1) j ≤ R (t) 1 [i, j] 1 u (T 1 ) i + η t R (t) 2i 1 u (T 1 ) i -E (t) j v (t) j + η t c (t) -E (t) j + r (t) 1 [i, j] u (T1) i v (t) j ≤ δ i + t-1 τ =T1 (τ ) i v (t) j + η t δ i + t-1 τ =T1 (τ ) i c (t) -E (t) j v (t) j + η t c (t) -E (t) j + (t) i = δ i + t τ =T1 (τ ) i . Similarly, we have that ∀1 ≤ i ≤ d R (t+1) 2i c (t+1) u (T1) i ≤ R (t) 2i 1 u (T 1 ) i + η t d j=1 -E (t) j R (t) 1 [i, j] 1 u (T 1 ) i c (t) + η t d j=1 -E (t) j v (t) j + r (t) 2i u (T1) i c (t) ≤ δ i + t-1 τ =T1 (τ ) i c (t) + η t δ i + t-1 τ =T1 (τ ) i d j=1 -E (t) j v (t) j c (t) + η t d j=1 -E (t) j v (t) j + (t) i = δ i + t τ =T1 (τ ) i . Therefore by induction, eq. ( 18) holds for all t in the second phase. So far we have proved the rank 1 structure stated in Lemma 2. The remaining part of the proof is given by the following lemma, whose proof is deferred to Section C.9. , we have that w.h.p. for T 1 ≤ t ≤ min{T 2 , T 3 }, ∀1 ≤ i, j ≤ d : 0 ≤ R (t) 1 [i, j] u (T1) i v (t) j ≤ Õ( 0 ), 0 ≤ R (t) 2i c (t) u (T1) i ≤ Õ( 0 ), and that when t = min{T 2 , T 3 }, we have E (t) 2 2 = O( 0 d). C.9 PROOF OF LEMMA 8 We first have the following lemma which describes the structure of v (t) for t ≥ T 1 . Lemma 9. Under Assumption 1, 2 and 3, for t ≥ T 1 , we can write v (t)T as v (t)T = a (t) A + R (t)T v , with a (T1) = η 1-β , R (T1)T v = [0, 0, ..., 0], a (t+1) = 1 -η t c (t) d (t) a (t) + η t c (t) , R (t+1) v = 1 -η t c (t) d (t) R (t) v -η t c (t) R (t) 3 , where d (t) := c (t) u (T1) 2 + R (t)T 2 u (T1) , R (t)T 3 := c (t) u (T1)T R (t) 1 + R (t)T 2 R (t) 1 . Moreover, we have that W (t) 2 W (t) 1 = d (t) v (t)T + R (t)T 3 = d (t) a (t) A + d (t) R (t)T v + R (t)T 3 . ( ) We prove Lemma 8 by induction. Denote the i-th coordinate of R (t) 3 and R (t) v as R (t) 3i and R (t) vi , respectively. The following lemmas constitute the inductive part. Lemma 10. Under Assumption 1, 2 and 3, suppose σ ≤ η 3/2 d α/2+1 and pick η ≤ O 1 d α . Consider any t such that T 1 ≤ t < min{T 2 , T 3 }. Suppose for all T 1 ≤ τ ≤ t, we have ∀i, j ∈ , we have that at time t + 1, [d] : w (τ ) 2i ≤ O d 1/4 , W (τ ) 1 [i, j] ≤ O 1 d 1/4 , then we have that ∀i, j ∈ [d] : r (t) 1 [i, j] = Õ η 2 d 11/4 , r ∀1 ≤ i, j ≤ d : 0 ≤ R (t+1) 1 [i, j] u (T1) i v (t+1) j ≤ Õ( 0 ), 0 ≤ R (t+1) 2i c (t+1) u (T1) i ≤ Õ( 0 ). where 0 is defined in Definition 2. Lemma 12. Under the conditions of Lemma 10 and pick η ≤ O d 7α 4 +4 , we have that at time t + 1, 0 ≤ R (t+1)T 2 u (T1) c (t+1) u (T1) 2 ≤ Õ( 0 ), ∀j ∈ [d] : 0 ≤ R (t+1) 3j c (t+1) u (T1) 2 v (t) j ≤ Õ( 0 ). Moreover, ∀j ∈ [d] : R (t+1) 3j A j ≤ Õ( 0 ). ( ) Lemma 13. Under the conditions of Lemma 10 and pick η ≤ O d 7α 4 +4 , if we further suppose that ∀j ∈ [d] : v (t) j c (t) = Θ 1 √ d , R (t) 3j Aj and R (t) vj a (t) Aj are of order Õ( 0 ), then we have that at time t + 1, By combining Lemma 10, 11 and 13, we can prove by induction that for all T 1 ≤ t ≤ min{T 2 , T 3 }, eq. ( 19) holds (which follows from Lemma 11), and (A) ∀i, j ∈ [d] : E (t) j E (t) i = Θ(1), (B) ∀j ∈ [d] : v (t+1) j c (t+1) = Θ 1 √ d , (C) ∀i, j ∈ [d] : w (t+1) 2i ≤ O d 1/4 , W (t+1) 1 [i, j] ≤ O 1 d 1/4 , (D) ∀j ∈ [d], R ∀i, j ∈ [d] : E (t) j E (t) i = Θ(1), which follows from the part (A) of Lemma 13. Now the only thing to verify is the base case, i.e. when t = T 1 . More specifically, we want to prove that 1) ∀i, j ∈ [d] :  w (T1) 2i ≤ O d 1/4 , W (T1) 1 [i, j] ≤ O 1 d 1/4 and that 2) ∀j ∈ [d] : v (T 1 ) j c (T 1 ) = Θ 1 √ d , R (t) v , R (t) 3 . So far we have proved eq. ( 19) in Lemma 8. Now let's prove when t = min{T 2 , T 3 }, we have that E (t) 2 2 = O( 0 d). If min{T 2 , T 3 } = T 3 , by Definition 3, we have E (t) 2 2 ≤ . If min{T 2 , T 3 } = T 2 , by Definition 2, there exists j ∈ [d] such that E (t) j = -Θ √ 0 . Combining with eq. ( 22) gives us ∀i ∈ [d] : E (t) i = -Θ √ 0 . Combining these two cases, we get that when t = min{T 2 , T 3 }, E (t) 2 2 ≤ max{ , Θ ( 0 d)} = O ( 0 d).

C.10 PROOF OF LEMMA 9

We prove this lemma by induction. The base case (t = T 1 ) of v (t) is verified by eq. ( 16). Suppose at time t, v (t)T = a (t) A + R (t)T v , then by eq. 17, we have that W (t) 2 W (t) 1 = c (t) u (T1)T + R (t)T 2 u (T1) v (t)T + R (t) 1 = c (t) u (T1) 2 + R (t)T 2 u (T1) v (t) + c (t) u (T1)T R (t) 1 + R (t)T 2 R (t) 1 = d (t) v (t)T + R (t)T 3 = d (t) a (t) A + d (t) R (t)T v + R (t)T 3 , where d (t) := c (t) u (T1) 2 + R (t)T 2 u (T1) , R (t)T 3 := c (t) u (T1)T R (t) 1 + R (t)T 2 R (t) 1 . That gives us v (t+1)T = v (t)T -η t c (t) E (t) = a (t) A + R (t)T v -η t c (t) d (t) a (t) A + d (t) R (t)T v + R (t)T 3 -A = 1 -η t c (t) d (t) a (t) + η t c (t) A + 1 -η t c (t) d (t) R (t)T v -η t c (t) R (t)T 3 := a (t+1) A + R (t+1)T v . Therefore we have proved by induction that for t in the second phase, v (t) = a (t) A + R (t)T v . The above steps also proved eq. ( 20).

C.11 PROOF OF LEMMA 10

Write r (t) 1 = q (t) 11 + q (t) 12 where we have q (t) 11 = η t τ =0 β t-τ W (t)T 2 E (t) -W (τ )T 2 E (τ ) , q (t) 12 = -η t τ =0 β t-τ Dg (τ ) 1 . Write r (t) 2 = q (t) 21 + q (t) 22 where q (t) 22 = -η t τ =0 β t-τ Dg (τ ) 2 , q (t) 21 = η t τ =0 β t-τ E (t) W (t)T 1 -E (τ ) W (τ )T 1 . Let's first bound q (t) 11 [i, j] and q (t) 21,i . By definition of T 2 , we know that for T 1 ≤ τ ≤ t, ∀i ∈ [d] : E (τ ) i = O(1). Then we have for all i, j ∈ [d], W (τ +1) 1 [i, j] -W (τ ) 1 [i, j] ≤ η τ k=0 β τ -k w (k) 2i E (k) j ≤ η τ k=0 β τ -k O d 1/4 = ηO d 1/4 , w (τ +1) 2i -w (τ ) 2i ≤ η τ k=0 β τ -k d j=1 E (k) j W (k) 1 [j, i] ≤ η τ k=0 β τ -k O d 3/4 = ηO d 3/4 . (23) Note that E (τ +1) j -E (τ ) j = d i=1 w (τ +1) 2i -w (τ ) 2i W (τ ) 1 [i, j] + w (τ ) 2i W (τ +1) 1 [i, j] -W (τ ) 1 [i, j] + d i=1 w (τ +1) 2i -w (τ ) 2i W (τ +1) 1 [i, j] -W (τ ) 1 [i, j] . We can further get that for ∀j ∈ [d], E (τ +1) j -E (τ ) j ≤ ηdO d 3/4 O d -1/4 + ηdO d 1/4 O d 1/4 + η 2 dO d 3/4 O d 1/4 = O ηd 3/2 + η 2 d 2 = O ηd 3/2 . Combining the above inequalities gives us ∀i, j ∈ [d], q (t) 11 [i, j] = η t τ =0 β t-τ w (t) 2i E (t) j -w (τ ) 2i E (τ ) j ≤ η t τ =0 β t-τ w (t) 2i -w (τ ) 2i E (t) j + w (τ ) 2i E (t) j -E (τ ) j ≤ η 2 t τ =0 β t-τ (t -τ ) O d 3/4 O(1) + O d 1/4 O d 3/2 = O η 2 d 7/4 , q (t) 12,i = η t τ =0 β t-τ d j=1 E (t) j W (t) 1 [i, j] -E (τ ) j W (τ ) 1 [i, j] ≤ η t τ =0 β t-τ d j=1 E (t) j W (t) 1 [i, j] -W (τ ) 1 [i, j] + E (t) j -E (τ ) j W (τ ) 1 [i, j] ≤ η 2 d t τ =0 β t-τ (t -τ ) O(1)O d 1/4 + O d 3/2 O d -1/4 = O η 2 d 9/4 . Next let's bound q (t) 12 [i, j] and q (t) 22,i . By the assumption of this lemma and the analysis before T 1 , we know that for all τ ≤ t, the M (τ ) 1 , M (τ ) 2 in Lemma 31 are upper bounded by O 1 d 1/4 and O d 1/4 , respectively. In the theorem we consider the training period before T SGD,2 so the time T in Lemma 31 is set as T SGD,2 . In the following sections, we will prove that T SGD,2 ≤ O d α log( √ d/ ) η . Then by Lemma 31, we have with probability at least 1 -1 d , for ∀τ ≤ t and ∀i, j ∈ [d], Dg (τ ) 1 [i, j] = g(τ) 1 [i, j] -g (τ ) 1 [i, j] ≤ O d 13 4 σ d α+1 η log d + O d 1 4 σ d α+2 η log d ≤ Õ d 13 4 σ d α+1 η , Dg (τ ) 2i = g(τ) 2i -g (τ ) 2i ≤ O d 15 4 σ d α+1 η log d + O d 3 4 σ d α+2 η log d = Õ d 15 4 σ d α+1 η . By picking σ ≤ η 3/2 d α/2+1 , we have Dg (τ ) 1 [i, j] ≤ η Õ d 11 4 and Dg (τ ) 2i ≤ η Õ d 13 4 , which yields q (t) 12 [i, j] ≤ η t τ =0 β t-τ Dg (τ ) 1 [i, j] ≤ Õ η 2 d 11 4 , q (t) 22,i ≤ η t τ =0 β t-τ Dg (τ ) 2i ≤ Õ η 2 d 13 4 . Combining the above bounds, we get that ∀i, j ∈ [d], r (t) 1 [i, j] ≤ Õ η 2 d 11 4 , r (t) 2i ≤ Õ η 2 d 13 4 . By the analysis in Section C.2, we know that at time T 1 , for some i 0 ∈ [d], c (T1) u (T1) i0 = Θ 1 d α 2 , and for ∀i, j ∈ [d], we have c (T1) u (T1) i = Ω 1 d 3α 4 and u (T1) i v (T1) j = Ω 1 d 3α 4 + 1 2 , which gives us ∀i, j ∈ [d], r (t) 1 [i, j] u (T1) i v (t) j ≤ r (t) 1 [i, j] u (T1) i v (T1) j = Õ η 2 d 3 4 α+ 13 4 , r (t) 2i c (t) u (T1) i ≤ r (t) 2i c (T1) u (T1) i = Õ η 2 d 3 4 α+ 13 4 . Hence we get the bound ∀i ∈ [d] :  r (t) 2 W (t) 1 j ≤ d i=1 r (t) 2i W (t) 1 [i, j] = O η 2 d 4 , W (t) 2 r (t) 1 j ≤ d i=1 w (t) 2i r (t) 1 [i, j] = O η 2 d 4 . Combining with eq. ( 23), we get E (t+1) =E (t) + W (t+1) 2 -W (t) 2 W (t) 1 + W (t) 2 W (t+1) 1 -W (t) 1 + W (t+1) 2 -W (t) 2 W (t+1) 1 -W (t) 1 =E (t) -η t E (t) W (t)T 1 W (t) 1 + r (t) 2 W (t) 1 -η t W (t) 2 W (t)T 2 E (t) + W (t) 2 r (t) 1 + O η 2 d =E (t) I -η t W (t)T 1 W (t) 1 -η t W (t) 2 2 2 I + O η 2 d 4 + O η 2 d 4 + O η 2 d . Then we have E (t+1) 2 ≤ E (t) 2 I -η t W (t)T 1 W (t) 1 -η t W (t) 2 2 2 I 2 + O η 2 d 4 ≤ 1 -η t W (t) 2 2 2 E (t) 2 + O η 2 d 4 . When T 1 ≤ t < T 2 , we have proved that c (t) is increasing over time in Section C.8, which implies that W (t) 2 2 2 ≥ C W (T1) 2 2 2 since c (t) u (T1)T is the leading term of W (t) 2 . Combining with η t ≥ η gives us E (t+1) 2 ≤ 1 -ηC W (T1) 2 2 2 E (t) 2 + O η 2 d 4 , ⇒ E (t) 2 ≤ O η 2 d 4 ηC W (T1) 2 2 2 + 1 -ηC W (T1) 2 2 2 t-T1    E (T1) 2 - O η 2 d 4 ηC W (T1) 2 2 2    (i) ≤ O    ηd 4 W (T1) 2 2 2    + exp -ηC W (T1) 2 2 2 (t -T 1 ) O √ d , where (i) uses E (T1) 2 = O( √ d). By picking η ≤ O d 7α 4 +4 and noticing that W (T1) 2 2 2 ≥ Ω 1 d α , we have ηd 4 W (T 1 ) 2 2 2 < √ 2 . Hence when t -T 1 ≥ Θ log √ d/ η W (T 1 ) 2 2 2 , we have that E (t) 2 ≤ √ , i.e. E (t) 2 2 ≤ . That means after at most O log √ d/ η W (T 1 ) 2 2 2 steps from T 1 , either t ≥ T 2 , or we have E (t) 2 2 ≤ . In other words, min{T 2 , T 3 } ≤ T 1 + O log √ d/ η W (T 1 ) 2 2 2 ≤ O d α log √ d/ η . Now we are ready to bound eq. 18. Combining min{T 2 , T 3 } ≤ O d α log( √ d/ ) η and Lemma 10 yields that for t + 1 ≤ min{T 2 , T 3 }, ∀i ∈ [d], t+1 τ =T1 (τ ) i ≤ (t + 1 -T 1 ) Õ η 2 d 3 4 α+ 13 4 ≤ Õ ηd 7 4 α+ 13 4 log d = Õ log d . Lemma 1 tells us that δ i = Õ 1 d 1 4 α-1 . Substituting these bounds into eq. ( 18) completes the proof.

C.13 PROOF OF LEMMA 12

The proof in Section C.8 tells us that for T 1 ≤ τ ≤ T 2 , c (τ ) > 0, ∀j ∈ [d] : v (τ ) j > 0, which gives us 0 ≤ R (t+1)T 2 u (T 1 ) c (t+1) u (T 1 ) 2 and 0 ≤ R (t+1) 3j c (t+1) u (T 1 ) 2 v (t+1) j . By Lemma 11, we have that ∀1 ≤ i, j ≤ d : 0 ≤ R (t+1) 1 [i, j] u (T1) i v (t+1) j ≤ Õ( 0 ), 0 ≤ R (t+1) 2i c (t+1) u (T1) i ≤ Õ( 0 ), which gives us R (t+1)T 2 u (T1) c (t+1) u (T1) 2 ≤ d i=1 u (T1) i R (t+1) 2i c (t+1) d i=1 u (T1) i 2 ≤ Õ( 0 )c (t+1) d i=1 u (T1) i 2 c (t+1) d i=1 u (T1) i 2 = Õ( 0 ). Lemma 9 tells us that R (t+1)T 3 = c (t+1) u (T1)T R (t+1) 1 + R (t+1)T 2 R (t+1) 1 . And we have that c (t+1) u (T1)T R (t+1) 1 j c (t+1) u (T1) 2 v (t+1) j ≤ c (t+1) d i=1 u (T1) i R (t+1) 1 [i, j] c (t+1) d i=1 u (T1) i 2 v (t+1) j ≤ Õ( 0 )c (t+1) d i=1 u (T1) i 2 v (t+1) j c (t+1) d i=1 u (T1) i 2 v (t+1) j = Õ( 0 ), R (t+1)T 2 R (t+1) 1 j c (t+1) u (T1) 2 v (t+1) j ≤ d i=1 R (t+1) 2i R (t+1) 1 [i, j] c (t+1) d i=1 u (T1) i 2 v (t+1) j ≤ Õ 2 0 c (t+1) d i=1 u (T1) i 2 v (t+1) j c (t+1) d i=1 u (T1) i 2 v (t+1) j = Õ 2 0 . Therefore R (t+1) 3j c (t+1) u (T1) 2 v (t+1) j ≤ c (t+1) u (T1)T R (t+1) 1 j c (t+1) u (T1) 2 v (t+1) j + R (t+1)T 2 R (t+1) 1 j c (t+1) u (T1) 2 v (t+1) j ≤ Õ( 0 ). By Lemma 9, W (t+1) 2 W (t+1) 1 = c (t+1) u (T1) 2 v (t+1)T + R (t+1)T 2 u (T1) v (t+1)T + R (t+1)T 3 . Then we have that ∀j ∈ [d], W (t+1) 2 W (t+1) 1 j = c (t+1) u (T1) 2 v (t+1) j 1 + e (t+1) j , where e (t+1) j ≤ Õ( 0 ). ( ) Since t < T 2 , we have ∀j ∈ [d] : W (t+1) 2 W (t+1) 1 j Aj = O(1), which yields 0 ≤ c (t+1) u (T1) 2 v (t+1) j A i ≤ O(1), which proves eq. ( 21), since ∀j ∈ , we can apply the technique when proving eq. ( 24) to show that eq. ( 24) also holds at time t. Since [d] : 0 ≤ R (t+1) 3j c (t+1) u (T 1 ) 2 v (t+1) j ≤ Õ( 0 ). R (t) vj a (t) Aj ≤ Õ( 0 ), we get that v (t) j = a (t) A j + R (t) vj = a (t) A j 1 + e (t) vj , where e (t) vj ≤ Õ( 0 ). Substituting into the time t version of eq.( 24) yields ∀j ∈ [d] : W (t) 2 W (t) 1 j = a (t) c (t) u (T1) 2 A j 1 + ẽ(t) j , where ẽ(t) j ≤ Õ( 0 ), That gives us ∀j ∈ [d] : E (t) j = A j a (t) c (t) u (T1) 2 -1 + a (t) c (t) u (T1) 2 ẽ(t) j . Since t < T 2 , we have E (t) j < - √ 0 . Combining with A j = Θ(1), gives us a (t) c (t) u (T1) 2 -1 = -Ω √ 0 . Then we can rewrite E (t) j as ∀j ∈ [d], E (t) j = A j a (t) c (t) u (T1) 2 -1 1 + a (t) c (t) u (T1) 2 a (t) c (t) u (T1) 2 -1 ẽ(t) j := A j a (t) c (t) u (T1) 2 -1 1 + e (t)

Ej

, where e (t) Ej = Õ( √ 0 ). Hence ∀i, j ∈ [d] : E (t) j E (t) i = Θ(1). (B) Note that we assume ∀j ∈ [d] : v (t) j c (t) = Θ 1 √ d , then we have for j ∈ [d], -E (t) v (t) c (t) -E (t) j = d i=1 -E (t) i v (t) i c (t) -E (t) j = d i=1 E (t) i E (t) j • v (t) i c (t) = d i=1 Θ 1 √ d = Θ √ d , ⇒ c (t) -E (t) j -E (t) v (t) = Θ 1 √ d . Then for t + 1, we have that for j ∈ [d], v (t+1) j c (t+1) = v (t) j + η t c (t) -E (t) j c (t) + η t -E (t) v (t) = Θ 1 √ d . (C) Combining eq. ( 25) and ∀j ∈ [d] : A j = Θ(1), we know that a (t+1) Ai , we first prove that 1 -η t c (t) d (t) > 0. c (t+1) u (T1) 2 v (t+1) j ≤ O(1), which yields ∀j ∈ [d], u (T1) 2 v (t+1) j 2 ≤ v (t+1) j c (t+1) O(1) = O 1 √ d , c (t+1) 2 u (T1) 2 ≤ c (t+1) v (t+1) j O(1) = O √ d . (26) Hence ∀i, j ∈ [d], c (t+1) u (T1) i = O d 1/4 ⇒ w (t+1) 2i ≤ c (t+1) u (T1) i + R (t+1) 2i (i) = O d 1/4 , u (T1) i v (t+1) j = O 1 d 1/4 ⇒ W (t+1) 1 [i, j] ≤ u (T1) i v (t+1) j + R (t+1) 1 [i, j] (ii) = O 1 d 1/4 , It is not hard to prove that eq.( 26) also holds for time t. Recall that d (t) = c (t) u (T1) 2 + R (t)T 2 u (T1) and Lemma 12 tells us that 0 ≤ R (t)T 2 u (T 1 ) c (t) u (T 1 ) 2 ≤ Õ( 0 ), then we have c (t) d (t) = c (t) 2 u (T1) 2 + c (t) R (t)T 2 u (T1) ≤ O √ d . Under the conditions of Lemma 10, and pick η ≤ O d 7α 4 +4 , we have that 1 -η t c (t) d (t) ≥ 1 -ηc (t) d (t) > 0. The assumption ∀j ∈ [d] :

R (t) 3j

Aj ≤ Õ( 0 ) together with c (t) > 0 gives us η t c (t) R (t) 3j η t c (t) A j ≤ Õ( 0 ). Combining with the assumption R (t) vi a (t) Ai ≤ Õ( 0 ) yields ∀i ∈ [d] : R (t+1) vi a (t+1) A i ≤ 1 -η t c (t) d (t) R (t) v + η t c (t) R (t) 3i 1 -η t c (t) d (t) a (t) A i + η t c (t) A i ≤ Õ( 0 ).

D ANALYSIS OF ADAM

Note that A = 1 m Y X T , Λ xx := 1 m XX T . Denote g (t) k := ∇ W k L(W (t) ), k = 1, 2. We have that g (t) 1 = W (t)T 2 W (t) 2 W (t) 1 -A , g (t) 2 = W (t) 2 W (t) 1 -A W (t)T 1 . Let Ã(t) , Λ(t) xx and g(t) k , k = 1, 2 be the corresponding batch versions at time t. Let E (t) := W (t) 2 W (t) 1 -A, and denote E (t) j as the j-th component of E (t) . We also denote ∆w (t) 2i := w (t+1) 2i - w (t) 2i , ∆W (t) 1 [i, j] := W (t+1) 1 [i, j] -W (t) 1 [i, j] . By eq. ( 2), the update equations of Adam are given by η t = η • 1 -β t+1 2 1 -β t+1 1 , g (t) 1 [i, j] = w (t) 2i E (t) j , g (t) 2i = E (t) , W (t) 1 [i, :] , W (t+1) 1 [i, j] -W (t) 1 [i, j] = -η t m (t) 1 [i, j] v (t) 1 [i, j] = -η t (1 -β 1 ) t τ =0 β t-τ 1 g(τ) 1 [i, j] (1 -β 2 ) t τ =0 β t-τ 2 g(τ) 1 [i, j] 2 + ξ = -η t (1 -β 1 ) t τ =0 β t-τ 1 g (τ ) 1 [i, j] + r (t) 1n [i, j] (1 -β 2 ) t τ =0 β t-τ 2 g (τ ) 1 [i, j] 2 + r (t) 1d [i, j] + ξ , w (t+1) 2i -w (t) 2i = -η t m (t) 2i v (t) 2i = -η t (1 -β 1 ) t τ =0 β t-τ 1 g(τ) 2i (1 -β 2 ) t τ =0 β t-τ 2 g(τ) 2i 2 + ξ = -η t (1 -β 1 ) t τ =0 β t-τ 1 g (τ ) 2i + r (t) 2n,i (1 -β 2 ) t τ =0 β t-τ 2 g (τ ) 2i 2 + +r (t) 2d,i + ξ . ( ) where Dg (t) 1 := g(t) 1 -g (t) 1 and Dg (t) 2 := g(t) 2 -g (t) 2 , and r (t) 1n [i, j] := (1 -β 1 ) t τ =0 β t-τ 1 Dg (τ ) 1 [i, j], r (t) 1d [i, j] = (1 -β 2 ) t τ =0 β t-τ 2 2g (τ ) 1 [i, j]Dg (τ ) 1 [i, j] + Dg (τ ) 1 [i, j] 2 , r (t) 2n,i := (1 -β 1 ) t τ =0 β t-τ 1 Dg (τ ) 2i , r (t) 2d,i = (1 -β 2 ) t τ =0 β t-τ 2 2g (τ ) 2i Dg (τ ) 2i + Dg (τ ) 2i 2 . ( ) Denote the i-th coordinate of W 2 W 1 and A as (W 2 W 1 ) i and A i , respectively. By Assumption 2 and the assumption that ∀i ∈ [d] : A i > 0, A i = Ω(1), at the beginning, w.h.p., ∀i ∈ [d] : (W 2 W 1 ) i -A i < 0. Based on this, we divide the training procedure into two phases (note that these two phases are different from those of SGD+M). 1. First phase: when the error (W 2 W 1 ) i -A i is negative and its absolute value is big for all i ∈ [d]. 2. Second phase: when (W 2 W 1 ) i -A i is close to zero for some coordinate i ∈ [d]. More formally, we define the boundary between the two phases below. Definition 4 (End of the first phase). The end of the first phase (denoted as T 1 ) is defined as T 1 = inf t > 0 : ∃i ∈ [d] : E (t) i ≥ - √ ηd . In the second phase, we define some time points. Definition 5. Define T g := inf t > T 1 : ∃i ∈ [d] : g (t) 2i ≤ d √ η . For t < T 1 , we have ∀i ∈ [d] : E (t) i < 0 by Definition 4. For t > T 1 , some E (t) i may flip the sign and become positive. For certain coordinate i, we define the following "flip time". Definition 6. Define T f,i := inf t > T 1 : E (t) i ≥ - √ ηd . Define T f := max i T f,i as the largest "flip time" over all i ∈ [d], i.e. the "flip time" of the last E i which flips the sign. Moreover, denote T := min {T g , T f }. We can first show that after a few steps in the first phase, W 1 will become an approximately rank-1 matrix, as described in the following lemma. Lemma 14. Under Assumption 1, 2 and 3, suppose σ ≤ η 3/2 ξ 2 d 13/4 . By picking η ≤ O 1 d 3α , ξ ≤ η d 3α-1 , and β 2 = β 2 1 , there exists t inc > 0 such that w.h.p. for t inc ≤ t < T 1 , ∀i, j ∈ [d] : w (t) 2i = sign w (0) 2i η (t -t inc ) + R (t) 2i , W (t) 1 [i, j] = sign w (0) 2i η (t -t inc ) + R (t) 1 [i, j], where R (t) 1 [i,j] η(t-tinc) = Õ √ η + 1 η(t-tinc)d α , R (t) 2i η(t-tinc) = Õ √ η + 1 η(t-tinc)d α . Specially, when t = T 1 , we have that ∀i, j ∈ [d] : w (T1) 2i = sign w (0) 2i η (T 1 -t inc ) + R (T1) 2i , W (T1) 1 [i, j] = sign w (0) 2i η(T 1 -t inc ) + R (T1) 1 [i, j], where η (T 1 -t inc ) = Θ 1 √ d and R (T1) 1 [i, j] η (T 1 -t inc ) = Õ √ η + 1 d α-1 2 , R (T1) 2i η (T 1 -t inc ) = Õ √ η + 1 d α-1 2 . The following lemma tells us that this approximate rank-1 structure is preserved when T 1 ≤ t ≤ T . Lemma 15. Under Assumption 1, 2 and 3, suppose σ ≤ η 3/2 ξ 2 d 13/4 . By picking η ≤ O 1 d 3α , ξ ≤ η d 3α-1 , and β 2 = β 2 1 , we have w.h.p. for T 1 ≤ t < T , ∀i, j ∈ [d] : w (t) 2i = sign w (0) 2i c (t) + R (t) 2i , W (t) 1 [i, j] = sign w (0) 2i V (t) j + R (t) 1 [i, j], where R (t) 2i c (t) = Õ √ η + 1 d α-1/2 , R (t) 1 [i, j] V (t) j ≤ Õ η 1 4 + 1 d α 2 -1 4 , and that L W ( T ) ≤ Õ ηd 4 . Now we are ready to prove the Adam part of Theorem 1.

D.1 PROOF OF THE ADAM PART OF THEOREM 1

Define T Adam,1 = t inc + 1 ηd α 2 . Note that this choice of T Adam,1 gives η (T Adam,1 -t inc ) = 1 d α 2 . By picking η ≤ O 1 d 3α , ξ ≤ η d 3α-1 and β 2 = β 2 1 , we can apply Lemma 14 to get that ∀i, j ∈ [d] : w (TAdam,1) 2i = Θ 1 d α 2 , W (TAdam,1) 1 [i, j] = Θ 1 d α 2 , and therefore ∀i ∈ [d] : E (TAdam,1) i = -Θ(1) and L W (TAdam,1) = Θ(d). Define T Adam,2 = T . By Lemma 15, we have L W (TAdam,2) = Õ ηd 4 . For any p > 0, by picking α ≥ p+4 3 , we have L W (TAdam,2) = Õ ηd 4 ≤ Õ 1 d p . Moreover, combining Lemma 14 and 15, we get that when t ∈ [T Adam,1 , T Adam,2 ], the conditions in Lemma 30 are satisfied with δ = Õ η 1 4 + 1 d α 2 -1 4 . The i-th component of the u vector (denoted as u i ) is sign w (0) 2i . That means ∀i ∈ [d] : u 2 i = 1 and maxi(ui) 2 median(ui) 2 = 1. Then we can apply Lemma 30 and get that R Adam med,1 (t), R Adam med,2 (t) ∈ 1 -δ 1 + δ 2 max i (u i ) 2 median(u i ) 2 , 1 + δ 1 -δ 2 max i (u i ) 2 median(u i ) 2 = 1 -δ 1 + δ 2 , 1 + δ 1 -δ 2 , ⇒ R Adam med,1 (t), R Adam med,2 (t) = 1 ± O(δ) = 1 ± Õ η 1 4 + 1 d α 2 -1 4 .

D.2 PROOF OF LEMMA 14

For some time t, we introduce two conditions. Condition 1. ∀τ ∈ [H] : sign g (t-τ ) 1 [i, j] = s (t) 1 [i, j], (1 -β 1 ) H τ =0 β (τ ) 1 g (t-τ ) 1 [i, j] ≥ Ω(ξ). Condition 2. ∀τ ∈ [H] : sign g (t-τ ) 2i = s (t) 2i , (1 -β 1 ) H τ =0 β (τ ) 1 g (t-τ ) 2i ≥ Ω(ξ). Next prove that, under Assumption 1 and 2, by picking η ≤ O 1 d 3α , ξ ≤ η d 3α-1 , and β 2 = β 2 1 , there exists t inc > 0 such that for t inc ≤ t < T 1 , the weights can be approximated in the following way. W (t+1) 1 [i, j] = W (t) 1 [i, j] -η sign g (t) 1 [i, j] + e (t) 1 [i, j] , w (t+1) 2i = w (t) 2i -η sign g (t) 2i + e (t) 2i , where e (t) 1 [i, j] = Õ √ η , e (t) 2i = Õ √ η . Before we dive into the proof, let's introduce some useful lemmas. The following lemma reflects our key idea: converting the exponential average in Adam to a finite-step average, and trying to bound the stochastic error terms in eq. ( 28). Lemma 16. Under Assumption 1, 2 and 3 and pick β 2 = β 2 1 . Let M (t) 1 := max i,j∈[d],τ ≤t W (τ ) 1 [i, j] , M (t) 2 := max i,j∈[d],τ ≤t w (τ ) 2i , G (t) 1 := max i,j∈[d],τ ≤t g (τ ) 1 [i, j] and G (t) 2 := max i,j∈[d],τ ≤t g (τ ) 2i . We have that w.h.p., for all t ≤ Õ 1 √ dη and ∀i, j ∈ [d], ∆W (t) 1 [i, j] = -η t (1 -β 1 ) H τ =0 β τ 1 g (t-τ ) 1 [i, j] + (t) 1n [i, j] (1 -β 2 ) H τ =0 β τ 2 g (t-τ ) 1 [i, j] 2 + (t) 1d [i, j] + ξ , ∆w (t) 2i = -η t (1 -β 1 ) H τ =0 β τ 1 g (t-τ ) 2i + (t) 2n,i (1 -β 2 ) H τ =0 β τ 2 g (t-τ ) 2i 2 + (t) 2d,i + ξ , where H ≥ 1 1-β1 log max G (t) 1 ,G (t) 2 , G (t) 1 2 , G (t) 2 2 ηξ 2 and (t) 1n [i, j] ≤ O(ηξ 2 ) + O D (t) 1 , (t) 1d [i, j] ≤ O(ηξ 2 ) + O D (t) 1 G (t) 1 + D (t) 1 2 , (t) 2n,i ≤ O(ηξ 2 ) + O D (t) 2 , (t) 2d,i ≤ O(ηξ 2 ) + O D (t) 2 G (t) 2 + D (t) 2 2 , with D (t) 1 ≤ Õ   d 3 M (t) 1 M (t) 2 2 σ d 1/2 η   + Õ   M (t) 2 σ d 3/2 η   , D (t) 2 ≤ Õ   d 4 M (t) 1 2 M (t) 2 σ d 1/2 η   + Õ   dM (t) 1 σ d 3/2 η   . Corollary 2. Under the conditions of Lemma 16 and suppose σ  ≤ η 3/2 ξ 2 d 13/4 . Consider any t ≤ Õ 1 √ dη . If M (t) 1 , M (t) 2 ≤ Õ 1 √ d , G (t) 1 ≤ Õ 1 √ d , G (t) 2 ≤ Õ √ d , [i, j] , (t) 1d [i, j] , (t) 2n,i , (t) 2d,i ≤ Õ(ηξ 2 ). The following lemma analyzes the magnitude of weights during a short period at the beginning. Lemma 17. Under Assumption 1, 2 and 3, suppose σ ≤ η 3/2 ξ 2 d 13/4 . Pick ξ ≤ 1 d 3 2 α , then there exists some time point t inc ∈ (H, T 1 ), such that w.h.p., for t ≤ t inc , for every i, j ∈ [d], ∆W (t) 1 [i, j] ≤ Õ(η), ∆w (t) 2i ≤ Õ(η), W (t) 1 [i, j] ≤ O 1 d 3 2 α+1 , Ω 1 d 3 2 α ≤ w (t) 2i ≤ O 1 d α . Specifically, when t = t inc , we have sign w (tinc) 2i = sign W (tinc) 1 [i, j] = sign w (0) 2i , W (tinc) 1 [i, j] = Θ 1 d 3 2 α+1 and g (tinc) 1 [i, j] ≥ Ω(ξ), g (tinc) 2i ≥ Ω(ξ). Moreover, Condition 1 and 2 are satisfied for t = t inc . The s (t) 1 [i, j] and s (t) 2i in the conditions are both -sign w (0) 2i . The following lemma gives us lower bounds of g (t) 1 [i, j] and g (t) 2i . Lemma 18. Under Assumption 1, 2 and 3, suppose σ ≤ η 3/2 ξ 2 d 13/4 . Pick ξ ≤ η d 3α-1 , η ≤ O 1 d 3α . Consider t inc in Lemma 17. We have w.h.p. for any t ∈ [t inc , T 1 ), and for ∀i, j ∈ [d], sign ∆W (t) 1 [i, j] = sign ∆w (t) 2i = sign w (0) 2i and that ∀i, j ∈ [d] : g (t) 1 [i, j] ≥ Ω √ η , g (t) 2i ≥ Ω √ ηd . Moreover, we have ∀τ ≤ t, ∀i, j ∈ [d] : W (τ ) 1 [i, j] ≤ Õ 1 √ d , w (τ ) 2i ≤ Õ 1 √ d and g (τ ) 1 [i, j] ≤ Õ 1 √ d , g (τ ) 2i ≤ Õ √ d . The following lemma shows that when t inc ≤ t < T 1 , we have ∀i, j ∈ [d] : g (t) 2i g (t) 2i -g (t-1) 2i and that g (t) 1 [i, j] g (t) 1 [i, j] -g (t-1) 1 [i, j] . Lemma 19. Under Assumption 1, 2 and 3, suppose σ ≤ η 3/2 ξ 2 d 13/4 . Pick ξ ≤ η d 3α-1 , η ≤ O 1 d 3α . For t inc in Lemma 17, we have that w.h.p. for t inc ≤ t < T 1 and τ ≤ t, ∀i, j ∈ [d], g (t) 1 [i, j] -g (t-τ ) 1 [i, j] g (t) 1 [i, j] = Õ ( √ ητ ) , g 2i -g (t-τ ) 2i g (t) 2i = Õ ( √ ητ ) , g (t) 1 [i, j] 2 -g (t-τ ) 1 [i, j] 2 g (t) 1 [i, j] 2 = Õ ( √ ητ ) + Õ ητ 2 , g (t) 2i 2 -g (t-τ ) 2i 2 g (t) 2i 2 = Õ ( √ ητ ) + Õ ητ 2 . ( ) Equipped with these lemmas, now let's prove eq. ( 29). For any t ∈ [t inc , T 1 ), by Lemma 18, we know that M (t) 1 , M (t) 2 ≤ Õ 1 √ d , and that G (t) 1 ≤ Õ 1 √ d , G (t) 2 ≤ Õ √ d . At the end of the proof for this lemma, we will show that T 1 = Θ 1 √ dη . Then we can pick H := 1 1-β1 log d ηξ 2 and apply Lemma 16 and Corollary 2 to get that, w.h.p., for all t ∈ [t inc , T 1 ) and ∀i, j ∈ [d], eq. ( 27) can be written as ∆W (t) 1 [i, j] = -η t (1 -β 1 ) H τ =0 β τ 1 g (t-τ ) 1 [i, j] + (t) 1n [i, j] (1 -β 2 ) H τ =0 β τ 2 g (t-τ ) 1 [i, j] 2 + (t) 1d [i, j] + ξ , ∆w (t) 2i = -η t (1 -β 1 ) H τ =0 β τ 1 g (t-τ ) 2i + (t) 2n,i (1 -β 2 ) H τ =0 β τ 2 g (t-τ ) 2i 2 + (t) 2d,i + ξ , ( ) where ∀i, j ∈ [d], (t) 1n [i, j] , (t) 1d [i, j] , (t) 2n,i , (t) 2d,i ≤ Õ(ηξ 2 ). Let's first look at the update of W (t) 1 [i, j]. For t in the first phase, we write the RHS of eq. ( 32) as (1 -β 1 ) H τ =0 β τ 1 g (t-τ ) 1 [i, j] + (t) 1n [i, j] (1 -β 2 ) H τ =0 β τ 2 g (t-τ ) 1 [i, j] 2 + (t) 1d [i, j] + ξ = (1 -β 1 )g (t) 1 [i, j] H τ =0 β τ 1 + (1 -β 1 ) H τ =0 β τ 1 g (t-τ ) 1 [i, j] -g (t) 1 [i, j] + (t) 1n [i, j] (1 -β 2 ) g (t) 1 [i, j] 2 H τ =0 β τ 2 + (1 -β 2 ) H τ =0 β τ 2 g (t-τ ) 1 [i, j] 2 -g (t) 1 [i, j] 2 + (t) 1d [i, j] + ξ := g (t) 1 [i, j](1 -β H+1 1 ) + e (t) 1n [i, j] + (t) 1n [i, j] g (t) 1 [i, j] 2 (1 -β H+1 2 ) + e (t) 1d [i, j] + (t) 1d [i, j] + ξ , where e (t) 1n [i, j] := (1 -β 1 ) H τ =0 β τ 1 g (t-τ ) 1 [i, j] -g (t) 1 [i, j] , e (t) 1d [i, j] := (1 -β 2 ) H τ =0 β τ 2 g (t-τ ) 1 [i, j] 2 -g (t) 1 [i, j] 2 . We have already shown that (t) 1n [i, j] , (t) 1d [i, j] ≤ Õ(ηξ 2 ). By Lemma 19, we have that ∀i, j ∈ [d], e (t) 1n [i, j] ≤ (1 -β 1 ) H τ =0 β τ 1 g (t-τ ) 1 [i, j] -g (t) 1 [i, j] ≤ g (t) 1 [i, j] Õ ( √ η) (1 -β 1 ) H τ =0 β τ 1 τ = g (t) 1 [i, j] Õ ( √ η) . Similarly, we have ∀i, j ∈ [d], e (t) 1d [i, j] ≤ g (t) 1 [i, j] 2 Õ ( √ η) (1 -β 2 ) H τ =0 β τ 1 τ + g (t) 1 [i, j] 2 Õ ( √ η) (1 -β 2 ) H τ =0 β τ 1 τ 2 = g (t) 1 [i, j] 2 Õ ( √ η) . By Lemma 18, we know that g (t) 1 [i, j] = Ω √ η . Then we have that ∀i, j ∈ [d] : (t) 1n [i, j] ≤ Õ(ηξ 2 ) ≤ Õ ( √ η) g (t) 1 [i, j] , (t) 1d [i, j] ≤ Õ ( √ η) ξ 2 . Therefore by Lemma 33 in Appendix G, we have g (t) 1 [i, j] 1 -β H+1 1 + e (t) 1n [i, j] + (t) 1n [i, j] g (t) 1 [i, j] 2 1 -β H+1 2 + e (t) 1d [i, j] + (t) 1d [i, j] + ξ = 1 -β H+1 1 1 -β H+1 2 sign g (t) 1 [i, j] + ẽ(t) 1 [i, j] , where ẽ(t) 1 [i, j] = Õ √ η . Since β ∈ (0, 1), we know that log β ≤ β -1 < 0. Then our choice of H gives us H = 1 1-β1 log d ηξ 2 ≥ log ηξ 2 d log β1 and H > 1 1-β2 log d ηξ 2 ≥ log ηξ 2 d log β2 , which implies that β H 1 , β H 2 ≤ ηξ 2 /d. Hence for t ≥ t inc > H, η t 1-β H+1 1 √ 1-β H+1 2 = η √ 1-β t+1 2 √ 1-β H+1 2 1-β H+1 1 1-β t+1 1 = η(1 ± O(η)). Combining all of the above yields that W (t+1) 1 [i, j] = W (t) 1 [i, j] -η t (1 -β 1 ) t τ =0 β τ 1 g (t-τ ) 1 [i, j] (1 -β 2 ) t τ =0 β τ 2 g (t-τ ) 1 [i, j] 2 + ξ = W (t) 1 [i, j] -η t 1 -β H+1 1 1 -β H+1 2 sign g (t) 1 [i, j] + ẽ(t) 1 [i, j] = W (t) 1 [i, j] -η sign g (t) 1 [i, j] + e (t) 1 [i, j] , where e (t) 1 [i, j] = Õ √ η . The proof for w (t) 2i is similar. So far we have successfully proved eq. ( 29). By sign ∆W (t) 1 [i, j] = sign ∆w (t) 2i = sign w (0) 2i in Lemma 18, we know that sign -g (t) 1 [i, j] = sign -g (t) 2i = sign w (0) 2i , which gives us ∀i, j ∈ [d] : w (t) 2i = sign w (0) 2i η (t -t inc ) + R (t) 2i , W (t) 1 [i, j] = sign w (0) 2i η (t -t inc ) + R (t) 1 [i, j], where R (t) 1 [i,j] η(t-tinc) = Õ √ η + W (t inc ) 1 [i,j] η(t-tinc) and R (t) 2i η(t-tinc) = Õ √ η + w (t inc ) 2i η(t-tinc) . Now it suffices to show that ∀i, j ∈ [d] : w (tinc) 2i ≤ O 1 d α , W (tinc) 1 [i, j] ≤ O 1 d α , which is implied by Lemma 17. Finally to complete the proof, we show that T 1 = Θ 1 √ dη . When t = T 1 , we have ∀j ∈ [d] : d i=1 w (T1) 2i W (T1) 1 [i, j] = Θ(1) . Combining with the above results, we know that dη 2 (T 1 -t inc ) 2 = Θ(1), i.e. η(T 1 -t inc ) = Θ 1 √ d . In Section D.5, we will prove t inc = Θ 1 ηd 3 2 α+1 . Then we have T 1 = Θ 1 √ dη .

D.3 PROOF OF LEMMA 16

For certain t and H, we write eq. ( 27) as ∆W (t) 1 [i, j] = -η t (1 -β 1 ) H τ =0 β τ 1 g (t-τ ) 1 [i, j] + (t) 1n [i, j] (1 -β 2 ) H τ =0 β τ 2 g (t-τ ) 1 [i, j] 2 + (t) 1d [i, j] + ξ , ∆w (t) 2i = -η t (1 -β 1 ) H τ =0 β τ 1 g (t-τ ) 2i + (t) 2n,i (1 -β 2 ) H τ =0 β τ 2 g (t-τ ) 2i 2 + (t) 2d,i + ξ , where (t) 1n [i, j] := (1 -β 1 ) t τ =H+1 β τ 1 g (t-τ ) 1 [i, j] :=q (t) 1n [i,j] +r (t) 1n [i, j], (t) 2n,i := (1 -β 1 ) t τ =H+1 β τ 1 g (t-τ ) 2i :=q (t) 2n,i +r (t) 2n,i , (t) 1d [i, j] := (1 -β 2 ) t τ =H+1 β τ 2 g (t-τ ) 1 [i, j] 2 :=q (t) 1d [i,j] +r (t) 1d [i, j], (t) 2d,i := (1 -β 2 ) t τ =H+1 β τ 2 g (t-τ ) 2i 2 :=q (t) 2d,i +r (t) 2d,i , r (t) 1n [i, j], r (t) 1d [i, j], r (t) 2n,i , r 2d,i are defined in eq. ( 28). Since β 2 = β 2 1 < β 1 , then if we pick H ≥ 1 1-β1 log max G (t) 1 ,G (t) 2 , G (t) 1 2 , G (t) 2 2 ηξ 2 , we can get that H ≥ 1 1-β1 log G (t) 1 ηξ 2 , H ≥ 1 1-β2 log G (t) 1 2 ηξ 2 , H ≥ 1 1-β1 log G (t) 2 ηξ 2 , H ≥ 1 1-β2 log G (t) 2 2 ηξ 2 . Hence we can apply Lemma 32 in Appendix G to get that q (t) 1n [i, j] , q (t) 1d [i, j] , q (t) 2n,i , q (t) 2d,i ≤ ηξ 2 . Pick T in Lemma 31 as of order Õ 1 √ dη . By Lemma 31, we have with probability at least 1 -1 d , for all t ≤ T , ∀τ ≤ t and ∀i, j ∈ [d], Dg (τ ) 1 [i, j] = g(τ) 1 [i, j] -g (τ ) 1 [i, j] ≤ Õ   d 3 M (t) 1 M (t) 2 2 σ d 1/2 η   + Õ   M (t) 2 σ d 3/2 η   := D (t) 1 , Dg (τ ) 2i = g(τ) 2i -g (τ ) 2i ≤ Õ   d 4 M (t) 1 2 M (t) 2 σ d 1/2 η   + Õ   dM (t) 1 σ d 3/2 η   := D (t) 2 . Plugging into eq. ( 28) gives us r (t) 1n [i, j] ≤ (1 -β 1 ) t τ =0 β t-τ 1 Dg (τ ) 1 [i, j] ≤ O D (t) 1 , r (t) 1d [i, j] ≤ (1 -β 2 ) t τ =0 β t-τ 2 2g (τ ) 1 [i, j]Dg (τ ) 1 [i, j] + Dg (τ ) 1 [i, j] 2 ≤ O D (t) 1 G (t) 1 + D (t) 1 2 , r (t) 2n,i ≤ (1 -β 1 ) t τ =0 β t-τ 1 Dg (τ ) 2i ≤ O D (t) 2 , r (t) 2d,i ≤ (1 -β 2 ) t τ =0 β t-τ 2 2g (τ ) 2i Dg (τ ) 2i + Dg (τ ) 2i 2 ≤ O D (t) 2 G (t) 2 + D (t) 2 2 . D.4 PROOF OF COROLLARY 2 Since G (t) 1 ≤ Õ 1 √ d , G (t) 2 ≤ Õ √ d , then H := 1 1-β1 log d ηξ 2 is bigger than 1 1-β1 log max G (t) 1 ,G (t) 2 , G (t) 1 2 , G (t) 2 2 ηξ 2 . By M (t) 1 , M (t) 2 ≤ Õ 1 √ d , G (t) 1 ≤ Õ 1 √ d , G 2 ≤ Õ √ d and the assumption σ ≤ η 3/2 ξ 2 d 13/4 , we get that D  (t) 1 ≤ Õ d 7/4 ση -1/2 and D (t) 2 ≤ Õ d 11/4 ση -1/2 , which yields ∀i, j ∈ [d], (t) 1n [i, j] , (t) 1d [i, j] , (t) 2n,i , (t) 2d,i ≤ Õ(ηξ 2 ).

D.5 PROOF OF LEMMA 17

The proof is based on the following two lemmas. Lemma 20. Under Assumption 1 and 2, we have that w.p. at least 1 -1 d α 2 -1 , for every 1 ≤ i ≤ d, √ π d 3 2 α ≤ w (0) 2i ≤ 2 d 2α log 2d δ , and that w.p. at least 1 -δ for any given δ > 0, W 1 [i, j] ≤ 2 d 4α log 2d 2 δ . Lemma 21. Under Assumption 1, 2 and 3, suppose σ ≤ η 3/2 ξ 2 d 13/4 . Pick β 2 = β 2 1 , ξ ∈ (0, 1), η < 1 4 . Consider any time point t ≤ Õ 1 √ dη . If ∀τ ≤ t, ∀i, j ∈ [d] : W (τ ) 1 [i, j] ≤ Õ 1 √ d , w (τ ) 2i ≤ Õ 1 √ d and g (τ ) 1 [i, j] ≤ Õ 1 √ d , g (τ ) 2i ≤ Õ √ d , we will have ∆W (t) 1 [i, j] ≤ Õ(η) ∆w (t) 2i ≤ Õ(η), where the Õ notation depends on H = 1 1-β1 log d ηξ 2 . Furthermore, if for certain i, j ∈ [d], Condition 1 (resp. Condition 2) is satisfied, we will have sign ∆W (t) 1 [i, j] = -s (t) 1 [i, j], ∆W (t) 1 [i, j] = Θ(η) resp. sign ∆w (t) 2i = -s (t) 2i , ∆w (t) 2i = Θ(η) . Now we prove Lemma 17. Define t d := inf t : ∃i, j : W (t) 1 [i, j] > 1 d or w (t) 2i > 1 d . Now we want to find a time point t inc before t d for the lemma to hold. During the period t < t d , we have ∀j ∈ [d], E j = -Θ(1) (which means t d < T 1 ) and therefore for all i, j ∈ [d], g (t) 1 [i, j] ≤ 1 d and g (t) 2i ≤ 1. Then we can use Lemma 21 to get that for t ≤ min t d , 1 √ dη , we have ∆W (t) 1 [i, j] ≤ Õ(η), ∆w (t) 2i ≤ Õ(η). Hence t d ≥ Ω 1 ηd . Define t sign = inf t < min t d , 1 √ dη : ∃i ∈ [d] : w (t) 2i ≤ 1 d 3 2 α . By Lemma 20, w.h.p. ∀i ∈ [d] : w (0) 2i ≥ √ π d 3 2 α , combining with ∆w (t) 2i ≤ Õ(η) gives us that w.h.p., t sign ≥ √ π-1 d 3 2 α / Õ(η) = Ω 1 ηd 3 2 α . Now let's analyze the behavior of W 1 during the period t < t sign . Consider any i, j ∈ [d]. By definition, sign w (t) 2i = sign w (0) 2i . Note that E (t) j = -Θ(1), then we have sign g (t) 1 [i, j] = -sign w (0) 2i and that g (t) 1 [i, j] = Ω 1 d 3 2 α = Ω(ξ) by our choice of ξ. Then we know that Condition 1 is satisfied with s (t) 1 [i, j] = -sign w (0) 2i (for all H < t ≤ t sign ), which by Lemma 21 yields sign ∆W (t) 1 [i, j] = sign w (0) 2i and ∆W (t) 1 [i, j] = Θ(η). Lemma 20 tells us that w.h.p., ∀i, j ∈ [d] : W (0) 1 [i, j] = Õ 1 d 2α . For any i, j, if ini- tially sign W (0) 1 [i, j] = sign w (0) 2i , then for the following steps before t sign , we will have sign W (t) 1 [i, j] = sign w (0) 2i . If initially sign W (0) 1 [i, j] = sign w (0) 2i , then after at most t 0 = Õ 1 ηd 2α steps, W 1 [i, j] will flip the sign. Note that t 0 = Õ 1 ηd 2α is smaller than t sign . Hence we have shown that at some time point t 0 , we have ∀i, j ∈ [d] : sign W (t) 1 [i, j] = sign w (t) 2i = sign w (0) 2i . Now we analyze the period t ≥ t 0 . When t 0 < t ≤ t sign , we still have sign ∆W (t) 1 [i, j] = sign w (0) 2i and ∆W (t) 1 [i, j] = Θ(η). Combining these two with the fact sign W (t0) 1 [i, j] = sign w (0) 2i , we know that for all t ∈ [t 0 , t sign ], sign W (t) 1 [i, j] = sign w (0) 2i and that ∀i, j ∈ [d] : W (t+1) 1 [i, j] = W (t) 1 [i, j] + Θ(η). Then at certain step t inc which satisfies t inc = t 0 + Θ 1 ηd 3 2 α+1 ∈ (H, t sign ), we will have ∀t inc -H ≤ τ ≤ t inc , ∀i, j ∈ [d] : W (τ ) 1 [i, j] = Θ 1 d 3 2 α+1 and therefore g (τ ) 2i = d j=1 W (τ ) 1 [i, j]E (τ ) j = d j=1 W (τ ) 1 [i, j]E (τ ) j = Θ 1 d 3 2 α = Ω(ξ). For t ≤ t inc , we have ∀i, j ∈ [d] : W (t) 1 [i, j] = O 1 d 3 2 α+1 . Since t inc < t sign , we have w tinc 2i = Ω 1 d 3 2 α . For t ≤ t inc , note that ∆w (t) 2i ≤ Õ(η), t inc = t 0 + Θ 1 ηd 3 2 α+1 = Θ 1 ηd 3 2 α+1 , combining with the upper bound in Lemma 20 yields w (t) 2i ≤ w (0) 2i + t inc Õ(η) ≤ Õ 1 d 3 2 α+1 ≤ O 1 d α . Moreover, ∀t inc -H ≤ τ ≤ t inc , ∀i ∈ [d] : sign g (τ ) 2i = -sign w (0) 2i . Then Condition 2 is satisfied with s (t) 2i = -sign w (0) 2i for t = t inc . In the analysis of g (t) 1 [i, j], we have already shown that for all t ≤ t sign (and thus for t = t inc ), Condition 1 is satisfied, which completes the proof.

D.6 PROOF OF LEMMA 20

Since for X ∼ N 0, σ 2 , we have that P (|X| ≤ t) ≤ 2t √ 2πσ , then for a fixed i, P w (0) 2i ≤ √ π d 3 2 α ≤ 2 √ π/d 3 2 α √ 2π • 2/d 2α = 1 d α 2 . Then by union bound, we have that w.p. at least 1 -1 d α 2 -1 , for every 1 ≤ i ≤ d, w √ π d 3 2 α . As for the upper bounds, using the Gaussian tail bound and union bound, we have w.p. at least 1 -δ, ∀i, j ∈ [d] : w (0) 2i ≤ 2 d 2α log 2d δ , W 1 [i, j] ≤ 2 d 4α log 2d 2 δ .

D.7 PROOF OF LEMMA 21

Now we analyze the magnitude order of ∆W (t) 1 [i, j]. The analysis of ∆w (t) 2i is similar. For t ≤ Õ 1 √ dη . By assumption, M (t) 1 , M (t) 2 ≤ Õ 1 √ d , G (t) 1 ≤ Õ 1 √ d , G (t) 2 ≤ Õ √ d , and σ ≤ η 3/2 ξ 2 d 13/4 . Hence we can pick H := 1 1-β1 log d ηξ 2 and apply Lemma 16 and Corollary 2 to get that, w.h.p., for all t ≤ Õ 1 √ dη and ∀i, j ∈ [d], eq. ( 27) can be written as ∆W (t) 1 [i, j] = -η t (1 -β 1 ) H τ =0 β τ 1 g (t-τ ) 1 [i, j] + (t) 1n [i, j] (1 -β 2 ) H τ =0 β τ 2 g (t-τ ) 1 [i, j] 2 + (t) 1d [i, j] + ξ , ∆w (t) 2i = -η t (1 -β 1 ) H τ =0 β τ 1 g (t-τ ) 2i + (t) 2n,i (1 -β 2 ) H τ =0 β τ 2 g (t-τ ) 2i 2 + (t) 2d,i + ξ , where ∀i, j ∈ [d], (t) 1n [i, j] , (t) 1d [i, j] , 2d,i ≤ Õ(ηξ 2 ). On one hand, using (t) 1n [i, j] , 1d [i, j] ≤ Õ(ηξ 2 ) and β 2 = β 2 1 , and √ x + y ≥ √ x -|y| when x ≥ 0, x + y ≥ 0, we get from eq. ( 33) that ∆W (t) 1 [i, j] ≤ η t (1 -β 1 ) H τ =0 β τ 1 g (t-τ ) 1 [i, j] + Õ(ηξ 2 ) (1 -β 2 ) H τ =0 β τ 1 g (t-τ ) 1 [i, j] 2 -Õ( √ ηξ) + ξ (i) ≤ η t (1 -β 1 ) √ H + 1 H τ =0 β τ 1 g (t-τ ) 1 [i, j] 2 + Õ(ηξ 2 ) (1 -β 2 ) H τ =0 β τ 1 g (t-τ ) 1 [i, j] 2 + ξ/2 ≤ O √ Hη = Õ(η), where (i) uses Cauchy-Schwarz inequality for the numerator. On the other hand, when sign g (t-H) 1 [i, j] = sign g (t-H+1) 1 [i, j] = ... = sign g (t) 1 [i, j] = s (t) 1 [i, j], we have sign H τ =0 β τ 1 g (t-τ ) 1 [i, j] = s (t) 1 [i, j], H τ =0 β τ 1 g (t-τ ) 1 [i, j] ≥ H τ =0 β τ 1 g (t-τ ) 1 [i, j] 2 . If we further have (1 -β 1 ) H τ =0 β (τ ) 1 g (t-τ ) 1 [i, j] ≥ Ω(ξ), then combining with (t) 1n [i, j] ≤ Õ(ηξ 2 ) < ξ we will get sign ∆W (t) 1 [i, j] = -sign H τ =0 β τ 1 g (t-τ ) 1 [i, j] + (t) 1n [i, j] = -sign H τ =0 β τ 1 g (t-τ ) 1 [i, j] = -s (t) 1 [i, j]. Using √ x + y ≤ |x| + |y|, we obtain that ∆W (t) 1 [i, j] ≥ η t (1 -β 1 ) H τ =0 β (τ ) 1 g (t-τ ) 1 [i, j] -Õ(ηξ 2 ) (1 -β 2 ) H τ =0 β τ 1 g (t-τ ) 1 [i, j] 2 + Õ( √ ηξ) + ξ ≥ η t 1-β1 2 H τ =0 β (τ ) 1 g (t-τ ) 1 [i, j] 2 max (1 -β 2 ) H τ =0 β τ 1 g (t-τ ) 1 [i, j] 2 , 3 2 ξ = Ω(η). Together with the upper bound completes the proof.

D.8 PROOF OF LEMMA 18

The proof is based on the following lemma, which gives a coarse analysis on the magnitude of weights and their increments per step during the first phase. Lemma 22. Under Assumption 1, 2 and 3, suppose σ ≤ η 3/2 ξ 2 d 13/4 . Pick ξ ≤ min η d 3α-1 , 1 d 3 2 α , for t inc in Lemma 17, we have that w.h.p. for all t inc ≤ t ≤ T 1 , ∀i, j ∈ [d]. sign ∆W (t) 1 [i, j] = sign ∆w (t) 2i = sign w (0) 2i , ∆W (t) 1 [i, j] = Θ(η), ∆w (t) 2i = Θ(η), sign W (t) 1 [i, j] = sign w (t) 2i = sign w (0) 2i , W (t) 1 [i, j] = Õ 1 √ d , w (t) 2i = Õ 1 √ d . Specially, at the end of the first phase (t = T 1 ), we have ∀i, j ∈ [d], w (T1) 2i = Θ 1 √ d and W (T1) 1 [i, j] = Θ 1 √ d . Now we go back to the proof of Lemma 18. For t inc ≤ t < T 1 , since E (t) j = W (t) 2 W (t) 1 j -A j = d i=1 w (t) 2i W (t) 1 [i, j] -A j , we have, ∆E (t) j := E (t+1) j -E (t) j = d i=1 w (t+1) 2i W (t+1) 1 [i, j] -w (t+1) 2i W (t) 1 [i, j] + w (t+1) 2i W (t) 1 [i, j] -w (t) 2i W (t) 1 [i, j] = d i=1 w (t+1) 2i ∆W (t) 1 [i, j] + ∆w (t) 2i W (t) 1 [i, j] . Combining Lemma 22 and eq. ( 34) gives us ∀j ∈ [d], ∆E (t) j > 0, ∆E (t) j = d i=1 w (t+1) 2i ∆W (t) 1 [i, j] + ∆w (t) 2i W (t) 1 [i, j] ≤ d i=1 Õ η 1 √ d = Õ η √ d . Let's first analyze g (t) 1 [i, j]. Note that ∆g (t) 1 [i, j] = w (t+1) 2i E (t+1) j -w (t+1) 2i E (t) j + w (t+1) 2i E (t) j -w (t) 2i E (t) j = w (t+1) 2i ∆E (t) j + ∆w (t) 2i E (t) j , where sign w (t+1) 2i ∆E (t) j = sign w (0) 2i while sign ∆w (t) 2i E (t) j = -sign w (0) 2i . Now we analyze the sign of g (t) 1 [i, j] when t inc ≤ t < T 1 . Using w (tinc+1) 2i = Õ 1 d α and eq. ( 35), we get that w (tinc+1) 2i ∆E (tinc) j ≤ Õ 1 d α • √ dη . While on the other hand, ∆w (tinc) j

= Θ(η).

That means sign ∆g (tinc) 1 [i, j] = -sign w (0) 2i . Note that sign g (tinc) 1 [i, j] = -sign w (tinc) 2i = -sign w (0) 2i , we know that g (t) 1 [i, j] will increase when t = t inc . In the following steps, g 1 [i, j] will keep increasing as long as ∆w (t) 2i E (t) j > w (t+1) 2i ∆E (t) j . Since W (t) 1 [i, j] , w (t) 2i keep increasing while ∆W (t) 1 [i, j] , ∆w 2i remain Θ(η), by eq. ( 35), we know that the trend of ∆E (t) j is to increase. On the other hand, E (t) j keeps decreasing since E (t) j < 0 while ∆E (t) j > 0. Then after some time point we will have ∆w (t) 2i E (t) j < w (t+1) 2i ∆E (t) j and in the following steps g (t) 1 [i, j] will have the trend to decrease. Specially, when t = T 1 -1, we have E (t) j = Θ √ ηd and W (t) 1 = Θ 1 √ d , w = Θ 1 √ d by Lemma 22, which gives us ∆E (t) j = d i=1 w (t+1) 2i ∆W (t) 1 [i, j] + ∆w (t) 2i W (t) 1 [i, j] ≤ d i=1 Θ η 1 √ d = Θ η √ d . Hence w (t+1) 2i ∆E (t) j = Θ(η) > ∆w (t) 2i E (t) j = Θ η √ ηd . Therefore we have proved that when t inc ≤ t < T 1 , the trend of g (t) 1 [i, j] is to first increase and then decrease. In order to prove g (t) 1 [i, j] = Ω √ η , it suffices to show that g (tinc) 1 [i, j] = Ω √ η and g (T1) 1 [i, j] = Ω √ η . When t = t inc , g (tinc) 1 [i, j] = w (tinc) 2i • E (tinc) j = Ω 1 d 3 2 α • Θ(1) = Ω ( √ η) . When t = T 1 , we have g (T1) 1 [i, j] = w (T1) 2i • E (t1) j = Θ 1 √ d • ηd = Θ ( √ η) . As for g (t) 2i , since for ∀i ∈ [d], W 1 [i, j] for different j have the same sign. Combining with ∀j ∈ [d] : E (t) j < 0 gives us g (t) 2i = d j=1 E (t) j W (t) 1 [i, j] = d j=1 E (t) j W (t) 1 [i, j] . Then it suffices to show that for t inc ≤ t < T 1 , E (t) j W (t) 1 [i, j] = Ω √ η , which can be proven using the same technique as above. Finally, for ∀τ ≤ t, ∀i, j ∈ [d], note that the upper bounds of W (τ ) 1 [i, j] and w (τ ) 2i are already given in Lemma 22. As for g (τ ) 1 [i, j] and g (τ ) 2i , we have g (τ ) 1 [i, j] = w (τ ) 2i E (τ ) j = Õ 1 √ d , g (τ ) 2i ≤ d j=1 E (τ ) j W (τ ) 1 [i, j] = Õ √ d . D.9 PROOF OF LEMMA 22 For any i, j ∈ [d], and any t in the interval [t inc , T 1 ], we prove by induction that (A) W (t) 1 [i, j] = Õ 1 √ d , w 2i = Õ 1 √ d . (B) ∀τ ∈ [t -H, t] : sign W (τ ) 1 [i, j] = sign w (τ ) 2i = sign w (C) g (t) 1 [i, j] ≥ Ω(ξ), g 2i ≥ Ω(ξ). The base case t = t inc was already proven by Lemma 17. For t ∈ [t inc , T 1 ), suppose (B) and (C) hold for time t and (A) holds for all τ ∈ [t inc , t]. From (A), we get that ∀τ ∈ [t inc , t] : g (τ ) 1 [i, j] = w (τ ) 2i E (τ ) j = Õ 1 √ d , g (τ ) 2i ≤ d j=1 E (τ ) j W (τ ) 1 [i, j] = Õ √ d . Since when t < T 1 , ∀j ∈ [d] : E (t) j < 0, from (B) we know that ∀τ ∈ [t -H, t] : sign g (τ ) 1 [i, j] = sign g (τ ) 2i = -sign w 2i . Combining with (C) tells us that Condition 1 and 2 are satisfied. In Section D.2 we have shown that T 1 = Θ 1 √ dη . Then for t ∈ [t inc , T 1 ), we can use Lemma 21 to get that ∀t inc ≤ τ ≤ t, ∀i, j ∈ [d], sign ∆W (τ ) 1 [i, j] = sign ∆w (τ ) 2i = sign w (0) 2i , ∆W (τ ) 1 [i, j] = Θ(η), ∆w (τ ) 2i = Θ(η). Since when t = t inc , sign W (tinc) 1 [i, j] = sign w (tinc) 2i = sign w (0) 2i . We get that for t inc ≤ τ ≤ t, ∀i, j ∈ [d] : W (τ +1) 1 [i, j] = W (τ ) 1 [i, j] + Θ(η), w (τ +1) 2i = w (τ ) 2i + Θ(η). Now for t + 1, we have ∀i, j ∈ [d], sign W (t+1) 1 [i, j] = sign w (0) 2i , W (t+1) 1 [i, j] = W (tinc) 1 [i, j] + (t + 1 -t inc ) Θ(η), sign w (t+1) 2i = sign w (0) 2i , w (t+1) 2i = w (tinc) 2i + (t + 1 -t inc ) Θ(η). That means ∀τ ∈ [t + 1 -H, t + 1] : sign W (τ ) 1 [i, j] = sign w (τ ) 2i = sign w (0) 2i . This proves (B) for time t + 1. On the other hand, we get that W (t+1) 1 [i, j] ≥ W (tinc) 1 [i, j] = Θ 1 d 3 2 α+1 and w (t+1) 2i ≥ w (tinc) 2i = Ω 1 d 3 2 α . Since t + 1 ≤ T 1 which means ∀j ∈ [d] : E (t+1) j ≥ √ ηd. Then g (t+1) 1 [i, j] = w (t+1) 2i E (t+1) j ≥ Ω 1 d 3 2 α ηd = Ω(ξ), g (t+1) 2i = d j=1 E (t+1) j W (t+1) 1 [i, j] = d j=1 E (t+1) j W (t+1) 1 [i, j] ≥ dΘ 1 d 3 2 α+1 ηd = Ω(ξ). This proves (C) at time t + 1. Since t + 1 ≤ T 1 which means ∀j ∈ [d] : W (t+1) 2 W (t+1) 1 j ≤ O(1), we obtain that d i=1 w (t+1) 2i W (t+1) 1 [i, j] = d i=1 w (t+1) 2i W (t+1) 1 [i, j] = d i=1 w (tinc) 2i + (t + 1 -t inc ) Θ(η) W (tinc) 1 [i, j] + (t + 1 -t inc ) Θ(η) ≤ O(1). Note that W (tinc) 1 [i, j] , w (tinc) 2i < 1 d (since t inc < t d ), we get that (t + 1 -t inc ) Θ(η) = O 1 √ d , which gives us w (t+1) 2i = Õ 1 √ d and W (t+1) 1 [i, j] = Õ 1 √ d and hence (A) holds at time t + 1. Therefore by induction, we can prove that (A), (B), (C) hold for all t inc ≤ t ≤ T 1 . Then applying Lemma 21, we get that for all Let's first prove eq. ( 30). t inc ≤ t ≤ T 1 , ∀i, j ∈ [d] : ∆W (t) 1 [i, j] = Θ(η), ∆w By Lemma 18, for t inc ≤ t < T 1 , we have ∀i, j ∈ [d], g (t) 1 [i, j] = Ω √ η , g = Ω √ ηd . Then it suffices to show that for t inc ≤ t < T 1 , g (t) 1 [i, j] -g (t-τ ) 1 [i, j] = τ Õ(η) and g (t) 2i -g (t-τ ) 2i = τ Õ(ηd). It suffices to show that when t < T 1 , g (t+1) 1 [i, j] -g (t) 1 [i, j] = Õ(η) and g (t+1) 2i -g (t) 2i = Õ(ηd). By Lemma 17 and 22, we know that when t < T 1 , ∀i, j ∈ 35) hold for all t < T 1 (not only t inc ≤ t < T 1 ). Substituting these bounds into eq. ( 36) gives us ∀t < T 1 , [d], ∆W (t) 1 [i, j] ≤ Õ(η), ∆w (t) 2i ≤ Õ(η) and that W (t) 1 [i, j] ≤ Õ 1 √ d , w (t) 2i ≤ Õ 1 √ d . Then the bound ∆E (t) j ≤ Õ η √ d in eq. ( g (t+1) 1 [i, j] -g (t) 1 [i, j] ≤ w (t+1) 2i ∆E (t) j + ∆w (t) 2i E (t) j = Õ 1 √ d Õ η √ d + Θ(η)O(1) = Õ(η). Similarly, we have that g (t+1) 2i -g (t) 2i = Õ(ηd), which proves eq. ( 30). Note that for a, b ∈ R: = Ω d √ η . Combining with the bound a 2 -b 2 a 2 = a 2 -(a -b -a) 2 a 2 = 2a(a -b) -(a -b) 2 a 2 ≤ g (t+1) 2i -g (t) 2i = Õ(ηd), we know that the g (t) 2i parts in eq.( 30) and eq.( 31) still hold. Then we can use the same strategy in Section D.2 to prove that the w (t) 2i part of eq. ( 29) still holds, which gives us ∀i ∈ [d] : w (t+1) 2i = w (t) 2i -η sign g (t) 2i + e (t) 2i , where e (t) 2i = Õ ( √ η) . By Lemma 14, we have that at the end of the first phase (t = T 1 ), ∀i ∈ [d] : w (T1) 2i = sign w (0) 2i c (T1) + R (T1) 2i , where R (T1) 2i c (T1) = Õ √ η + 1 d α-1/2 . Combining with ∀i ∈ [d], ∀t ≤ T i : sign g (t) 2i = -sign w (0) 2i yields that during the second phase, for t ≤ T , we have  ∀i ∈ [d] : w (t) 2i = sign w (0) 2i c (t) + R (t) 2i , where R (t) 2i c (t) = Õ √ η + 1 d α-1/2 . w (t) 2i = Θ 1 √ d , W (t) 1 [i, j] = Õ 1 √ d . The base case (t = T 1 ) was already proven by Lemma 22. Now suppose for some t such that T 1 ≤ t < T , for all τ such that T 1 ≤ τ ≤ t, we have ∀i ∈ [d] : sign w (τ ) 2i = sign W (τ ) 1 [i, j 0 ] = sign w (0) 2i and that ∀i, j ∈ [d] : w (τ ) 2i = Θ 1 √ d , W (τ ) 1 [i, j] = Õ 1 √ d . Using these bounds, we get that ∀j ∈ [d] : E (τ ) j ≤ d i=1 w (τ ) 2i W (τ ) 1 [i, j] + |A j | = O(1), which then yields two upper bounds g (τ ) 1 [i, j] = w (τ ) 2i E (τ ) j = Õ 1 √ d and g (τ ) 2i ≤ d j=1 E (τ ) j W (τ ) 1 [i, j] = Õ √ d . By definition of T g , we know that for all T 1 ≤ τ ≤ t, ∀i ∈ [d] : g (τ ) 2i ≥ d √ η = Ω(ξ) and that sign g (τ ) 2i = -sign w (0) 2i , which implies that Condition 2 is satisfied for ∀i ∈ [d] . At the end of the proof of this lemma, we will show that T = Θ 1 √ dη . Together with the upper bound of g (τ ) 2i , we can apply Lemma 21 to get that w.h.p. for T 1 ≤ τ ≤ t, sign ∆w  g (τ ) 1 [i, j 0 ] = w (τ ) 2i E (τ ) j0 = Ω √ η = Ω(ξ) and that sign g (τ ) 1 [i, j 0 ] = -sign w (0) 2i . That means Condition 1 is satisfied for ∀i ∈ [d] and j 0 . Using the same technique as when we deal with w (τ ) 2i , we get that for T 1 ≤ τ ≤ t, ∀i ∈ [d] : W (τ +1) 1 [i, j 0 ] = W (τ ) 1 [i, j 0 ] + Θ(η), sign W (t+1) 1 [i, j 0 ] = sign w (0) 2i and that ∀i, j ∈ [d], ∆W (τ ) 1 [i, j] = Õ(η). Now we analyze the magnitude order of w (t+1) 2i , W (t+1) 1 [i, j] . Let's first analyze w (t+1) 2i . By Lemma 14, when t = T 1 , ∀i, j ∈ [d], w (T1) 2i w (T1) 2j = 1 ± Õ √ η + 1 d α-1/2 , W (T1) 1 [i, j 0 ] w (T1) 2i = 1 ± Õ √ η + 1 d α-1/2 . Combining with the facts that for T 1 ≤ τ ≤ t, W (τ +1) 1 [i, j 0 ] = W (τ ) 1 [i, j 0 ] + Θ(η) and w (τ +1) 2i = w (τ ) 2i + Θ(η) yields W (t+1) 1 [i,j0] w (t+1) 2i = Θ(1). Since we just proved ∀i ∈ [d] : sign w (t+1) 2i = sign W (t+1) 1 [i, j 0 ] = sign w (0) 2i , we get that (W 2 W 1 ) (t+1) j0 = d i=1 w (t+1) 2i W (t+1) 1 [i, j 0 ] = d i=1 w (t+1) 2i W (t+1) 1 [i, j 0 ] = O(1), which gives us that w (t+1) 2i = Õ 1 √ d . Recall that we have shown w (t+1) 2i ≥ Ω 1 √ d , then w (t+1) 2i = Θ 1 √ d . Now we prove W (t+1) 1 [i, j] = Õ 1 √ d . We have proved that T 1 ≤ τ ≤ t, ∀i, j ∈ [d], ∆W (τ ) 1 [i, j] = Õ(η) and w (τ +1) 2i -w (τ ) 2i = Θ(η), then ∀i, j ∈ [d], W (t+1) 1 [i, j] w (t+1) 2i ≤ W (T1) 1 [i, j] + t τ =T1 W (τ +1) 1 [i, j] -W (τ ) 1 [i, j] w (T1) 2i + t τ =T1 w (τ +1) 2i -w (τ ) 2i ≤ W (T1) 1 [i, j] + (t + 1 -T 1 ) Õ(η) w (T1) 2i + (t + 1 -T 1 ) Θ(η) = Õ(1), where the last equality uses W (T 1 ) 1 [i,j] w (T 1 ) 2i = 1 ± Õ √ η + 1 d α-1/2 . Since we already proved that w (t+1) 2i = Θ 1 √ d , we get W (t+1) 1 = Õ 1 √ d . Therefore by induction, for all t in the interval [T 1 , T ), we have ∀i, j ∈ [d] : w (t) 2i = Θ 1 √ d , W (t) 1 [i, j] = Õ 1 √ d . From the proof we also get ∀i ∈ [d] : w (t+1) 2i > w (t) 2i , and that ∆w (t) 2i = Θ(η), ∆W (t) 1 [i, j] ≤ Õ(η).

Now we verify that

T = Θ 1 √ dη . Combining ∀i, j ∈ [d] : w ( T ) 2i = Θ 1 √ d and ∀t ∈ [T 1 , T ), w -w We prove this lemma by induction. The base case (t = T 1 ) can be verified by Lemma 14. Now suppose for t in the interval [T 1 , T ), we have ∀i, j ∈ [d], W (t) 1 [i, j] = Ω 1 √ d . For t ∈ [T 1 , T ), by the proof of Lemma 26 (Section D.13), we know that for ∀τ ≤ t, ∀i, j ∈ [d] : w (τ ) 2i = Θ 1 √ d , W (τ ) 1 [i, j] = Õ 1 √ d and that g (τ ) 1 [i, j] ≤ Õ 1 √ d , g (τ ) 2i ≤ Õ √ d , and that T1 = Θ 1 √ dη . Then we can pick H := 1 1-β1 log d ηξ 2 and apply Lemma 16 and Corollary 2 to get that, w.h.p., for all t ∈ [T 1 , T ) and ∀i, j ∈ [d], the update of W 1 can be written as W (t+1) 1 [i, j] = W (t) 1 [i, j] -η t (1 -β 1 ) H τ =0 β τ 1 g (t-τ ) 1 [i, j] + (t) 1n [i, j] (1 -β 2 ) H τ =0 β τ 2 g (t-τ ) 1 [i, j] 2 + (t) 1d [i, j] + ξ , where (t) 1n [i, j] , 1d [i, j] ≤ Õ(ηξ 2 ). By Lemma 23, we have that for 1 ≤ i, j ≤ d, g (t) 1 [i, j] = w (t) 2i E (t) j = c (t) sign w (0) 2i E (t) j + R (t) g,1 [i, j], R (t) g,1 [i, j] c (t) E (t) j = Õ √ η + 1 d α-1/2 , ⇒ H τ =0 β τ 1 g (t-τ ) 1 [i, j] = sign w (0) 2i H τ =0 β τ 1 c (t-τ ) E (t-τ ) j + H τ =0 β τ 1 R (t-τ ) g,1 [i, j]. Using the fact that for a, b ∈ R, |a 2 -b 2 | a 2 ≤ 2 |a-b| |a| + |a-b| |a| 2 , we get that g (t) 1 [i, j] 2 = c (t) sign w (0) 2i E (t) j + R (t) g,1 [i, j] 2 := c (t) E (t) j 2 + R (t) gsqr,1 [i, j], where R (t) gsqr,1 [i,j] c (t) E (t) j 2 = Õ √ η + 1 d α-1/2 . That yields H τ =0 β τ 2 g (t-τ ) 1 [i, j] 2 = H τ =0 β τ 2 c (t-τ ) E (t-τ ) j 2 + H τ =0 β τ 2 R (t-τ ) gsqr,1 [i, j]. Since c (t-τ ) E (t-τ ) j 2 > 0, in eq. ( 38) we have that H τ =0 β τ 2 R (t-τ ) gsqr,1 [i, j] H τ =0 β τ 2 c (t-τ ) E (t-τ ) j 2 = Õ √ η + 1 d α-1/2 . ( ) However we cannot similarly prove that H τ =0 β τ 1 R (t-τ ) g,1 [i, j] H τ =0 β τ 1 c (t-τ ) E (t-τ ) j in eq. (37) because c (t-τ ) E (t-τ ) j may not have the same sign for τ = 0, 1, ..., H. To deal with eq.( 37), we need to consider the two cases where H τ =0 β τ 1 R (t-τ ) g,1 [i, j] H τ =0 β τ 1 c (t-τ ) E (t-τ ) j or H τ =0 β τ 1 R (t-τ ) g,1 [i, j] H τ =0 β τ 1 c (t-τ ) E (t-τ ) j . Case 1. (1 -β 1 ) H τ =0 β τ 1 R (t-τ ) g,1 [i, j] + (t) 1n [i, j] ≤ δ (1 -β 1 ) H τ =0 β τ 1 c (t-τ ) E (t-τ ) j where δ = η 1 4 + 1 d α 2 -1 4 . Note that from eq. ( 39) we have (1 -β 1 ) H τ =0 β τ 2 R (t-τ ) gsqr,1 [i, j] ≤ Õ √ η + 1 d α-1/2 (1 -β 1 ) H τ =0 β τ 2 c (t-τ ) E (t-τ ) j 2 . Combining with (t) 1d [i, j] ≤ Õ(ηξ 2 ) ≤ Õ η 1 4 + 1 d α 2 -1 4 2 ξ 2 , we can apply Lemma 33 to get that W (t+1) 1 [i, j] -W (t) 1 [i, j] = -η t (1 -β 1 )sign w (0) 2i H τ =0 β τ 1 c (t-τ ) E (t-τ ) j + (1 -β 1 ) H τ =0 β τ 1 R (t-τ ) g,1 [i, j] + (t) 1n [i, j] (1 -β 2 ) H τ =0 β τ 2 c (t-τ ) E (t-τ ) j 2 + (1 -β 2 ) H τ =0 β τ 2 R (t-τ ) gsqr,1 [i, j] + (t) 1d [i, j] + ξ = -η t 1 -β 1 √ 1 -β 2 • sign w (0) 2i H τ =0 β τ 1 c (t-τ ) E (t-τ ) j H τ =0 β τ 2 c (t-τ ) E (t-τ ) j 2 + ξ 1 + e (t) 1 [i, j] := -sign w (0) 2i v (t) j 1 + e (t) 1 [i, j] , where e (t) 1 [i, j] = Õ η 1 4 + 1 d α 2 -1 4 . Since W (t+1) 1 [i, j] -W (t) 1 [i, j] = Õ(η), we get that v (t) j = Õ(η). Case 2. (1 -β 1 ) H τ =0 β τ 1 R (t-τ ) g,1 [i, j] + (t) 1n [i, j] > δ (1 -β 1 ) H τ =0 β τ 1 c (t-τ ) E (t-τ ) j where δ = η 1 4 + 1 d α 2 -1 4 . Since R (t) g,1 [i,j] c (t) E (t) j = Õ √ η + 1 d α-1/2 , we have that (1 -β 1 ) H τ =0 β τ 1 R (t-τ ) g,1 [i, j] ≤ Õ √ η + 1 d α-1/2 (1 -β 1 ) H τ =0 β τ 1 c (t-τ ) E (t-τ ) j (i) ≤ Õ √ η + 1 d α-1/2 (H + 1)(1 -β 1 ) H τ =0 β τ 2 c (t-τ ) E (t-τ ) j 2 (ii) = Õ √ η + 1 d α-1/2 (1 -β 1 ) H τ =0 β τ 2 g (t-τ ) 1 [i, j] 2 , where (i) uses Cauchy-Schwarz inequality and β 2 = β 2 1 , (ii) uses eq. ( 38) and (39). Combining with (t) 1n [i, j] ≤ Õ(ηξ 2 ) ≤ Õ √ η + 1 d α-1/2 ξ - (t) 1d [i, j] gives us [k, j] . Since k, j are arbitrary, we have proved that at time t + 1, ∀i, j ∈ (1 -β 1 ) H τ =0 β τ 1 c (t-τ ) E (t-τ ) j < (1 -β 1 ) H τ =0 β τ 1 R (t-τ ) g,1 [i, j] + (t) 1n [i, j] η 1 4 + 1 d α 2 -1 4 ≤ Õ √ η + 1 d α-1/2 η 1 4 + 1 d α 2 -1 4   H τ =0 β τ 2 g (t-τ ) 1 [i, j] 2 - (t) 1d [i, j] + ξ   ≤ Õ η 1 4 + 1 d α 2 -1 4   H τ =0 β τ 2 g (t-τ ) 1 [i, j] 2 + (t) 1d [i, j] + ξ   , which implies W (t+1) 1 [i, j] -W (t) 1 [i, j] ≤ η Õ η 1 4 + 1 d α 2 -1 4 . [d] : W (t+1) 1 [i, j] ≥ Ω 1 √ d . Therefore by induction, we conclude that when T 1 ≤ t < T , for ∀i, j ∈ [d], W (t) 1 [i, j] = Ω 1 √ d . The remaining part of this lemma has also been proved by the analysis above.  W (t) 1 [i, j] keeps increasing. If T f,j < T , then for T f,j ≤ t < T , we will have -Õ √ ηd ≤ E (t) j ≤ Õ √ ηd . Now we start proving Lemma 25. At time T , denote S := j : T f,j < T , i.e. the set of coordinates whose E j have passed its "flip time". By Lemma 27, we know that ∀j ∈ S, E  g ( T ) 2i0 ≤ O d √ η . Then d j∈S c E ( T ) j W ( T ) 1 [i 0 , j] = d j=1 E ( T ) j W ( T ) 1 [i 0 , j] - j∈S E ( T ) j W ( T ) 1 [i 0 , j] ≤ g ( T ) 2i0 + j∈S E ( T ) j W ( T ) 1 [i 0 , j] ≤ O (d √ η) + d Õ ηd Õ 1 √ d = Õ (d √ η) . By Lemma 24, we know that when T 1 ≤ t < T , for ∀i, j ∈ [d], W (t) 1 [i, j] = Ω 1 √ d . Since the update per step ∆W (t) 1 [i, j] ≤ Õ(η), we know that sign W (t) 1 [i, j] remains unchanged during this period and sign W (t) 1 [i, j] = sign W (T1) 1 [i, j] = sign w (0) 2i independent of j. Combining with ∀j ∈ S c : E ( T ) j < 0 gives us that E ( T ) j W ( T ) 1 [i 0 , j] for different j have the same sign. Therefore for any j 0 ∈ S c , Õ (d √ η) ≥ d j∈S c E ( T ) j W ( T ) 1 [i 0 , j] = d j∈S c E ( T ) j W ( T ) 1 [i 0 , j] ≥ E ( T ) j0 W ( T ) 1 [i 0 , j 0 ] ≥ E ( T ) j0 Ω 1 √ d ⇒ E ( T ) j0 ≤ Õ d ηd . Note that the above inequality holds for any j 0 ∈ S c , which means ∀j ∈ S c : E ( T ) j ≤ Õ d √ ηd . Combining with the fact that ∀j ∈ S : We first recall that when T 1 ≤ t < T , Lemma 26 gives us for all i ∈ E ( T ) j ≤ Õ d √ ηd [d], w (t) 2i = Θ 1 √ d and W (t) 1 [i, j] = Õ 1 √ d . Then eq.( 35) we obtained in the first phase analysis still holds, which tells us that the change of E (t) j per step satisfies E (t+1) j -E (t) j = Õ η √ d for all T 1 ≤ t < T . We divide the analysis into two cases, based on whether these E (t) j ≥ √ ηd or E (t) j ≤ - √ ηd. By Lemma 24, we know that when T 1 ≤ t < T , ∀i ∈ [d], W (t) 1 [i, j] = Ω 1 √ d . Since the update per step ∆W (t) 1 [i, j] ≤ Õ(η), we know that sign W (t) 1 [i, j] remains unchanged during this period and sign W (t) 1 [i, j] = sign W (T1) 1 [i, j] = sign w (0) 2i . By the analysis of w  t+1) 2i = w (t) 2i + sign w (0) 2i ∆ (t) 2i , where ∆ (t) 2i = η 1 ± Õ √ η . Case 1. Consider some time point t such that E (t) j ≤ - √ ηd. Note that for all i ∈ [d], g (t) 1 [i, j] = w (t) 2i E (t) j = Ω √ η and that sign g (t) 1 [i, j] = -sign w (t) 2i = -sign w (0) 2i . By Lemma 23, for all i ∈ [d] we have W (t+1) 1 [i, j] = W (t) 1 [i, j] + sign w (0) 2i ∆ (t) 1 [i, j] with ∆ (t) 1 [i, j] = η 1 ± Õ √ η . That gives us E (t+1) j = d i=1 w (t+1) 2i W (t+1) 1 [i, j] -A j = d i=1 w (t) 2i + sign w (0) 2i ∆ (t) 2i W (t) 1 [i, j] + sign w (0) 2i ∆ (t) 1 [i, j] -A j = d i=1 w (t) 2i W (t) 1 [i, j] + sign w (0) 2i w (t) 2i ∆ (t) 1 [i, j] + ∆ (t) 2i W (t) 1 [i, j] + ∆ (t) 2i ∆ (t) 1 [i, j] -A j (i) =E (t) j + d i=1 w (t) 2i ∆ (t) 1 [i, j] + ∆ (t) 2i W (t) 1 [i, j] + ∆ (t) 2i ∆ (t) 1 [i, j] , ⇒ E (t+1) j > E (t) j , where (i) is because sign w (t) 2i = sign W (t) 1 [i, j] = sign w (0) 2i . Therefore we have proved that E (t) j will increase in the next step. After that for τ ≥ t + 1, as long as E (τ ) j ≤ -√ ηd, the above analysis will hold and E (τ ) j will keep increasing until E (τ ) j > -√ ηd or we reach T . Case 2. Consider some time point t such that E (t) j ≥ √ ηd. We will prove that E (t) j will decrease after a short period, and during this period, the change of it is at most Õ √ ηd . By similar arguments as in Case 1, we can get that W (t+1) 1 [i, j] = W (t) 1 [i, j] -sign w (0) 2i ∆ (t) 1 [i, j], where ∆ (t) 1 [i, j] = η 1 ± Õ √ η , Then E (t+1) j = d i=1 w (t+1) 2i W (t+1) 1 [i, j] -A j = d i=1 w (t) 2i + sign w (0) 2i ∆ (t) 2i W (t) 1 [i, j] -sign w (0) 2i ∆ (t) 1 [i, j] -A j = d i=1 w (t) 2i W (t) 1 [i, j] -sign w (0) 2i w (t) 2i ∆ (t) 1 [i, j] -∆ (t) 2i W (t) 1 [i, j] -∆ (t) 2i ∆ (t) 1 [i, j] -A j (i) =E (t) j - d i=1 w (t) 2i ∆ (t) 1 [i, j] -∆ (t) 2i W (t) 1 [i, j] + ∆ (t) 2i ∆ (t) 1 [i, j] , where (i) is because sign w (t) 2i = sign W (t) 1 [i, j] = sign w (0) 2i . E (t+1) j may not be smaller than E (t) j , but we will show that after at most t s steps for some t s , we will have E (t+ts+1) j

< E

(t+ts) j . To see this, first note that by the bounds of ∆  (t) 1 [i, j] and ∆ (t) 2i , we get ∆ (t) 1 [i, j] ≥ ∆ (t) 2i -η Õ √ η . [i, j] = η 1 ± Õ √ η . Hence E (t+ts) j -E (t+ts+1) j = d i=1 w (t+ts) 2i ∆ (t+ts) 1 [i, j] -∆ (t+ts) 2i W (t+ts) 1 [i, j] + ∆ (t+ts) 2i ∆ (t+ts) 1 [i, j] ≥ d i=1 W (t+ts) 1 [i, j] ∆ (t+ts) 1 [i, j] -∆ (t+ts) 2i + √ η∆ (t+ts) 1 [i, j] + ∆ (t+ts) 2i ∆ (t+ts) 1 [i, j] ≥ d i=1 -η Õ ( √ η) W (t+ts) 1 [i, j] + η √ η + ∆ (t+ts) 2i ∆ (t+ts) 1 [i, j] > 0, where the last inequality uses ∀i, j ∈ [d] :  W (t+ts) 1 [i, j] = Õ 1 √ d . Therefore E (t+ts+1) j < E = √ η/ Ω(η) = Õ 1 √ η . Combining with the fact that for all T 1 ≤ τ ≤ T , E (τ +1) j -E (τ ) j = Õ η √ d gives us E (t+ts) j -E (t) j ≤ Õ ηt s √ d = Õ ηd . For ii), we reach T before ∀i ∈ [d] : w (t+ts) 2i ≥ W (t+ts) 1 [i, j] + √ η. Then we have T -t ≤ √ η/ Ω(η) = Õ 1 √ η , which yields E ( T ) j -E (t) j ≤ Õ η( T -t) √ d ≤ Õ √ ηd . Combining the above two cases, we find that if for some t, E j ≥ √ ηd, then after at most t s steps E j will decrease and keeps decreasing until E j < √ ηd or we reach T . During these steps, E j can increase at most Õ √ ηd . If for some t, E j ≤ -√ ηd, then after one step it will increase and keeps increasing until E j > √ ηd or we reach T . That means once for some coordinate j, E j overshoots, it will zigzag in a small region around zero, which is -Õ √ ηd , Õ √ ηd .

E HESSIAN TENDS TO BECOME MORE AND MORE DIAGONAL DURING TRAINING

In this section, we empirically demonstrate that the trend of loss Hessian in practice is to become more and more diagonal during training. We also give a rigorous theoretical analysis on a two-layer network under Assumption 1 and 2.

E.1 EMPIRICAL RESULTS

Let's first define the diagonal domination of the i-th coordinate at time t. r OPT diag,i (t) := j =i H (t) [i, j] 2 H (t) [i, i] . To measure the diagonal domination of the whole Hessian, we need to consider the distribution of r OPT diag,i (t) for different i. Figure 14 shows the mean and median of r SGDM diag,i (t) and r Adam diag,i (t) on the sentence classification task (see Section 4.1). Here we chose 4 layers (Layer #6, 12, 17 and 22) and computed the Hessians across these 4 layers. Since the number of parameters is very large, we did the computation by random sampling. As we can see, for both r SGDM diag,i (t) and r Adam diag,i (t), the trend of their mean or median is to decrease over time, although there might be some oscillation. 

E.2 THEORETICAL ANALYSIS

To simplify the theoretical analysis, we consider the mean of r OPT diag,i (t) over all coordinate and define R OPT diag (t) := mean r OPT diag,i (t) . We consider a 2-layer network under Assumption 1 and 2, and have two goals in our proof: 1. To show that R OPT diag (t) after training is smaller than that before training (t = 0). 2. Note that in our setting (see in Assumption 1), the Hessian is a (d 2 +d)×(d 2 +d) matrix. For a completely "uniform" matrix with the same size, we have that R OPT diag (t) = Θ √ d 2 + d = Θ(d). Hence our second goal is to show that the R OPT diag (t) after training is on lower order than Θ(d). Theorem 2. Consider the ratio R OPT diag (t) defined in eq. (40). Under Assumption 1 and 2, we have that before training (t = 0), with high probability, R OPT diag (0) ≥ Ω d 4α-3 2 . ( ) For SGD+M defined in eq. (3). For any p > 0, by picking the same hyperparameters as in Theorem 1, for T SGD,1 , T SGD,2 mentioned in Theorem 1, we have with constant probability, for any t ∈ [T SGD,1 , T SGD,2 ], R SGDM diag (t) ≤ Õ √ d + q (t) , where the trend of q (t) is to decrease over time and q (TSGD,2) ≤ Õ 1 d p/2-1 = o(d). For Adam defined in eq. (3). For any p > 0, by picking the same hyperparameters as in Theorem 1, for T Adam,1 , T Adam,2 mentioned in Theorem 1, we have with high probability, for any t ∈ [T Adam,1 , T Adam,2 ], R Adam diag (t) ≤ Õ √ d + r (t) , where the trend of r (t) is to decrease over time and r (T Adam,2 ) ≤ Õ 1 d p-1 2 = o √ d . E.3 PROOF OF THEOREM 2 Lemma 4.3 of (Kawaguchi, 2016) gives us the following forms of Hessian. For any k ∈ {1, 2, ..., H + 1}, we know that ∇ vec(W k ) (∇ vec(W k ) L(W )) equals ((W H+1 . . . W k+1 ) T (W H+1 . . . W k+1 ) ⊗ (W k-1 . . . W 1 )(W k-1 . . . W 1 ) T , and for k ∈ {2, 3, ..., H + 1}, ∇ vec(W k ) (∇ vec(W1) L(W )) =(C T (W H+1 . . . W k+1 ) ⊗ (W k-1 . . . W 1 ) T ) +[(W k-1 . . . W 2 ) T ⊗ I][I d k-1 ⊗ (r(W H+1 . . . W k+1 )) .,1 • • • I d k-1 ⊗ (r(W H+1 . . . W k+1 )) .,d k ], where r = (W H+1 . . . W 1 -A) T , C = W H+1 W H • • • W 2 . For the 2-layer linear network, write the Hessian as H := H 22 H T 21 H 21 H 11 , then we have that H 11 = (W T 2 W 2 ) ⊗ I d ∈ R d 2 ×d 2 , H 22 = W 1 W T 1 ∈ R d×d , H 21 = W T 2 ⊗ W T 1 + I d ⊗ (W 2 W 1 -A) T ∈ R d 2 ×d . Intuitively, before training the elements of W 1 and W 2 are very close to zero, and W 2 W 1 -A ≈ -A. Since the elements of A are Θ(1), we know that the magnitudes of elements of H 21 are much bigger than those of H 11 and H 22 . After training, for both SGD+M and Adam, W 2 W 1 -A ≈ 0. Then H 21 ≈ (W 2 ) T ⊗ (W 1 ) T and the magnitudes of its elements are no longer much larger than those of H 11 and H 22 . From the formula of H 11 , we know that all the diagonal entries are nonzero, and among the d 4 -d 2 off-diagonal entries, there are only d 3 -d 2 nonzero entries, which helps us to bound R OPT diag (t).

E.3.1 PROOF OF EQ. (41)

Let's first analyze the weights and Hessian before training (t = 0). For ease of notation, we omit the superscript (t). For the i-th row where 1 ≤ i ≤ d, i.e. the i-th row of the submatrix [H 22 H T 21 ], we have j =i H 2 [i, j] = j =i H 2 22 [i, j] + d 2 j=1 H 2 21 [j, i] ≥ id j=(i-1)d H 2 21 [j, i] = d j=1 (w 2i W 1 [i, j] + (W 2 W 1 -A) j ) 2 = Θ(d). On the other hand, for the diagonal elements, we have w.h.p. |H[i, i]| = |H 22 [i, i]| = W 1 [i, :] 2 2 = d j=1 W 2 1 [i, j] ≤ Õ 1 d 4α-1 . Then we have that for 1 ≤ i ≤ d, j =i H 2 [i, j] |H[i, i]| ≥ Ω(d) Õ 1 d 4α-1 = Ω d 4α-1 2 . For the (id + k)-th row where 1 ≤ i ≤ d, 1 ≤ k ≤ d, i.e. the ((i -1)d + k)-th row of the submatrix [H 21 H 11 ], we have j =id+k H 2 [i, j] = j =(i-1)d+k H 2 11 [(i -1)d + k, j] + d j=1 H 2 21 [(i -1)d + k, j] ≥ H 2 21 [(i -1)d + k, i] = (w 2i W 1 [i, k] + (W 2 W 1 -A) k ) 2 = Θ(1). On the other hand, for the diagonal elements, we have w.h.p. |H[id + k, id + k]| = |H 11 [(i -1)d + k, (i -1)d + k]| = w 2 2i ≤ Õ 1 d 2α . Then we have that for 1 ≤ i ≤ d, 1 ≤ k ≤ d, j =id+k H 2 [i, j] |H[id + k, id + k]| ≥ Ω(1) Õ 1 d 2α = Ω d 2α . Taking the average, we obtain that before training, i.e. when t = 0, R OPT diag (0) ≥ d Ω d 4α-1 2 + d 2 Ω d 2α d 2 + d = Ω d 4α-3 2 . E.3.2 PROOF OF EQ. ( ) The proof is based on the lemma below. Lemma 28. Suppose the weight matrices have the following structure: W 1 = uv T + R 1 , W 2 = cu T + R T 2 , where ∀1 ≤ i, j ≤ d : |R1[i,j]| |uivj | ≤ δ, |R2i| |cui| ≤ δ, δ ∈ (0, 1). Then we have for 1 ≤ i ≤ d, j =i H 2 [i, j] |H[i, i]| ≤ 1 + δ 1 -δ 1 + |c| v 2 d j=1 u 2 j u 2 i + E 2 (1 -δ) 2 u 2 i v 2 2 , and for 1 ≤ i ≤ d, 1 ≤ k ≤ d, j =id+k H 2 [i, j] |H[id + k, id + k]| ≤ 1 + δ 1 -δ 1 + |v k | |c| d j=1 u 2 j u 2 i + |E k | (1 -δ) 2 c 2 u 2 i . Now we are ready to prove eq. ( 42). By the analyses in Section C.1, we know that for t ∈ [T SGD,1 , T SGD,2 ], the weights obtained by GD with momentum satisfy W (t) 1 = u (T1) v (t)T + R (t) 1 , W (t) 2 = c (t) u (T1)T + R (t)T 2 , where T SGD,1 = T 1 and ∀1 ≤ i, j ≤ d : R (t) 1 [i, j] u (T1) i v (t) j ≤ Õ( 0 ), R (t) 2i c (t) u (T1) i ≤ Õ( 0 ). Here 0 is defined in Definition 2. Since u (T1) doesn't depend on time t in the period (T SGD,1 , T SGD,2 ], we write u (T1) as u for ease of notation. Hence by Lemma 28, when t ∈ [T SGD,1 , T SGD,2 ], we have for 1 ≤ i ≤ d, j =i H (t) [i, j] 2 H (t) [i, i] ≤ 1 + Õ( 0 ) 1 -Õ( 0 ) 1 + c (t) v (t) 2 d j=1 u 2 j u 2 i + E (t) 2 1 -Õ( 0 ) 2 u 2 i v (t) 2 2 = O 1 + c (t) v (t) 2 d j=1 u 2 j u 2 i + O E (t) 2 u 2 i v (t) 2 2 , (44) and for 1 ≤ i ≤ d, 1 ≤ k ≤ d, j =id+k H (t) [i, j] 2 H (t) [id + k, id + k] ≤ 1 + Õ( 0 ) 1 -Õ( 0 )   1 + v (t) k c (t)   d j=1 u 2 j u 2 i + E (t) k 1 -Õ( 0 ) 2 c (t) 2 u 2 i = O   1 + v (t) k c (t)   d j=1 u 2 j u 2 i + O   E (t) k c (t) 2 u 2 i   . (45) By Lemma 3, we have u = X + Y where X i , i ∈ [d] are i.i.d Gaussian random variables and w.h.p.,

∀i ∈ [d] :

|Y i | |X i | ≤ Õ 1 d 1 4 α-1 2 := δ xy , which yields that ∀i ∈ [d] : d j=1 u 2 j |u i | ≤ 1 + δ xy 1 -δ xy d j=1 X 2 j |X i | , 1 u 2 i ≤ 1 1 -δ xy 2 1 X 2 i . ( ) By the proof in Section C.8, we know that for t ∈ t) are positive. The induction in Section C.9 further gives us that for t ∈ [T SGD,1 , T SGD,2 ], w.h.p. ∀k ∈ [d] : [T SGD,1 , T SGD,2 ], ∀i ∈ [d] : v (t) i , c v (t) k c (t) = Θ 1 √ d , which yields c (t) v (t) 2 = Θ(1) . Combining with eq. ( 47), we obtain 1 + c (t) v (t) 2 d j=1 u 2 j u 2 i ≤ O   d j=1 X 2 j |X i |   ,   1 + v (t) k c (t)   d j=1 u 2 j u 2 i ≤ O   d j=1 X 2 j |X i |   . By the proof in Section C.8, we know that for t ∈ t) are positive and monotonically increasing. On the other hand, the proof in Section C.2 and C.9 tells us that w.h.p. [T SGD,1 , T SGD,2 ], ∀i ∈ [d] : v (t) i , c E (t) 2 (resp. ∀k ∈ [d], E (t) k ) decreases from Θ( √ d) (resp. Θ(1)) when t = T SGD,1 to O( √ 0 d) (resp. O( √ )) when t = T SGD,2 . Therefore, the trend of E (t) 2 u 2 i v (t) 2 2 and E (t) k (c (t) ) 2 u 2 i is to decrease over time, and when t = T SGD,2 , we have w.h.p. ∀k ∈ [d] : E (t) k = O ( √ 0 ) , E (t) 2 = O 0 d . Moreover, when t = T SGD,2 , the inequality in eq. ( 26) becomes equality, i.e. c 2 u 2 2 = Θ √ d and ∀j ∈ [d] : u 2 2 v 2 j = Θ 1 √ d . Using u = X + Y and eq. ( 46), we have c 2 X 2 2 = Θ √ d , ∀j ∈ [d] : X 2 2 v 2 j Θ 1 √ d , ⇒ X 2 2 v 2 2 = Θ √ d , which together with the second inequality in eq. ( 47) yields 1 u 2 i v 2 2 ≤ 1 1 -δ xy 2 1 X 2 i v 2 2 = Θ d j=1 X 2 j X 2 i √ d , 1 c 2 u 2 i ≤ 1 1 -δ xy 2 1 c 2 X 2 i = Θ d j=1 X 2 j X 2 i √ d . Combining with eq. ( 49), we get that E (t) 2 u 2 i v (t) 2 2 ≤ O d j=1 X 2 j X 2 i • √ 0 , E (t) k c (t) 2 u 2 i ≤ O d j=1 X 2 j X 2 i • 0 d . Substituting eq. ( 48) and (50) into eq. ( 44) and (45) gives us ∀1 ≤ i ≤ d : j =i H (t) [i, j] 2 H (t) [i, i] ≤ O   d j=1 X 2 j |X i |   + q (t) 1i , where the trend of q (t) 1i is to decrease over time and q (TSGD,2) 1i ≤ O d j=1 X 2 j X 2 i • √ 0 . We also have ∀1 ≤ i ≤ d, 1 ≤ k ≤ d : j =id+k H (t) [i, j] 2 H (t) [id + k, id + k] ≤ O   d j=1 X 2 j |X i |   + q (t) 2i , where the trend of q where the trend of q (t) is to decrease over time and q (TSGD,2) ≤ 1 d 2 + d d i=1 O d j=1 X 2 j X 2 i • √ 0 + d d 2 + d d i=1 O d j=1 X 2 j X 2 i • 0 d ≤ O 1 d 2 + d d i=1 d j=1 X 2 j X 2 i • 0 d = O 1 d d i=1 d j=1 X 2 j X 2 i • 0 d . Denote σ 2 as the variance of X i for i ∈ [d] . By concentration of chi-squared distribution, we know that with probability at least 1 -δ for δ > 0, Therefore with constant probability, R SGDM diag (t) = Õ √ d + q (t) , where the trend of q (t) is to decrease over time and q (TSGD,2) ≤ Õ d √ 0 d . For any p > 0, by picking the same hyperparameters as in Theorem 1, we have 0 d ≤ Õ 1 d p and hence q (TSGD,2) ≤ Õ Hence by Lemma 28, when t ∈ [T Adam,1 , T Adam,2 ], we have for 1 ≤ i ≤ d, j =i H (t) [i, j] 2 H (t) [i, i] ≤ 1 + δ 1 -δ 1 + c (t) v (t) 2 d j=1 u 2 j u 2 i + E (t) 2 (1 -δ) 2 u 2 i v (t) 2 2 = O 1 + c (t) v (t) 2 √ d + O E (t) 2 v (t) 2 2 , and for 1 ≤ i ≤ d, 1 ≤ k ≤ d, j =id+k H (t) [i, j] 2 H (t) [id + k, id + k] ≤ 1 + δ 1 -δ   1 + v (t) k c (t)   d j=1 u 2 j u 2 i + E (t) k (1 -δ) 2 c (t) 2 u 2 i = O   1 + v (t) k c (t)   √ d + O   E (t) k c (t) 2   . Recall the following facts of Adam. (A) By Lemma 14, we know that for t ∈ [T Adam,1 , T 1 ] (where T 1 is defined in Definition 4), w.h.p. ∀k ∈ [d] : v Combining (A) and (B), we get that the trend of E (t) 2 v (t) 2 2 and E (t) k (c (t) ) 2 is to decrease over time, and when t = T Adam,2 , we have w.h.p. E (t) 2 v (t) 2 2 ≤ Õ d 2 √ η , E (t) k c (t) 2 ≤ Õ d 2 ηd . (53) Substituting (A) and eq. ( 53) into eq. ( 51) and ( 52) gives us w.h.p., ∀1 ≤ i ≤ d : j =i H (t) [i, j] 2 H (t) [i, i] ≤ O √ d + r (t) 1i , where the trend of r We also have ∀1 ≤ i ≤ d, 1 ≤ k ≤ d : j =id+k H (t) [i, j] 2 H (t) [id + k, id + k] ≤ Õ √ d + r (t) 2i , where the trend of r t) where the trend of r (t) is to decrease over time and r (TAdam,2) ≤ 1 d 2 + d d i=1 Õ d 2 √ η + d d 2 + d d i=1 Õ d 2 ηd ≤ Õ d 2 ηd . For any p > 0, by picking the same hyperparameters as in Theorem 1, we have ηd 4 ≤ Õ 1 d p and hence r (TAdam,2) ≤ Õ  -δ) 2 (cu i ) 2 ≤ (w 2i ) 2 ≤ (1 + δ) 2 (cu i ) 2 , (1 -δ) 2 (u i ) 2 v 2 2 ≤ W 1 [i, :] 2 2 ≤ (1 + δ) 2 (u i ) 2 v 2 2 . For the i-th row where 1 ≤ i ≤ d, i.e. the i-th row of the submatrix [H 22 H T 21 ], by triangle inequality, we have j =i H 2 [i, j] ≤ j =i H 2 22 [i, j] + d 2 j=1 H 2 21 [j, i] ≤ j =i W 1 [i, :], W 1 [j, :] 2 + d j=1 w 2 2j d k=1 W 2 1 [i, k] + E 2 ≤ W 1 [i, :] 2   j =i W 1 [j, :] 2 2 + d j=1 w 2 2j   + E 2 . Then we have that for 1 ≤ i ≤ d,  j =i H 2 [i, j] |H[i, i]| ≤ W 1 [i, :] 2 j =i W 1 [j, :] 2 2 + d j=1 w 2 2j W 1 [i, :] 2 2 + E 2 W 1 [i, :] 2 2 = j =i W 1 [j, :] 2 2 W 1 [i, :] 2 2 + d j=1 w 2 2j W 1 [i, :] 2 2 + E 2 W 1 [i, :] 2 2 ≤ (1 + δ) 2 (1 -δ) 2 • j =i u 2 j v 2 2 u 2 i v 2 2 + (1 + δ) 2 (1 -δ) 2 • c 2 d j=1 u 2 j u 2 i v 2 2 + E 2 (1 -δ) 2 u 2 i v 2 2 ≤ 1 + δ 1 -δ 1 + |c| v 2 d j=1 u 2 j u 2 i + E 2 (1 -δ) 2 u 2 i v 2 w 2 2i W 2 1 [j, k] + |E k | = |w 2i |   j =i w 2 2j + d j=1 W 2 1 [j, k]   + |E k |. Then we have that for 1 ≤ i ≤ d, 1 ≤ k ≤ d, j =id+k H 2 [i, j] |H[id + k, id + k]| ≤ |w 2i | j =i w 2 2j + d j=1 W 2 1 [j, k] w 2 2i + |E k | w 2 2i = j =i w 2 2j w 2 2i + d j=1 W 2 1 [j, k] w 2 2i + |E k | w 2 2i ≤ (1 + δ) 2 (1 -δ) 2 • j =i c 2 u 2 j c 2 u 2 i + (1 + δ) 2 (1 -δ) 2 • v 2 k d j=1 u 2 j c 2 u 2 i + |E k | (1 -δ) 2 c 2 u 2 i ≤ 1 + δ 1 -δ 1 + |v k | |c| d j=1 u 2 j u 2 i + |E k | (1 -δ) 2 c 2 u 2 i . For ease of notation, let's now drop the superscript OPT and (t) and write R OPT med,1 (t) as R med,1 and R OPT med,2 (t) as R med,2 . For a 2-layer linear network, H = 1. Consider the Hessian w.r.t W 1 , we have M 1 M T 1 = W T 2 W 2 and N T 1 N 1 is an identity matrix. Under Assumption 1, we know that W 2 is a row vector, which can be denoted as W 2 = [w 21 , w 22 , ..., w 2d1 ]. Then we have (M 1 M T 1 ) a,a = w 2 2a , (N T 1 N 1 ) b,b = 1, ⇒ R med,1 = max i (w 2i ) 2 median(w 2i ) 2 . Similarly, consider the Hessian w.r.t. W 2 , we have that M 1 M T 1 is an identity matrix and N T 1 N 1 = W 1 W T 1 . Therefore, (M 1 M T 1 ) a,a = 1, (N T 1 N 1 ) b,b = W 1 [b, :] 2 2 , ⇒ R med,2 = max i W 1 [i, :] 2 2 median W 1 [i, :] 2 2 . Hence we have related the uniformity of diagonal Hessian to that of weight matrices. In the detailed analysis, for both GD and Adam, we can prove that W 1 converges to an approximately rank 1 matrix. The following lemma allows us to use this rank 1 structure to compute R med,1 and R med,2 . Lemma 30. Suppose W 1 ∈ R d×d and W 2 ∈ R 1×d have the following structure: W 1 = uv T + R 1 , W 2 = cu T + R 2 , where u ∈ R d , v ∈ R d , R 1 ∈ R d×d , R 2 ∈ R 1×d and that ∀1 ≤ i, j ≤ d : |R 1 [i, j]| |u i v j | ≤ δ, |R 2i | |cu i | ≤ δ, δ ∈ (0, 1). Then we have R med,1 , R med,2 ∈ (1 -δ) 2 (1 + δ) 2 • max i u 2 i median u 2 i , (1 + δ) 2 (1 -δ) 2 • max i u 2 i median u 2 i . Proof. Let's first consider R med,1 . we have ∀i ∈ [d] : (1 -δ) 2 (cu i ) 2 ≤ w 2 2i ≤ (1 + δ) 2 (cu i ) 2 ⇒ (1 -δ) 2 max i (cu i ) 2 ≤ max i w 2 2i ≤ (1 + δ) 2 max i (cu i ) 2 (1 -δ) 2 median (cu i ) 2 ≤ median w 2 2i ≤ (1 + δ) 2 median (cu i ) 2 , which yields (1 -δ) 2 (1 + δ) 2 • max i u 2 i median u 2 i ≤ R med,1 = max i w 2 2i median w 2 2i ≤ (1 + δ) 2 (1 -δ) 2 • max i u 2 i median u 2 i . Similarly, for R med,2 . We have that β τ b (t-τ ) , β ∈ (0, 1). ∀i, j ∈ [d] : (1 -δ) 2 (u i v j ) 2 ≤W 2 1 [i, j] ≤ (1 + δ) 2 (u i v j ) 2 ⇒ (1 -δ) 2 u 2 i v 2 2 ≤ W 1 [i, :] 2 2 ≤ (1 + δ) 2 u 2 i v 2 2 ⇒ (1 -δ) 2 max i u 2 i v 2 2 ≤ max i W 1 [i, :] 2 2 ≤ (1 + δ) 2 max i u 2 i v 2 2 (1 -δ) 2 median u 2 i v 2 2 ≤ median W 1 [i, :] 2 2 ≤ (1 + δ) 2 median u 2 i v 2 2 , which yields (1 -δ) 2 (1 + δ) 2 • max i u 2 i v 2 2 median u 2 i v 2 2 ≤ R med,2 = max i W 1 [i, :] 2 2 median W 1 [i, :] 2 2 ≤ (1 + δ) 2 (1 -δ) 2 • max i u 2 i v 2 2 median u 2 i v 2 2 . That means (1 -δ) 2 (1 + δ) 2 • max i u 2 i median u 2 i ≤ R med,2 ≤ (1 + δ) 2 (1 -δ) Suppose ∀τ ≤ t : b (t) ≤ B, then for any > 0, the following truncated version ã(t) = (1 -β) H τ =0 β τ b (t-τ ) with H ≥ 1 1-β log B = Ω 1 1-β satisfies a (t) -ã(t) ≤ . Proof. We have that a (t) -ã(t) ≤ (1 -β) t τ =H+1 β τ b (t-τ ) ≤ (1 -β) t τ =H+1 β τ B ≤ Bβ H+1 . To make it less than , it suffices to choose H ≥ log( B )/ log β. Since β ∈ (0, 1), we know that log β ≤ β -1 < 0. We also have log B < 0. Then it suffices to choose Lemma 34. Suppose X 1 , X 2 , ..., X d are i.i.d Gaussian with mean 0 and variance σ 2 , then for 0 < δ < 1 e , we have with probability at least 1 -δ, H ≥ log( /B) β -1 ≥ log( /B) log β ⇒ H ≥ 1 1 -β log B = Ω 1 1 -β . max 1≤i≤d X 2 i ≥ σ 2 C 1 log d -C 2 log log 1 δ for some C 1 , C 2 > 0. Proof. It suffices to assume that σ 2 = 1 and prove that w.p. at least 1 -δ, max 1≤i≤d X 2 i ≥ C 1 log d -C 2 log log 1 δ . First, by the lower bound of Gaussian tail, there exists α, β > 0 such that P(|X i | > x) = 2P(X i > x) ≥ αe -βx 2 for x ≥ 0. Then by i.i.d., we have Consider X i for some fixed i. Since X i ∼ N (0, 1), we have P(|X i | ≤ t) ≤ 2t Therefore, with constant probability, 

P(max

d i=1 1 |X i | = log 2 d C -1 k=0 i∈I k 1 |X i | + |Xi|>1 1 |X i | ≤ log 2 d C -1 k=0 O d + d • 2 k+1 log d + 2 k+1 log d + d = O d log 2 d C + O d log d • ( √ 2) log 2 d C +C2 + 2 log



Recall that the main theoretical bound in the original Adagrad paper(Duchi et al., 2011) is in terms of the diagonal scaling. In Assumption 2 we assume Gaussian initialization. Due to the rotational invariance of Gaussian distribution, we can assume that all coordinates of A are positive without loss of generality. https://huggingface.co/docs/transformers/v4.16.2/en/training https://pytorch.org/tutorials/beginner/translation_transformer.html https://pytorch.org/tutorials/beginner/transformer_tutorial.html We borrowed the implementation here https://pytorch-tutorial.readthedocs.io/en/ latest/tutorial/chapter03_intermediate/3_2_2_cnn_resnet_cifar10/ and replace the "layers" array [2,2,2] with [1,1,1]. https://www.math.uwaterloo.ca/ ˜hwolkowi/matrixcookbook.pdf



Figure 1: (left) Training losses of SGD+M starting from xSGD and xAdam.(right) The 10th largest value over median in the diagonal of loss Hessian (which can be viewed as a variant of R OPT med (t) defined in eq. (1)) for Adam and SGD+M. Since the full Hessian is too big, here we selected several layers and randomly sampled 200 coordinates per layer to compute.

Figure 2: Training losses of Adam and SGD+M on the sentence classification task described in Section 4.1.

(left)  shows the training losses for one among them.

Figure 3: Training losses of Adam and SGD+M for the translation task on (left) Multi30k (see Section 4.1) and (right) data with randomly generated targets (see Section 4.2).

Figure 5: Training losses of Adagrad and SGD on wikitext-2 (left) and random data (right)

Figure 6: Training losses of RMSprop, AMSGrad and SGD+M on the translation task described in Section 4.1.

Figure 7: Training losses of SGD, Adam after SGD and Adam for the translation task

Figure 8: Examples of layers with approximately low rank structure (right) and without low rank structure (left)

Figure 9: Ru and Rv for Adam and SGD with momentum in some layers

Figure 10: How the true gradient ({|gi k |} d k=1 ) and "adaptive gradient" ({|gadapt,i k |} d k=1 ) align with diagonal of Hessian ({|Hi k ,i k |} d k=1 ). Here coordinates are sorted such that |Hi 1,i 1 | ≤ |Hi 2 ,i 2 | ≤ ... ≤ |Hi d ,i d | (suppose H ∈ R d×d ).Experiments were conducted on the model described in Section A.1. This figure shows the results on the 12-th layer.

Figure 11: How the true gradient ({|gi k |} d k=1 ) and "adaptive gradient" ({|gadapt,i k |} d k=1 ) align with diagonal of Hessian ({|Hi k ,i k |} d k=1 ). Here coordinates are sorted such that |Hi 1,i 1 | ≤ |Hi 2 ,i 2 | ≤ ... ≤ |Hi d ,i d | (suppose H ∈ R d×d ).Experiments were conducted on the sentence classification task described in Section 4.1. This figure shows the results on the 12-th layer.

Figure 12: How the true gradient ({|gi k |} d k=1 ) and "adaptive gradient" ({|gadapt,i k |} d k=1 ) align with diagonal of Hessian ({|Hi k ,i k |} d k=1 ). Here coordinates are sorted such that |Hi 1,i 1 | ≤ |Hi 2 ,i 2 | ≤ ... ≤ |Hi d ,i d | (suppose H ∈ R d×d ).Experiments were conducted on the translation task described in Section 4.1. This figure shows the results on the 5-th layer.

Figure 13: (a) Training losses of Adam and SGD+M for the translation task, both with weight decay. (b) Training losses of Adam and SGD+M for a ResNet trained on CIFAR-10.

) 2 in Lemma 31 are upper bounded by 1 d α 2 . In the theorem we consider the training period before T SGD,2 so the time T in Lemma 31 is set as T SGD,2 . In the following sections, we will prove that T SGD,2 ≤ O d α log( √ d/ ) η . Then by Lemma 31, we have with probability at least 1 -1 d , for ∀t ≤ T 1 and ∀i, j ∈ [d],

different from those defined in Section C.2, but we abuse the notation and still use r

Under Assumption 1, 2 and 3, suppose σ ≤ η 3/2 d α/2+1 . By picking η ≤ O

t) 2i = Õ η 2 d 13/4 . Moreover, we can get that ∀i ∈ [d] : is defined in eq. (18). Lemma 11. Under the conditions of Lemma 10 and pick η ≤ O

Aj are of order Õ( 0 ).

Aj are of order Õ( 0 ). All of them can be verified by the proof in Section C.2 and the definition of

first try to bound the length of min{T 2 , T 3 }. More formally, we prove that under the conditions of Lemma 10 and pick η ≤ O of Lemma 10, we know that ∀j ∈ [d] :

where (i) and (ii) use Lemma 11.(D) The fact that ∀j ∈ [d] : 0 ) was already proved in Lemma 12 in eq.(21). To analyze R (t+1) vi

the end of the first phase, we have ∀j ∈ [d] : W for ∀i, j ∈ [d]. D.10 PROOF OF LEMMA 19

.13 PROOF OF LEMMA 26 By definition of T f , there exists j 0 ∈ [d] such that E (τ ) j0 < -√ ηd for T 1 ≤ t ≤ T . We prove by induction that during this period, ∀i ∈ [d] : sign w j ∈ [d] :

(η). Specially, when τ = t, we get the lower bound w

have that ∀i ∈ [d] :

η), we immediately get that T -T 1 = Θ 1 √ dη . In Section D.2 we have shown that T 1 = Θ 1 √ dη , then we get T = Θ 1 √ dη . D.14 PROOF OF LEMMA 24

our lemma will immediately follow. If S c = φ, we have T = min {T g , T f } = T g and that ∀j ∈ S c : E ( T ) j < 0. By the definition of T g , we know that ∃i 0 ∈ [d] :

23, we have for all i ∈ [d], w

Figure14: Mean and median of r SGDM diag,i (t) and r Adam diag,i (t) for the full hessian across the four layers(#6,12,17,22)

+ q (t) ,

d p/2-1 = o(d). E.3.3 PROOF OF EQ. (43) By the analyses in Section D.1, we know that for t ∈ [T Adam,1 , T Adam,2 ], the weights obtained by Adam satisfy where ∀i ∈ [d] : u i = sign(w

= c (t) = η(t -t inc ). Specially, when t = T Adam,1 , ∀k ∈ [d] : v (t) k = c (t) = 1 d α 2 . Lemma 24 and 26 tell us that for t ∈ [T 1 , T Adam,2 ] w.h.p. ∀i, j ∈ [d] : |W 1 [i, j]| = Θ 1 √ d , |w 2i | = Θ 1 √ d , which gives us ∀k ∈ [d] : v means when t ∈ [T Adam,1 , T Adam,2 ], ∀k ∈ [d] :v Lemma 14 and 25 tell us that w.h.p.E (t) 2 (resp. ∀k ∈ [d], E (t) k ) decreases from Θ(d) (resp. Θ(1)) when t = T Adam,1 to Õ d 2 √ η (resp. Õ d √ ηd ) when t = T Adam,2 .

Õ d 2 √ η .

the assumed weight structure, we get that ∀i ∈ [d] :(1

For the (id + k)-th row where1 ≤ i ≤ d, 1 ≤ k ≤ d, i.e. the ((i -1)d + k)-th row of the submatrix [H 21 H 11 ], by triangle inequality again, we have

Suppose a, b, c, e a , e b , e c ∈ R, b > 0, c > 0 satisfy b + e b + e c > 0, |e a | ≤ δ|a|, |e b | ≤ δb, |e c | ≤ δ 2 c 2 with 0 < δ 1, then we have a + e a √ b + e b + e c + c = a √ b + c (1 + R), where |R| = O(δ). Proof. We have a + e a √ b + e b + e c + c = a √ b + c + a √ b + e b + e c + c -Here the denominator of (i) uses b + e b ≥ b(1 -δ) > 0 and √ x + y ≥ √ x -|y| when x ≥ 0, x + y ≥ 0. Now let's bound |q 4 |. If e c > 0, we have e c = |e c | and |q 4 | ≤ √ ec √ ec = 1 since b + e b ≥ b(1 -δ) > 0. If e c ≤ 0, note that b + e b + e c > 0, we have |e c | < b + b e ≤ b(1 + δ), which yields |q 4 | ≤ √ |ec| √ b = O(1). Combining the above bounds give us |q 1 | ≤ |q 3 | + δ|q 4 | = O(δ).On the other hand, |q 2 | can be bounded by|q 2 | ≤ δ √ b + c √ b + e b -|e c | + c ≤ δ √ b + c b(1 -δ) + c(1 -δ) = O(δ).Then |R| ≤ |q 1 | + |q 2 | = O(δ)

i | ≤ x) = (1 -P(|X i | > x)) d ≤ (1 -αe -βx 2 ) d ≤ exp(-dαe -βx 2 ),where the last inequality uses 1 -x ≤ e -x for x ∈ [0, 1]. Let exp(-dαe -βx 2 ) = δ, we get that w.p. at least 1 -δ, Lemma 35. Suppose X 1 , X 2 , ..., X d are i.i.d Gaussian with mean 0 and variance σ 2 , then we have with constant probability,

It suffices to assume that σ 2 = 1 and prove that with constant probability,

Then we know that with probability at least 1-Θ d -1 , |X i | ≥ C d for some C > 0.Then by union bound, with constant probability, ∀i ∈ [d] :|X i | ≥ C d . Now we split the interval [ C d , 1] into several subintervals I k = {i : |X i | ∈ [2 -k-1 , 2 -k ]} for k = 0, 1, ..., log 2 d C -1. Let p k = P(|X i | ∈ [2 -k-1 , 2 -k ]), we know that |I k | ∼ Binomial(d, p k ) and p k ≤ C 1 • 2 -k-1 .Then by the concentration of binomial variables, we have w.p. at least1 -d -p for p > 0, |I k | = O dp k + √ dp k log d + log d = O d • 2 -k-1 + d • 2 -k-1 log d + log d . Then we have i∈I k 1 |X i | ≤ |I k |2 k+1 = O d + d • 2 k+1 log d + 2 k+1 log d , k = 0, 1, ..., log 2 d C -1.

R Adam med (t) and R SGDM med (t) in some layers, on the sentence classification task (see Section 4.1).

Table 2a shows the averaged R Adam

R Adam med (t) and R SGDM med (t) in some layers for the translation task. (a) on Multi30k and (b) on data with randomly generated targets.

Table 2b shows the averaged R Adam med (t), R SGDM med (t) and

Since the weights and Hessians in different layers may have different magnitudes, we compute the R OPT OPT med (t) found by SGD+M (resp. Adam) w.r.t. W k at time t where k = 1, 2.

resp. Table 4) shows the R OPT med (t) for Adagrad and SGD under uniform (resp. normal) initialization with different gains.





R OPT med (t) of Adagrad and SGD for random data

R RMSprop where t, t are picked such that tth Adam iterate and t th SGD+M iterate have the same training loss. The details of the tasks are described in in Section 4.1. Table7shows the results of R Adam med (t) and R SGDM med (t ) in some layers.





R Adam

R Adam

S Adam med (t) and S SGDM med (t) in some layers for the translation task.

R Adam med (t) and R SGDM med (t) in some layers for the translation task.

Lemma 7 tells us that C 2 can be written as C 2 := 1 2 (C 3 + C 4 ) where C 3i , i ∈ [d] are i.i.d Gaussian random variables and that w.h.p. ∀i ∈ [d] : |C4i| |C3i| ≤ Õ

∈ [d]. Then we know that eq.(35) still holds, which gives us ∀j ∈ [d] : E

Now we analyze the lower bound of W

= β 2 1 . Consider certain coordinate j. For T 1 ≤ t < min T , T f,j , we have ∀i ∈ [d] :

completes the proof. At time the "flip time" t = T f,j , by definition, E

Lemma 31. LetA = 1 m Y X T , Λ xx := 1 m XX T , g = ∇ W k L(W (t) ), k = 1, 2. Denote Ã(t) , Λ(t) , k = 1,2 as the corresponding batch versions at time t. Let M Proof. By Assumption 3 and Chebyshev's inequality, we have for fixed i, j ∈ [d] and t ≤ T , T d 2 σ 2 λ 2 , which gives us with probability at least 1 -1 Note that for all t ≤ T and ∀i ∈ [d], Then we have with probability at least 1 -1 d , for all t ≤ T and ∀i ∈ [d], -Λ xx -Ã(t) -A , we get that with probability at least 1 -1 d , for all t ≤ T and ∀i, j ∈ [d], Lemma 32. Consider two sequences {a (t) } t≥0 , {b (t) } t≥0 , which satisfy a (t) = (1 -β)

REPRODUCIBILITY STATEMENT

The training details (e.g. hyperparameters) of experiments are specified in Section 4. The source code is provided in the supplemental material. For the theoretical results, Section 3.3 states the full set of assumptions. Section C and Section D in the appendix provide complete proofs.Then eq. ( 31) immediately follows from eq. (30).

D.11 PROOF OF LEMMA 15

We divide Lemma 15 into the following three lemmas. Combining them together immediately gives us the whole proof.The first lemma below gives us the structure of W 2 in the second phase and that of W 1 under some conditions.Lemma 23. Under Assumption 1, 2 and 3, suppose, where eand moreover2i , where1 [i, j] , where eThe second lemma below also analyzes the structure of W 1 but removes the conditions in Lemma 23.Lemma 24. Under Assumption 1, 2 and 3, suppose σ ≤ η whereThe third lemma proves the convergence of Adam at time T .Lemma 25. Under Assumption 1, 2 and 3, suppose

D.12 PROOF OF LEMMA 23

The proof is based on the following lemma, which gives a coarse analysis on the magnitude of weights and their increments per step during the second phase.Lemma 26. Under Assumption 1, 2 and 3, supposeMoreover, we have that ∀i, j ∈ [d] :Equipped with Lemma 26, we are ready to prove Lemma 23. We will only prove the results of w Consider certain i.j ∈ [d] and the period from T 1 to t. Denote T as the set of time points when Case 1 is satisfied. By Lemma 26, we know that η(t -TBy the first phase analysis, we have thatwhereCombining with the analysis of Case 1, we have thatwhereCombining the above results together yieldsBy the inductive hypothesis.Therefore, we have that for any j ∈ [d] and any i 1 , i 2 ∈ [d],By Lemma 23, we know that w2i with different i are also roughly equal, i.e.where k can be any index in {1, 2, ..., d} and the last equality uses ∀i ∈ [d] :

F CONNECTION BETWEEN DIAGONAL OF LOSS HESSIAN AND WEIGHTS

The partial derivative at W i of the cost function for each i is given by:In our experiments, we were interested in the diagonal elements of the hessian. These are given by:for each possible i, a, b. For ease in notation, define for each i, the quantities M i := W T i+1 . . . W T

H+1

and N i := W T 1 . . . W T i-1 . Then we have the following lemma. Lemma 29. The diagonal elements of the hessian of the cost function are given by:Proof. We have:This implies that:where the last step follows since M i and N i are not functions of W i .Sincewhere C i and D i are not functions of W i . Now, Equation 74 in the Matrix Cookbook 12 shows us that for any matrices A and X we have:Note that W i ∈ R di×di-1 , then we can apply this to obtain that:. This completes the proof.

