UNDERSTANDING THE GENERALIZATION OF ADAM IN LEARNING NEURAL NETWORKS WITH PROPER REGU-LARIZATION

Abstract

Adaptive gradient methods such as Adam have gained increasing popularity in deep learning optimization. However, it has been observed in many deep learning applications such as image classification, Adam can converge to a different solution with a worse test error compared to (stochastic) gradient descent, even with a fine-tuned regularization. In this paper, we provide a theoretical explanation for this phenomenon: we show that in the nonconvex setting of learning over-parameterized two-layer convolutional neural networks starting from the same random initialization, for a class of data distributions (inspired from image data), Adam and gradient descent (GD) can converge to different global solutions of the training objective with provably different generalization errors, even with weight decay regularization. In contrast, we show that if the training objective is convex, and the weight decay regularization is employed, any optimization algorithms including Adam and GD will converge to the same solution if the training is successful. This suggests that the generalization gap between Adam and SGD in the presence of weight decay regularization is closely tied to the nonconvex landscape of deep learning optimization, which cannot be covered by the recent neural tangent kernel (NTK) based analysis.

1. INTRODUCTION

Adaptive gradient methods (Duchi et al., 2011; Hinton et al., 2012; Kingma & Ba, 2015; Reddi et al., 2018) such as Adam are very popular optimizers for training deep neural networks. By adjusting the learning rate coordinate-wisely based on historical gradient information, they are known to be able to automatically choose appropriate learning rates to achieve fast convergence in training. Because of this advantage, Adam and its variants are widely used in deep learning. Adam and SGD on the CIFAR-10 dataset. Despite their fast convergence, adaptive gradient methods have been observed to achieve worse generalization performance compared with gradient descent and stochastic gradient descent (SGD) (Wilson et al., 2017; Luo et al., 2019; Chen et al., 2020; Zhou et al., 2020) in many deep learning tasks such as image classification (we have done some simple deep learning experiments to justify this, the results are reported in Table 1 ). Even with explicit weight decay regularization, achieving good test error with adaptive gradient methods seems to be challenging. Moreover, we have also visualized the first layer of AlexNet trained by Adam and SGD in Figure 1 , where we can also observe a clear difference between Adam and SGD: the model learned by Adam is more "noisy" than that learned by SGD. Several recent works provided theoretical explanations of this generalization gap between Adam and GD by showing that Adam and GD have different implicit bias. Wilson et al. (2017) ; Agarwal et al. (2019) considered a setting of linear regression, and showed that Adam can fail when learning an overparameterized linear model on certain specifically designed data, while SGD can learn the linear model to achieve zero test error. This example in linear regression offers valuable insights into the difference between SGD and Adam. However, there is a gap between their theoretical results and the practical observations, since they consider a convex optimization setting, and the difference between Adam and SGD will no longer be observed when adding weight decay regularization. In fact, as we will show in this paper (Theorem 4.2), regularization can successfully correct the different implicit bias and push different algorithms to find the same solution, since the regularized training loss function of a convex model becomes strongly convex, which exhibits one unique global optimum. For this reason, we argue that the example in the convex setting cannot fully capture the differences between GD and Adam for training neural networks. More recently, Zhou et al. (2020) studied the expected escaping time of Adam and SGD from a local basin, and utilized this to explain the difference between SGD and Adam. However, their results do not take NN architecture into consideration, and do not provide an analysis of test errors either. In this paper, we aim at answering the following question Why is there a generalization gap between Adam and gradient descent in learning neural networks, even with weight decay regularization? Specifically, we study Adam and GD for training neural networks with weight decay regularization on an image-like data model, and demonstrate the different behaviors of Adam and GD based on the notion of feature learning/noise memorization decomposition. Inspired by the experimental observation in Figure 1 where Adam tends to overfit the noise component of the data, we consider a model where the data are generated as a combination of feature and noise patches, and analyze the convergence and generalization of Adam and GD for training a two-layer convolutional neural network (CNN) . The contributions of this paper are summarized as follows. • We establish global convergence guarantees for Adam and GD with weight decay regularization. We show that, starting at the same random initialization, Adam and GD can both train a two-layer convolutional neural network to achieve zero training error after polynomially many iterations, despite the nonconvex optimization landscape. • We further show that GD and Adam in fact converge to different global solutions with different generalization performance: when performed on the considered image-like data model, GD can achieve nearly zero test error, while the generalization performance of the model found by Adam is no better than a random guess. In particular, we show that the reason for this gap is due to the different training behaviors of Adam and GD: Adam is more likely to fit dense noises and output a model that is largely contributed by the noise patches; GD prefers to fit training data using their feature patch and finds a solution that is mainly composed by the true features. • We also show that for convex settings with weight decay regularization, both Adam and gradient descent converge to the same solution and therefore have no test error difference. This suggests that the difference between Adam and GD cannot be fully explained by linear models or neural networks trained in the "almost convex" neural tangent kernel (NTK) regime (Jacot et al., 2018; Allen-Zhu et al., 2019b; Du et al., 2019a; Zou et al., 2019) . It also demonstrates that the inferior generalization performance of Adam is closely tied to the nonconvex landscape of deep learning optimization, and cannot be solved by adding regularization.

2. RELATED WORK

In this section, we discuss the works that are closely related to our paper. Generalization gap between Adam and SGD. The worse generalization of Adam compared with SGD has also been observed by some recent works and has motivated new variants of neural network training algorithms. Keskar & Socher (2017) proposed to switch between Adam and SGD to achieve better generalization. Merity et al. (2018) proposed a variant of the averaged stochastic gradient method to achieve good generalization performance for LSTM language models. Luo et al. (2019) proposed to use dynamic bounds on learning rates to achieve a smooth transition from adaptive methods to SGD to improve generalization. Our theoretical results for GD and Adam can also Clearly, the model learned by Adam is more "noisy" than that learned by SGD, implying that Adam is more likely to overfit the noise in the training data. provide theoretical insights into the effectiveness of these empirical studies. Optimization and generalization in deep learning. Our work is also closely related to the recent line of work studying the optimization and generalization guarantees of neural networks in the neural tangent kernel (NTK) regime (Jacot et al., 2018) or lazy training regime (Chizat et al., 2019) . In particular, recent works (Du et al., 2019b; a; Allen-Zhu et al., 2019b; Zou et al., 2019) showed that the optimization only happens within a small neighborhood region around the random initialization and proved the global convergence of GD and SGD when the neural network is sufficiently wide. Moreover, the generalization ability of GD/SGD has been further studied in the same setting (Allen-Zhu et al., 2019a; Arora et al., 2019a; b; Ji & Telgarsky, 2020; Chen et al., 2021) , which suggests that wide neural network trained by GD/SGD can learn a low-dimensional function class. Moreover, Allen-Zhu & Li (2019) ; Bai & Lee (2019) initiated the study of learning neural networks beyond the NTK regime as it differs from the practical DNN training. Our analysis in this paper is also beyond NTK, and gives a detailed comparison between GD and Adam. Feature learning by neural networks. This paper is also closely related to several recent works that studied how neural networks can learn features. Allen-Zhu & Li (2020a) showed that adversarial training purifies the learned features by removing certain "dense mixtures" in the hidden layer weights of the network. Allen-Zhu & Li (2020b) studied how ensemble and knowledge distillation work in deep learning when the data have "multi-view" features. Frei et al. (2022b) studied the feature learning for two-layer networks, and demonstrated its superior performance than linear models. Shen et al. (2022) explored the benefit of data augmentation by showing its ability to achieve more effective feature learning. This paper studies a different aspect of feature learning by Adam and GD, and shows that GD can learn the features while Adam may fail even with proper regularization.

3. PROBLEM SETUP AND PRELIMINARIES

We consider learning a CNN with Adam and GD based on n independent training examples {(x i , y i )} n i=1 generated from a data model D. In the following. we first introduce our data model D, and then explain our neural network model and the details of the training algorithms. Data model. We consider a data model where the data inputs consist of feature and noise patches. Such a data model is motivated by image classification problems where the label of an image usually only depends on part of an image, and the other parts of the image showing random objects, or features that belong to other classes, can be considered as noises. When using CNN to fit the data, the convolution operation is applied to each patch of the data input separately. We claim that our data model is more practical than those considered in Wilson et al. (2017); Reddi et al. (2018) , which are handcrafted for showing the failure of Adam in term of either convergence or generalization (detailed illustrations of the data models in these works are deferred to the appendix). For simplicity, we only consider the case where the data consists of one feature patch and one noise patch. However, our result can be easily extended to cover the setting where there are multiple feature/noise patches. The detailed definition of our data model is given in Definition 3.1 as follows. Definition 3.1. Each data point (x, y) with x ∈ R 2d and y ∈ {-1, 1} is generated as follows: x = [x 1 , x 2 ] , where one of x 1 and x 2 denotes the feature patch that consists of a feature vector y • v, which is assumed to be 1-sparse, and the other one denotes the noise patch and consists of a noise vector ξ. Without loss of generality, we assume v = [1, 0, . . . , 0] . The noise vector ξ is generated according to the following process: • Randomly select s coordinates from [d]\{1} uniformly, denoted as a vector s ∈ {0, 1} d . • Generate ξ from distribution N (0, σ 2 p I), and then mask off the first coordinate and other d -s -1 coordinates, i.e., ξ = ξ s. • Add feature noise to ξ, i.e., ξ = ξ -αyv, where 0 < α < 1 is the strength of the feature noise. In particular, throughout this paper we set d = Ω(n 4 ), s = Θ d 1/2 n 2 , σ 2 p = Θ 1 s•polylog(n) and α = Θ σ p • polylog(n) . The most natural way to think of our data model is to treat x as the output of some intermediate layer of a CNN. In literature, Papyan et al. (2017) pointed out that the outputs of an intermediate layer of a CNN are usually sparse. Yang (2019) also discussed the setting where the hidden nodes in such an intermediate layer are sampled independently. This motivates us to study sparse features and entry-wisely independent noises in our model. In this paper, we focus on the case where the feature vector v is 1-sparse and the noise vector is s-sparse for simplicity. However, these sparsity assumptions can be generalized to the settings where the feature and the noises are denser, as long as the sparsity gap between feature and noises exists. Note that in Definition 3.1, each data input consists of two patches: a feature patch yv that is positively correlated with the label, and a noise patch ξ which contains the "feature noise" -αyv as well as random Gaussian noises. Importantly, the feature noise -αyv in the noise patch plays a pivotal role in both the training and test processes, which connects the noise overfitting in the training process and the inferior generalization ability in the test process. Moreover, we would like to clarify that the data distribution considered in our paper is an extreme case where we assume there is only one feature vector and all data has a feature noise, since we believe this is the simplest model that captures the fundamental difference between Adam and SGD. With this data model, we aim to show why Adam and SGD perform differently. Our theoretical results and analysis techniques can also be extended to more practical settings where there are multiple feature vectors and multiple patches, each data can either contain a single feature or multiple features, together with pure random noise or feature noise. Two-layer CNN model. We consider a two-layer CNN model F using truncated polynomial activation function σ(z) = (max{0, z}) q and fix the weights of second layer to be all 1's, where q ≥ 3. Given the data (x, y), the j-th output of the CNN can be formulated as F j (W, x) = m r=1 σ( w j,r , x 1 ) + σ( w j,r , x 2 ) = m r=1 σ( w j,r , y • v ) + σ( w j,r , ξ ) , (3.1) where m is the width of the network, w j,r ∈ R d denotes the r-th CNN filter, and W is the collection of model weights. For the ease of analysis, we set the output layer as all 1's. Our analyses and results can still be applied if we use random second layer weights. Besides, the motivation of using polynomial ReLU activation function is to guarantee that the loss function is (locally) smooth and the amplification ability of pattern learning. It can be replaced by a smoothed ReLU activation function (e.g., the activation function used in Allen-Zhu & Li (2020b) ). If we assume the input data distribution is Gaussian, we can also deal with ReLU activation function (Li et al., 2020) . A set of similar smoothed ReLU-type activation functions have also been widely considered to study the generalization performance of two-layer neural networks from different aspects (Frei et al., 2022a; Cao et al., 2022; Shen et al., 2022; Chen et al., 2022) . Moreover, we would like to emphasize that x 1 and x 2 denote two data patches, which are randomly assigned with feature vectors or noise vectors independently for each data point. The leaner has no knowledge about which one is the feature patch (or noise patch). In this paper we assume the width of the network is polylogarithmic in the training sample size, i.e., m = polylog(n). We assume j ∈ {-1, 1} in order to make the logit index be consistent with the data label. Moreover, we assume that the each weight is initialized from a random draw of Gaussian random variable ∼ N (0, σ 2 0 ) with σ 0 = Θ d -1/4 . Training objective. Given the training data points {(x i , y i )} i=1,...,n , we consider to learn the model parameter W by optimizing the empirical loss function with weight decay regularization L(W) = 1 n n i=1 L i (W) + λ 2 W 2 F , (3.2) where L i (W) = -log e Fy i (W,x i ) j∈{-1,1} e F j (W,x i ) denotes the individual loss for the data point (x i , y i ) and λ ≥ 0 is the regularization parameter. In particular, the regularization parameter can be arbitrary as long as it satisfies λ ∈ 0, λ 0 with λ 0 = Θ 1 d (q-1)/4 n•polylog(n) . We claim that the λ 0 is the largest feasible regularization parameter that the training process will not stuck at the origin point (recall that L(W) admits zero gradient at W = 0.) Training algorithms. In this paper, we consider full-batch gradient descent and Adamfoot_0 . In particular, starting from W (0) = {w (0) j,r , j = {±1}, r ∈ [m]}, the gradient descent update rule is w (t+1) j,r = w (t) j,r -η • ∇ wj,r L(W (t) ), where η is the learning rate. Meanwhile, Adam store historical gradient information in the momentum m (t) and a vector v (t) as follows m (t+1) j,r = β 1 m (t) j,r + (1 -β 1 ) • ∇ wj,r L(W (t) ), (3.3) v (t+1) j,r = β 2 v (t) j,r + (1 -β 2 ) • [∇ wj,r L(W (t) )] 2 , (3.4) and entry-wisely adjusts the learning rate: w (t+1) j,r = w (t) j,r -η • m (t) j,r / v (t) j,r , where β 1 , β 2 are the hyperparameters of Adam (a popular choice in practice is β 1 = 0.9, and β 2 = 0.99), which are considered as constants in our paper, and in (3.4) and (3.5), the square (•) 2 , square root √ •, and division •/• all denote entry-wise calculations. We would like to clarify the original Adam paper (Kingma & Ba, 2015) considers to normalize the gradient m (t) j,r via [v (t) j,r + ] 1/2 , while the small bias term is ignored in our paper. In practice, tuning can help improve the generalization ability of Adam (Choi et al., 2019) , as it allows to make a trade-off between the normalized gradient update and gradient update. We remark that considering tunable is beyond the focus of this paper. For the ease of analysis, we do not consider the initialization bias correction in the original Adam paper either and set m (0) j,r = ∇ wj,r L(W (0) ) and v (0) j,r = [∇ wj,r L(W (0) )] 2 .

4. MAIN RESULTS

In this section we will state the main theorems in this paper. We first provide the learning guarantees of Adam and Gradient descent for training a two-layer CNN model in the following theorem. Recall that in this setting the training objective is nonconvex. Theorem 4.1 (Nonconvex setting). Consider a two-layer CNN defined in (3.1) with d = Ω(n 4 ) and regularized training objective (3.2) with a regularization parameter λ > 0, suppose the network width is m = polylog(n) and the data distribution follows Definition 3.1, then we have the following guarantees on the training and test errors for the models trained by Adam and Gradient descent: 1. Suppose we run Adam for T = poly(n) η iterations with η = 1 poly(n) , then with probability at least 1 -O(n -1 ), we can find a NN model W * Adam such that ∇L(W * Adam ) 1 ≤ 1 T η . Moreover, the model W * Adam also satisfies: • Training error is zero: 1 n n i=1 1 F yi (W * Adam , x i ) ≤ F -yi (W * Adam , x i ) = 0. • Test error is high: P (x,y)∼D F y (W * Adam , x) ≤ F -y (W * Adam , x) ≥ 1 2 . 2. Suppose we run gradient descent for T = poly(n) η iterations with learning rate η = 1 poly(n) , then with probability at least 1-O(n -1 ), we can find a NN model W * GD such that ∇L(W * GD ) 2 F ≤ 1 T η . Moreover, the model W * GD also satisfies: • Training error is zero: 1 n n i=1 1 F yi (W * GD , x i ) ≤ F -yi (W * GD , x i ) = 0. • Test error is nearly zero: P (x,y)∼D F y (W * GD , x) ≤ F -y (W * GD , x) = 1 poly(n) . From the optimization perspective, Theorem 4.1 shows that both Adam and GD can be guaranteed to find a point with a very small gradient, which can also achieve zero classification error on the training data. Moreover, it can be seen that given the same iteration number T and learning rate η, Adam can be guaranteed to find a point with up to 1/(T η) gradient norm in 1 metric, while gradient descent can only be guaranteed to find a point with up to 1/ √ T η gradient norm in 2 metric. More specifically, let be a sufficiently small quantity and ignoring other problem parameters, we can set η = O( ) for Adam and η = O( 2 ) for GD, then Adam and GD will need T = O( 2 ) and T = O( 4 ) to find a first-order -stationary point. This suggests that Adam could enjoy a faster convergence rate compared to SGD in the training process, which is consistent with the practice findings. We would also like to point out that there is no contradiction between our result and the recent work (Reddi et al., 2019) showing that Adam can fail to converge, as the counterexample in Reddi et al. ( 2019) is for the online version of Adam, while we study the full batch Adam. In terms of the test performance, their generalization abilities are largely different, even with weight decay regularization. In particular, the output of gradient descent can generalize well and achieve nearly zero test error, while the output of Adam gives nearly 1/2 test error. In fact, this gap is due to two major aspects of the training process: (1) At the early stage of training where weight decay exhibits negligible effect, Adam and GD behave very differently. In particular, Adam prefers the denser and thus tends to fit the noise vectors ξ, gradient descent prefers the data patch of larger 2 norm and thus will learn the feature patch; (2) At the late stage of training where the weight decay regularization cannot be ignored, both Adam and gradient descent will be enforced to converge to a local minimum of the regularized objective, which maintains the pattern learned in the early stage. Consequently, the model learned by Adam will be biased towards the noise patch to fit the feature noise vector -αyv, which is opposite in direction to the true feature vector and therefore leads to a test error no better than a random guess. More details about the training behaviors of Adam and GD are given in Section 5. Experimental justification are provided in Appendix. Theorem 4.1 shows that when optimizing a nonconvex training objective, Adam and gradient descent will converge to different global solutions with different generalization errors, even with weight decay regularization. In comparison, the following theorem gives the learning guarantees of Adam and gradient descent when optimizing convex and smooth training objectives (e.g., linear model F (w, x) = w x with logistic loss). Theorem 4.2 (Convex setting). For any convex and smooth training objective with positive regularization parameter λ, suppose we run Adam and gradient descent for T = poly(n) η iterations, then with probability at least 1 -n -1 , the obtained parameters W * Adam and W * GD satisfy that ∇L(W * Adam ) 1 ≤ 1 T η and ∇L(W * Adam ) 2 2 ≤ 1 T η respectively. Moreover, let F (W, x) ∈ R be the output of the convex model with parameter W and input x, it holds that: • Training errors are the same, 1 n n i=1 1 sgn F (W * Adam , x i ) = sgn F (W * GD , x i ) = 0. • Test errors are nearly the same: P (x,y)∼D sgn F (W * Adam , x i ) = sgn F (W * GD , x i ) ≤ 1 poly(n) . Theorem 4.2 shows that when optimizing a convex and smooth training objective (e.g., a linear model with logistic loss) with weight decay regularization, both Adam and gradient can converge to almost the same solution and enjoy very similar generalization performance. The proof will be relying on the strong convexity of the training objective and the convergence (to the first-order stationary) guarantee of Adam (Défossez et al., 2020) and GD. Combining this result and Theorem 4.1, it is clear that the inferior generalization performance is closely tied to the nonconvex landscape of deep learning, and cannot be understood by standard weight decay regularization.

5. PROOF OUTLINE OF THE MAIN RESULTS

In this section we provide the proof sketch of Theorem 4.1 and explain the different generalization abilities of the models found by gradient descent and Adam. Before moving to the proof of main results, we first give the following lemma which shows that for data generated from the data distribution D in Definition 3.1, with high probability all noise vectors {ξ i } i=1,...,n have nearly disjoint supports. Lemma 5.1. Let {(x i , y i )} i=1,...,n be the training dataset generated by Definition 3.1. Moreover, recall that x i = [y i v , ξ i ] (or x i = [ξ i , y i v ] ), let B i = supp(ξ i )\{1} be the support of ξ i except the first coordinate. Then with probability at least 1 -n -2 , B i ∩ B j = ∅ for all i = j. This lemma implies that the optimization of each coordinate of the model parameter W, except for the first one, is mostly determined by only one training data. Technically, this lemma can greatly simplify the analysis for Adam so that we can better illustrate its optimization behavior and explain the generalization performance gap between Adam and gradient descent. Proof outline. For both Adam and gradient descent, we will show that the training process can be decomposed into two stages. In the first stage, which we call pattern learning stage, the weight decay regularization will be less important and can be ignored, while the algorithms tend to learn the pattern from the training data. In particular, we will show that in the pattern learning stage, the optimization algorithms have different algorithmic bias: Adam tends to fit the noise patch while gradient descent will mainly learn the feature patch. In the second stage, which we call it regularization stage, the effect of regularization cannot be neglected, which will regularize the algorithm to converge at some local stationary points. However, due to the nonconvex landscape of the training objective, the pattern learned in the first stage will remain unchanged, even when running an infinitely number of iterations.

5.1. PROOF SKETCH FOR ADAM

Recall that in each iteration of Adam, the model weight is updated by using a moving-averaged gradient, normalized by a moving average of the historical gradient squares. As pointed out in Balles & Hennig (2018) ; Bernstein et al. (2018) , Adam behaves similarly to sign gradient descent (signGD) when using sufficiently small step size or the moving average parameters β 1 , β 2 are nearly zero, which is also justified in our Lemma C.2. Specifically, we show that when considering constant β 1 and β 2 , the Adam update on the coordinates with large gradient (e.g., |∇L(W (t) )[k]| > η) can be well approximated by the signGD update (i.e., sign(∇L(W (t) )[k])). This motivates us to understand the optimization behavior of signGD and then extends it to Adam using their similarities. In particular, signGD updates the model parameter according to the following rule: w (t+1) j,r = w (t+1) j,r -η • sgn(∇ wj,r L(W (t) )). Recall that each data has a feature patch and a noise patch. By Lemma 5.1 and the data distribution (see Definition 3.1), all noise vectors {ξ i } i=1,...,n are supported on disjoint coordinates except the first coordinate. For x i , let B i denote its support excluding the first coordinate. In the subsequent analysis, we will always assume that those B i 's are disjoint, i.e., B i ∩ B j = ∅ if i = j. Next we will characterize two aspects of the training process: feature learning and noise memorization. Mathematically, we focus on two quantities: w (t) j,r , j • v and w (t) yi,r , ξ i . In particular, given the training data point (x i , y i ) with x i = [y i v , ξ i ] , larger w (t) yi,r , y i • v implies better feature learning and larger w (t) yi,r , ξ i represents better noise memorization. Then regarding the feature vector v that only has nonzero entry at the first coordinate, we have the following for signGD: w (t+1) j,r , jv = w (t) j,r , jv -η • sgn ∇ wj,r L(W (t) ) , jv (5.1) = w (t) j,r , jv + jη • sgn n i=1 y i (t) j,i σ ( w (t) j,r , y i v ) -ασ ( w (t) j,r , ξ i ) -nλw (t) j,r [1] , where ,x i ) . From (5.1) we can observe three terms in the signed gradient. Specifically, the first term represents the gradient over the feature patch, the second term stems from the feature noise term in the noise patch (see Definition 3.1), and the last term is the gradient of the weight decay regularization. On the other hand, the memorization of the noise vector ξ i can be described as follows. (t) j,i := 1 yi=j -logit j (F, x i ) and logit j (F, x i ) = e F j (W,x i ) k∈{-1,1} e F k (W w (t+1) yi,r , ξ i -w (t) yi,r , ξ i = -η • sgn ∇ wy i ,r L(W (t) ) , ξ i (5.2) = η k∈Bi∪{1} sgn (t) yi,i σ ( w (t) yi,r , ξ i )ξ i [k] -nλw (t) yi,r [k] • ξ i [k]. Throughout the proof, we will show that the training process of Adam can be decomposed into two stages: pattern learning stage and regularization stage. In the first stage, the algorithm learns the pattern of training data quickly, without being affected by the regularization term. In the second stage, the training data has already been correctly classified since the pattern has been well captured, the regularization will play an important role in the training process and guide the model to converge. Stage I: Learning the pattern. At the beginning of training, the neural network output is smaller than some constant for all data, and therefore all training data remain under-fitted and can provide large gradient for model training. We specify this stage of training as Stage I. In this stage, the effect of weight decay regularization can be ignored due to our choice of λ. We will show that in this stage the inner product w (t) yi,r , ξ i grows much faster than w (t) j,r , jv since feature learning only makes use of the first coordinate of the gradient, while noise memorization could take advantage of all the coordinates in B i (see (5.2), note that |B i | = s 1). Lemma 5.2 (General results in Stage I). Suppose the training data is generated according to Defi- nition 3.1, assume λ = o(σ q-2 0 σ p /n) and η = 1/poly(d), then for any t ≤ T 0 with T 0 = O 1 ηsσp and any i ∈ [n], w (t+1) j,r , j • v ≤ w (t) j,r , j • v + Θ(η), w (t+1) yi,r , ξ i = w (t) yi,r , ξ i + Θ(ηsσ p ). Since w (t) j,r , ξ i enjoys a much faster increasing rate than that of w (t) j,r , j • v , after a certain number of iterations, the learning of noise patch will dominate the learning of feature patch (i.e., ασ ( w (t) j,r , ξ i ) > σ ( w (t) j,r , y i v )) . Thus, by (5.1), the model will tend to fit the feature noise in the noise patch (i.e., -αy i v), leading to a flipped feature learning phenomenon. Lemma 5.3 (Flipping the feature learning). Suppose the training data is generated according to Definition 3.1, α ≥ Θ (sσ p ) 1-q ∨ σ q-1 0 and σ 0 < O((sσ p ) -1 ), then for any t ∈ [T r , T 0 ] with T r = O σ0 ηsσpα 1/(q-1) ≤ T 0 , w (t+1) j,r , j • v = w (t) j,r , j • v -Θ(η). Moreover, it holds that (1) w (T0) j,r [1] = -sgn(j) • Ω 1 sσp ; (2) w (T0) j,r [k] = sgn(ξ i [k]) • Ω 1 sσp or w (T0) j,r [k] = ± O(η) for k ∈ B i with y i = j; and (3) w (T0) j,r [k] = ± O(η) otherwise. From Lemma 5.3 it can be observed that at the iteration T 0 , the sign of the first coordinate of w (T0) j,r is different from that of the true feature, i.e., j • v. This implies that at the end of the first training stage, the model is biased towards the noise patch to fit the feature noise. Stage II: Regularizing the model. In this stage, as the neural network output becomes larger, part of training data starts to be well fitted and gives smaller gradient. As a consequence, the feature learning and noise memorization processes will be slowed down and the weight decay regularization term cannot be ignored. However, although weight decay regularization can prevent the model weight from being too large, it will maintain the pattern learned in Stage I and cannot push the model back to "forget" the noise and learn the feature and stops at some local stationary points. We summarize these results in the following lemma. Lemma 5.4 (Maintain the pattern ). If α = O sσ 2 p /n and η = o(λ), then let r * = arg max r∈[m] w (t) yi,r , ξ i , for any t ≥ T 0 , i ∈ [n], j ∈ {±1} and r ∈ [m], it holds that w (t) yi,r * , ξ i = Θ(1), k∈Bi |w (t) yi,r * [k]| • |ξ i [k]| = Θ(1), w (t) j,r , sgn(j) • v ∈ [-o(1), O(λ -1 η)]. Lemma 5.4 shows that in the second stage, w (t) yi,r , ξ i will always be large while w (t) yi,r , y i • v is still negative, or positive but extremely small. Next we will show that within polynomial steps, the algorithm can be guaranteed to find a point with small gradient. Lemma 5.5 (Convergence guarantee). If η = O(d -1/2 ), then for any t it holds that L(W (t+1) ) -L(W (t) ) ≤ -η ∇L(W (t) ) 1 + Θ(η 2 d). Lemma 5.5 shows that we can pick a sufficiently small η and T = poly(n)/η to ensure that the algorithm can find a point with up to O(1/(T η)) in 1 norm. Then we can show that given the results in Lemma 5.4, the formula of the algorithm output W * can be precisely characterized, which we can show that w * yi,r , y i • v < 0. This implies that the output model will be biased to fit the feature noise -αyv but not the true one v. Then when it comes to a fresh test example the model will fail to recognize its true feature. Also note that the noise in the test data is nearly independent of the noise in training data. Consequently, the model will not be able to identify the label of the test data and therefore cannot be better than a random guess.

5.2. PROOF SKETCH FOR GRADIENT DESCENT

Similar to the proof for Adam, we also decompose the entire training process into two stages. Stage I: Learning the pattern. In this stage the gradient from training loss function is large and and the effect of regularization can be ignored. Unlike Adam that is sensitive to the sparsity of the feature vector or noise vector, gradient descent is more focusing on the 2 norm of them, where the vector (which can be either feature or noise) with larger 2 norm is more likely to be discovered and learned by GD. Note that the feature vector has a larger 2 norm than the noise, we can show that gradient descent will learn the feature vector very quickly, while barely memorize the noise. 1), then we have Lemma 5.6. Let Λ (t) j = max r∈[m] w (t+1) j,r , j • v , Γ (t) j,i = max r∈[m] w (t) j,r , ξ i , and Γ (t) j = max i:yi=j Γ (t) j,i . Let T j be the iteration number that Λ (t) j reaches Θ(1/m) = Θ( T j = Θ(σ 2-q 0 ) for all j ∈ {-1, 1}. Moreover, let T 0 = max j {T j }, then for all t ≤ T 0 it holds that Γ (t) j = O(σ 0 ) for all j ∈ {-1, 1}. Stage II: Regularizing the model. Similar to Lemma 5.4, we show that in the second stage at which the impact of weight decay regularization cannot be ignored, the pattern of the training data learned in the first stage will remain unchanged. Lemma 5.7. If η ≤ O(σ 0 ), it holds that Λ (t) j = Θ(1) and Γ (t) j = O(σ 0 ) for all t ≥ min j T j . The following lemma further shows that within polynomial steps, gradient descent is guaranteed to find a point with small gradient. Lemma 5.8. If the learning rate satisfies η = o(1), then for any t ≥ 0 it holds that L(W (t+1) ) -L(W (t) ) ≤ - η 2 ∇L(W (t) ) 2 F . Lemma 5.8 shows that we can pick a sufficiently small η and T = poly(n)/η to ensure that gradient descent can find a point with up to O(1/(T η) 1/2 ) in 2 norm. By Lemma 5.7, it is clear that the output model of GD can well learn the feature vector while memorizing nearly nothing from the noise vectors, which can therefore achieve nearly zero test error. Experiments. We perform experiments on synthetic data (generated according to Definition 3.1) to validate our theoretical findings: Adam performs stronger noise memorization than feature learning while GD performs stronger feature learning than noise memorization, when conducted on the data distribution constructed in Definition 3.1. We consider both the two-layer CNN model studied in this paper and a 5-layer CNN model for further justification. Experimental setup and results are deferred to Appendix A due to the page limit.

6. CONCLUDING REMARKS AND FUTURE WORK

In this paper, we study the generalization of Adam and compare it with gradient descent. We show that when training neural networks, Adam and GD starting from the same initialization can converge to different global solutions of the training objective with significantly different generalization errors, even with proper regularization. Our analysis reveals the fundamental difference between Adam and GD in learning features or noise, and demonstrates that this difference is closely tied to the nonconvex landscape of neural networks. We would also like to remark several important research directions. First, our current result is for two-layer networks. Extending the results to deep networks could be an important next step, where we will not only look at the input data but also consider the output of each intermediate layer as "input". Second, our current data model is motivated by the image data (i.e., sparse feature and denser noise), where Adam has been observed to perform worse than SGD in terms of generalization. In fact, our theoretical analysis can lead to an opposite conclusion on the generalization comparison between Adam and GD if the noise is sparse and the feature is denser. Therefore, it would also be interesting to explore whether this is the case in other machine learning tasks such as natural language processing, where Adam is often observed to perform better than SGD. 

A EXPERIMENTS

A.1 EXPERIMENT DETAILS FOR FIGURE 1 The experiments in Two-layer CNN model. We first consider the exact two-layer CNN model studied in the paper. We set the network width m = 20, activation function σ(z) = max{0, z} 3 , total iteration number T = 1×10 4 , and the learning rate η = 5×10 -5 for Adam (default choices of β 1 and β 2 in pytorch), η = 0.02 for GD. We first report the training error and test error achieved by the solutions found by SGD and Adam in 2 (a), it can be seen that the algorithm will perform feature learning in the first few iterations and then entirely forget the feature (but fit feature noise), i.e., the feature learning is flipped, which verifies Lemma 5.3. In the meanwhile, the noise memorization happens in the entire training process and enjoys much faster rate than feature learning, which verifies Lemma 5.2. In addition, we can also observe that there are two stages for the increasing of min i max r w 1,r , ξ i : in the first stage min i max r w 1,r , ξ i increases linearly, and in the second stage its increasing speed gradually slows down and min i max r w 1,r , ξ i will remain in a constant order. This verifies Lemma 5.2 and Lemma 5.4. For GD, from Figure 2 (b), it can be seen that the feature learning will dominate the noise memorization: feature learning will increase to a constant in the first stage and then remain in a constant order in the second stage; noise memorization will keep in a low level which is nearly the same as that at the initialization. This verifies Lemmas 5.6 and 5.7. 5-Layer CNN model. We further perform numerical simulations for the deep neural network models. In particular, we consider a 5-layer CNN model: the first layer is exactly the same as the twolayer CNN model, followed by a 4-layer MLP with ReLU activation. The number of neurons is set as m = 20 for all layers. The total iteration number is T = 1×10 4 , the learning rate is η = 5×10 -5 for Adam and η = 0.1 for GD. We then show the feature learning and noise memorization of the first layer of this neural network model by calculating the inner products: max r w 1,r , v and min i max r w 1,r , ξ i , where w 1,r denotes the weights of the r-th neuron in the first layer. The results are shown in Figure 3 . It can be observed that, when applied on the data distribution in Definition 3.1, Adam tends to perform stronger noise memorization than feature learning, while GD performs stronger feature learning and nearly negligible noise memorization. Moreover, we also visualize the first layer of the 5-layer CNN trained by Adam and GD in Figure 4 . It can be seen that the CNN model found by Adam is clearly more "noisy" than that found by GD. This is consistent with our theoretical findings and empirical observation on the real-world dataset (i.e., Figure 1 ).

B EXTENSIONS TO MINI-BATCH STOCHASTIC GRADIENTS

One natural extension of our paper is proving the separation between mini-batch SGD (without replacement) and mini-batch Adam, which we believe is not difficult. In particular, let I t of size B be the set of indices of the mini-batch data used in the t-th iteration, the update rule of SGD is The update rule of mini-batch Adam is w (t+1) j,r = w (t) j,r -η • 1 B i∈It ∇ wj,r L i (W (t) ) -λw (t) j,r . m (t+1) j,r = β 1 m (t) j,r + (1 -β 1 ) • 1 B i∈It ∇ wj,r L i (W (t) ) + λw (t) j,r , v (t+1) j,r = β 2 v (t) j,r + (1 -β 2 ) • 1 B i∈It ∇ wj,r L i (W (t) ) + λw (t) j,r 2 , w (t+1) j,r = w (t) j,r -η • m (t) j,r / v (t) j,r + . Here the bias term is set as = Θ(λσ 0 ). We claim that this parameter is introduced to guarantee that the regularization term will not dominate the training process when using stochastic gradients in Adam. Then we will take a deeper look at the speeds of feature learning and noise learning for mini-batch SGD and Adam, where we focus on the period that | w (t) j,r , v |, | w (t) j,r , ξ i | = o(1) for all j, i, and r (i.e., the pattern learning stage). This further implies that | (t) j,i | = 0.5 ± o(1) for all j, i, and t. Thus in the following, we will assume that all | (t) j,i | has nearly the same quantity. Feature Learning. First, according to Definition 3.1, we know that the feature vector v and feature noise are the same for all data, which implies that the learning pattern of the feature coordinate will be largely the same as that of full-batch algorithms. In particular, for mini-batch Adam, we can show that the update of the first coordinate (i.e., feature coordinate) is similar to sign-GD when using sufficiently small learning rate η = 1/poly(d) since all stochastic gradients ∇L i (W (t) ) have the same component in this coordinate. Then using the fact that | (t) j,i |'s are nearly the same for all i, we have w (t+1) j,r , jv ∼ w (t) j,r , jv + jη • sgn n i=1 y i (t) j,i σ ( w (t) j,r , y i v ) -ασ ( w (t) j,r , ξ i ) -nλw (t) j,r [1] . which is the same as full-batch Adam (see (5.1)). For SGD, using the fact that | (t) j,i |'s are nearly the same for all i, we can get that w (t+1) j,r , jv ∼ (1 -ηλ) • w (t) j,r , j • v + η n • j • n i=1 y i (t) j,i σ ( w (t) j,r , y i v ) -α n i=1 y i (t) j,i σ ( w (t) j,r , ξ i ) which is also the same as that of GD (see (C.28)). Noise Memorization. Note that due to the normalization term v (t) j,r in the Adam update, all coordinates will be updated with nearly the same amount. Therefore, we only need to count the number of coordinates that are updated by full-batch Adam and mini-batch Adam. Recall that we have shown that using mini-batch gradients will not affect the feature learning. However, the noise memorization will be slightly different, since in each iteration, full-batch Adam can update Θ(ns) coordinates while mini-batch Adam can only update Θ(Bs) coordinates. To show this, we note that for any coordinate k = 1, the gradient momentum of full-batch Adam is m (t) j,r [k] ∼ τ τ =0 β τ 1 (1 -β 1 ) • 1 n i∈[n] ∇ wj,r L i (W (t-τ ) )[k] + λw (t-τ ) j,r [k] , while for mini-batch Adam, m (t) j,r [k] ∼ τ τ =0 β τ 1 (1 -β 1 ) • 1 B i∈It-τ ∇ wj,r L i (W (t-τ ) )[k] + λw (t-τ ) j,r [k] , where we only maintain the recent τ = polylog(n) gradients since for τ ≤ t -τ , the decaying terms (β 1 ) τ ≤ (β i ) τ becomes negligible. Therefore, by comparing the above two equations and applying Definition 3.1, it is clear that for full-batch Adam can update all noise coordinates, i.e., k ∈ ∪ i∈[n] B i , which is of size Θ(ns). In contrast, mini-batch Adam can only update a subset of noise coordinates, i.e., k ∈ ∪ τ ∈[τ ] ∪ i∈[It-τ ] B i , which is of size τ Bs = Θ(Bs) . This further implies that in each epoch (one pass of the data, Θ(n/B) steps), the noise coordinates in B i will be updated by mini-batch Adam in at most τ = Θ(1) steps, while within the same amount of iterations, the noise coordinates in B i will be updated by full-batch Adam for Θ(n/B) steps, suggesting that mini-batch Adam admits a slower rate of noise memorization by a Θ(n/B) factor. For SGD, it is easy to show that the rate of noise memorization will still be nearly the same as that of GD. In particular, during each training epoch (Θ(n/B) steps), SGD will learn the noise vector ξ i in only one step with the mini-batch gradient 1 B ∇L i (W (τ ) ) for some τ in this epoch, while within the same amount of steps, GD will learn the noise vector ξ i in all Θ(n/B) steps but with strength 1 n ∇L i (W (τ ) ), giving the same total learning ability. This suggests that SGD admits a nearly the same rate of noise memorization compared to GD. Overall, we are able to deliver the following lemmas that characterize the feature learning and noise memorization of SGD and stochastic gradient Adam in training Stage I. Lemma B.1 (SGD, Informal). Suppose the training data is generated according to Definition 3.1, then given proper configurations of λ and and sufficiently small learning rate η, define Λ (t) j = max r∈[m] w (t) j,r , j • v and Γ (t) j,i = max r∈[m] w (t) j,r , ξ i , for any t 0 satisfying Λ (t0) j , Γ (t0) j,i = o(1/polylog(n)), we have the following one-epoch update of feature learning and noise memorization for SGD  Λ (t0+ n/B ) j ≥ Λ (t0) j + η • Θ n B • (Λ (t0) j ) q-1 Γ (t0+ n/B ) j ≤ Γ (t0) j + η • Θ ηsσ 2 p B • (Γ (t0) j ) q-1 . (t0+ n/B ) j,r , j • v ≤ w (t0) j,r , j • v + Θ(nη/B), w (t0+ n/B ) yi,r , ξ i = w (t0) yi,r , ξ i + Θ(ηsσ p ). To sum up, we have shown that (1) mini-batch SGD and mini-batch Adam will not change the learning speed of feature vector v compared to their full-batch counterparts, i.e., Lemma C.3 and (C.31) (needs to covert to n/B iterations); (2) mini-batch Adam reduces the noise memorization speed of full-batch Adam by a Θ(n/B) factor, while mini-batch SGD has nearly the same noise memorization speed compared to full-batch GD, by comparing to Lemma C.3 and (C.32)). Additionally, recall that in our paper, the separation between Adam and GD is characterized by a poly(d) factor: the speed of feature learning in Adam and GD, and the rate of noise memorization in GD are both in the order of O(η) (in each step), while the rate of noise memorization in Adam is proportional to the number of nonzero entries, which is in the order of η • poly(d). Therefore, the separation between mini-batch SGD and mini-batch Adam in terms of the generalization error can still hold under a stronger over-parameterization condition sσ p = Θ d 1/4 /(npolylog(n)) = ω(n/B) (in contrast, the over-parameterization condition for full-batch Adam is sσ p = ω(1)). C PROOF OF THEOREM 4.1: NONCONVEX CASE In the beginning of the proof we first present the following useful lemma.

C.1 PRELIMINARIES

We first recall the magnitude of all parameters: d = poly(n), η = 1 poly(n) , s = Θ d 1/2 n 2 , σ 2 p = Θ 1 s • polylog(n) , σ 2 0 = Θ 1 d 1/2 , m = polylog(n), α = Θ σ p • polylog(n) , λ = O 1 d (q-1)/4 n • polylog(n) . Here poly(n) denotes a polynomial function of n with degree of a sufficiently large constant, poly(n) denotes a polynomial function of log(n) with degree of a sufficiently large constant. Based on the parameter configuration, we claim that the following equations hold, which will be frequently used in the subsequent proof. λ = o σ q-2 0 σ p n , α = ω (sσ p ) 1-q σ q-1 0 , σ 0 = o 1 sσ p , α = o sσ 2 p n , η = o λσ q 0 σ q p . Lemma C.1 (Non-overlapping support). Let {(x i , y i )} i=1,...,n be the training dataset sampled according to Definition 3.1. Moreover, let B i = supp(ξ i )\{1} be the support of x i except the first coordinatefoot_1 . Then with probability at least 1 -n -2 , B i ∩ B j = ∅ for all i, j ∈ [n]. Proof of Lemma C.1. For any fixed k ∈ [n] and j ∈ supp(ξ k )\{1}, by the model assumption we have P{(ξ i ) j = 0} = s/(d -1), for all i ∈ [n]\{k}. Therefore by the fact that the data samples are independent, we have P(∃i ∈ [n]\{k} : (ξ i ) j = 0) = 1 -[1 -s/(d -1)] n . Applying a union bound over all k ∈ [n] and j ∈ supp(ξ k )\{1}, we obtain P(∃k ∈ [n], j ∈ supp(ξ k )\{1}, i ∈ [n]\{k} : (ξ i ) j = 0) ≤ n • s • {1 -[1 -s/(d -1)] n }. (C.1) By the data distribution assumption we have s ≤ √ d/(2n 2 ), which clearly implies s/(d-1) ≤ 1/2. Therefore we have n • s • [1 -(1 -s/d) n ] = n • s • {1 -exp[n log(1 -s/(d -1))]} ≤ n • s • [1 -exp(n • 2s/(d -1))] ≤ n • s • [1 -exp(n • 4s/d)] ≤ n • s • (4ns/d) = 4n 2 s 2 /d ≤ n -2 , where the first inequality follows by the inequalities log(1 -z) ≥ -2z for z ∈ [0, 1/2], the second inequality follows by s/(d-1) ≥ 2s/d, the third inequality follows by the inequality 1-exp(-z) ≤ z for z ∈ R, and the last inequality follows by the assumption that s ≤ √ d/(2n 2 ). Plugging the bound above into (C.1) finishes the proof.

C.2 PROOF FOR ADAM

Before moving to the detailed proof, we first state the update rules of feature learning and noise memorization when the sign gradient is applied. w (t+1) j,r , jv = w (t) j,r , jv -η • sgn ∇ wj,r L(W (t) ) , jv = w (t) j,r , jv + jη • sgn n i=1 y i (t) j,i σ ( w (t) j,r , y i v ) -ασ ( w (t) j,r , ξ i ) -nλw (t) j,r [1] , (C.2) where (t) j,i := 1 yi=j -logit j (F, x i ) and logit j (F, x i ) = e F j (W,x i ) k∈{-1,1} e F k (W,x i ) . From (C.2) we can observe three terms in the signed gradient. Specifically, the first term represents the gradient over the feature patch, the second term stems from the feature noise term in the noise patch (see Definition 3.1), and the last term is the gradient of the weight decay regularization. On the other hand, the memorization of the noise vector ξ i can be described by the following update rule, w (t+1) yi,r , ξ i = w (t) yi,r , ξ i -η • sgn ∇ wy i ,r L(W (t) ) , ξ i = w (t) yi,r , ξ i + η • k∈Bi sgn (t) yi,i σ ( w (t) yi,r , ξ i )ξ i [k] -nλw (t) yi,r [k] , ξ i [k] -αy i η • sgn n i=1 y i (t) yi,i σ ( w (t) yi,r , y i v ) -ασ ( w (t) yi,r , ξ i ) -nλw (t) yi,r [1] . (C.3) In this subsection we first provide the following lemma that shows for most of the coordinate (with slightly large gradient), the Adam update is similar to signGD update (up to some constant factors). In the remaining proof for Adam, we will largely apply this lemma to get a signGD-like result for Adam (similar to the technical lemmas in Section 5). Besides, the proofs for all lemmas in Section 5 can be viewed as a simplified version of the proofs for technical lemmas for Adam, thus are omitted in the paper. Lemma C.2 (Closeness to SignGD). Recall the update rule of Adam, let W (t) be the t-th iterate of the Adam algorithm. Suppose that w (t) j,r , v , w (t) j,r , ξ i = Θ(1) for all j ∈ {±1} and r ∈ [m]. Then if β 2 ≥ β 2 1 , we have • For all k ∈ [d], m (t) j,r [k] v (t) j,r [k] ≤ Θ(1). • For every k / ∈ ∪ n i=1 B i (including k = 1) we have either |∇ wj,r L(W (t) )[k]| ≤ Θ(η) or m (t) j,r [k] v (t) j,r [k] = sgn ∇ wj,r L(W (t) )[k] • Θ(1). • For every k ∈ B i , we have |∇ wj,r L(W (t) )[k]| ≤ Θ ηn -1 sσ p | (t) j,i | ≤ Θ(ηsσ p ) or m (t) j,r [k] v (t) j,r [k] = sgn ∇ wj,r L(W (t) )[k] • Θ(1). Proof. First recall that the gradient ∇ wj,r L(W (t) ) can be calculated as ∇ wj,r L(W (t) ) = - 1 n n i=1 y i (t) j,i σ ( w (t) j,r , y i v ) • v + n i=1 (t) j,i • σ ( w (t) j,r , y i ξ i ) • ξ i + λw (t) j,r . More specifically, for the first coordinate of ∇ wj,r L(W (t) ), we have ∇ wj,r L(W (t) )[1] = - 1 n n i=1 y i (t) j,i σ ( w (t) j,r , y i v ) -α n i=1 y i (t) j,i • σ ( w (t) j,r , ξ i ) + λw (t) j,r [1]. (C.4) For any k ∈ B i , by Lemma C.1 we know that the gradient over this coordinate only depends on the training data ξ i , therefore, we have ∇ wj,r L(W (t) )[k] = - 1 n (t) j,i σ ( w (t) j,r , ξ i )ξ i [k] + λw (t) j,r [k]. (C.5) For the remaining coordinates, we have ∇ wj,r L(W (t) )[k] = λw (t) j,r [k]. (C.6) Now let us focus on the moving averaged gradient m (t) j,r and squared gradient v (t) j,r . We first show that for all k ∈ [d], it holds that m (t) j,r [k] v (t) j,r [k] ≤ Θ(1). (C.7) By the update rule of m (t) j,r , we have m (t) j,r [k] = β 1 m (t-1) j,r [k] + (1 -β 1 ) • wj,r L(W (t) )[k] = t τ =0 β τ 1 (1 -β 1 ) • ∇ wj,r L(W (t-τ ) )[k]. Similarly, we also have v (t) j,r [k] = t τ =0 β τ 2 (1 -β 2 ) • ∇ wj,r L(W (t-τ ) )[k] 2 . Then by Cauchy-Schwartz inequality we have m (t) j,r [k] 2 ≤ t τ =0 [β τ 1 (1 -β 1 )] 2 α 2 τ • ∇ wj,r L(W (t-τ ) )[k] 2 • t τ =0 α 2 τ . Let α 2 τ = [β τ 1 (1-β1)] 2 β τ 2 (1-β2) , which forms an exponentially decaying sequence if β 2 ≥ β 2 1 . Therefore, we have t τ =0 α 2 τ = Θ(1) and the above inequality implies that 1), which proves (C.7). Now we are going to prove the main argument of this lemma. Note that m (t) j,r , which is a weighted average of all historical gradients, where the weights decay exponentially fast, then we can take on a threshold τ = polylog(η -1 ) such that m (t) j,r [k] 2 ≤ v (t) j,r [k] • Θ( t τ =τ β τ 1 (1 -β 1 ) = 1 poly(η -1 ) . Then for each k ∈ [d] we have m (t) j,r [k] = τ τ =0 β τ 1 (1 -β 1 ) • ∇ wj,r L(W (t-τ ) )[k] + t τ =τ β τ 1 (1 -β 1 ) • ∇ wj,r L(W (t-τ ) )[k] = τ τ =0 β τ 1 (1 -β 1 ) • ∇ wj,r L(W (t-τ ) )[k] ± 1 poly(η -1 ) , where in the last equality we use the fact that |∇ wj,r L( W (t-τ ) )[k]| = O(1) for all k ∈ [d]. Similarly, we can also have the following on v (t) j,r , v (t) j,r [k] = τ τ =0 β τ 2 (1 -β 2 ) • ∇ wj,r L(W (t-τ ) )[k] 2 ± 1 poly(η -1 ) . Here we slightly abuse the notation by using the same τ . Then we have m (t) j,r [k] v (t) j,r [k] = τ τ =0 β τ 1 (1 -β 1 ) • ∇ wj,r L(W (t-τ ) )[k] ± 1 poly(η -1 ) τ τ =τ β τ 2 (1 -β 2 ) • ∇ wj,r L(W (t-τ ) )[k] 2 ± 1 poly(η -1 ) . In order to prove the main argument of this lemma, the key is to show that within τ iterations, the gradient ∇ wj,r L(W (t) )[k] barely changes. In particular, by (C.7), we have the update of each coordinate in one step is at most Θ(η). This implies that w (t) j,r , v -w (τ ) j,r , v ≤ Θ(ητ ), w (t) j,r , ξ i -w (τ ) j,r , ξ i ≤ Θ(ητ sσ p ), |w (t) j,r [k] -w (τ ) j,r [k]| ≤ Θ(ητ ). Then applying the fact that | w 1), we further have (τ ) j,r , v | ≤ Θ(1) and | w (τ ) j,r , ξ i | ≤ Θ( F j (W (τ ) , x i ) -F j (W (t) , x i ) ≤ Θ(mητ sσ p ) = Θ(ητ sσ p ), where we use the fact that m = Θ(1) and sσ p = ω(1). Then it holds that (τ ) j,i = e Fj (W (τ ) ,xi) k∈{-1,1} e F k (W (τ ) ,xi) ≤ e Fj (W (t) ,xi)+ Θ(η τ sσp) e Fj (W (t) ,xi)+ Θ(η τ sσp) + e F-j (W (t) ,xi)-Θ(η τ sσp) = sgn( (t) j,i ) • Θ(| (t) j,i |), where we use the fact that Θ(ητ sσ p ) = o(1). Similarly, we can also show that (τ ) j,i ≥ sgn( (t) j,i ) • Θ(| (t) j,i |), which further implies (τ ) j,i = sgn( (t) j,i ) • Θ(| (t) j,i |) for all τ ∈ [t -τ , t]. Note that | (τ ) j,i | ≤ 1, then it holds that (τ ) j,i σ ( w (τ ) j,r , v ) = sgn( (t) j,i ) • Θ(| (t) j,i |) • σ ( w (τ ) j,r , v ) ≤ sgn( (t) j,i ) • Θ(| (t) j,i |) • σ ( w (t) j,r , v ) + Θ(| (t) j,i |) • Θ(ητ ). We can also similarly derive the following (τ ) j,i σ ( w (τ ) j,r , v ) ≥ sgn( (t) j,i ) • Θ(| (t) j,i |) • σ ( w (t) j,r , v ) -Θ(| (t) j,i |) • Θ(ητ ), (τ ) j,i σ ( w (τ ) j,r , ξ i ) ≤ sgn( (t) j,i ) • Θ(| (t) j,i |) • σ ( w (t) j,r , ξ i ) + Θ(| (t) j,i |) • Θ(ητ sσ p ), (τ ) j,i σ ( w (τ ) j,r , ξ i ) ≥ sgn( (t) j,i ) • Θ(| (t) j,i |) • σ ( w (t) j,r , ξ i ) -Θ(| (t) j,i |) • Θ(ητ sσ p ). Combining the above results, applying (C.4), (C.5), and (C.6), we can show that for the first coordinate, we have ∇ wj,r L(W (τ ) )[1] = Θ ∇ wj,r L(W (t) )[1] ± Θ 1 n n i=1 | (t) j,i | • O(ητ ) ± Θ(λητ ); for any k ∈ B i , we have ∇ wj,r L(W (τ ) )[k] = Θ ∇ wj,r L(W (t) )[k] ± Θ | (t) j,i | n • O(ητ sσ p ) ± Θ(λητ ); and for remaining coordinates, we have ∇ wj,r L(W (τ ) )[k] = Θ ∇ wj,r L(W (t) )[k] ± Θ(λη τ ). Now we can plug the above results into the formula of m  m (t) j,r [k] v (t) j,r [k] = ∇ wj,r L(W (t) )[k] ± Θ(η) Θ |∇ wj,r L(W (t) )[k]| ± Θ(η)) . For k ∈ B i we have m (t) j,r [k] v (t) j,r [k] = ∇ wj,r L(W (t) )[k] ± Θ ηsσp| (t) j,i | n ± Θ(λη) Θ |∇ wj,r L(W (t) )[k]| ± Θ ηsσp| (t) j,i | n ± Θ(λη) . Then, we can conclude that for all k = 1 or k / ∈ B i for any i, we have either |∇ wj,r L(W (t) )[k]| ≤ Θ(η) or m (t) j,r [k] v (t) j,r [k] = sgn ∇ wj,r L(W (t) )[k] • Θ(1). For any k ∈ B i , we have either |∇ wj,r L(W (t) )[k]| ≤ Θ ηn -1 sσ p | (t) j,i | + λη or m (t) j,r [k] v (t) j,r [k] = sgn ∇ wj,r L(W (t) )[k] • Θ(1). This completes the proof. Lemma C.3 (Lemma 5.2, restated). Suppose the training data is generated according to Definition 3.1, assume λ = o(σ q-2 0 σ p /n) and η = 1/poly(d), then for any t ≤ T 0 with T 0 = O 1 ηsσp and any i ∈ [n], w (t+1) j,r , j • v ≤ w (t) j,r , j • v + Θ(η), w (t+1) yi,r , ξ i = w (t) yi,r , ξ i + Θ(ηsσ p ). Proof. At the initialization, we have | w (0) j,r , v | = Θ(σ 0 ), | w (0) j,r , ξ i | = Θ(s 1/2 σ p σ 0 + α) = Θ(s 1/2 σ p σ 0 ), w (0) j,r [k] = Θ(σ 0 ), which also imply that | (0) j,i | = Θ(1). Then recalling that λ = o(σ q-2 0 σ p /n), α = o(1), s 1/2 σ p = O(1), we have sgn n i=1 y i (0) j,i σ ( w (0) j,r , y i v ) -α n i=1 y i (0) j,i σ ( w (0) j,r , ξ i ) -nλw (0) j,r [1] = sgn j • Θ(nσ q-1 0 ) -j • Θ(αn(s 1/2 σ p σ 0 ) q-1 ) ± o σ q-1 0 σ p ) = sgn(j). Since v is 1-sparse, then by Lemma C.2, we have w (1) j,r , j • v ≤ w (0) j,r , j • v -η m (0) j,r / v (0) j,r , j • v ≤ w (0) j,r , j • v + Θ(η). Now suppose that the first inequality holds for iterations 0, . . . , t -1. Then we have w (t) j,r , j • v ≤ Θ(η • T 0 ) ≤ O(σ 0 ). Besides, note that (t) j,i = 1 j=yi -logit j (F (t) , x i ), we have sgn y i (t) j,i = sgn(j), where we recall that j ∈ {-1, 1}. Therefore, given that λ = o(σ q-2 0 σ p /n), α = o(1), s 1/2 σ p = O(1), and assume (t) j,i = Θ(1) (which will be verified later), sgn n i=1 y i (t) j,i σ ( w (t) j,r , y i v ) -α n i=1 y i (t) j,i σ ( w (t) j,r , ξ i ) -nλw (t) j,r [1] = sgn j • Θ(nσ q-1 0 ) -j • Θ(αn(s 1/2 σ p σ 0 ) q-1 ) ± o σ q-1 0 σ p ) = sgn(j). Since v is 1-sparse, then by Lemma C.2, the following inequality naturally holds, w (t+1) j,r , j • v ≤ w (t) j,r , j • v -η m (t) j,r / v (t) j,r , j • v ≤ w (t) j,r , j • v + Θ(η). Additionally, in terms of the memorization of noise, we first consider the iterate in the initialization. By the condition that η = o(1/d) = o(1/(sσ p )) and note that for a sufficiently large fraction of k ∈ B i (e.g., 0.99), we have |ξ i [k]| ≥ Θ(σ p ) ≥ Θ(ηn -1 sσ p | (0) j,i |) and thus sgn ∇ wy i ,r L(W (0) )[k] = sgn (0) yi,i σ ( w (0) yi,r , ξ i )ξ i [k] -nλw (0) yi,r [k] = -sgn Θ (d 1/2 σ p σ 0 ) q-1 σ p • sgn(ξ i [k]) ± o(σ q-1 0 σ p ) = -sgn(ξ i [k]). (C.8) Therefore, by Lemma C.2 we have the following according to (C.3), w (1) yi,r , ξ i = w (0) yi,r , ξ i -η m (0) j,r / v (0) yi,r , ξ i ≥ w (0) yi,r , ξ i + Θ(η) • k∈Bi sgn(ξ i [k]), ξ i [k] -O(ηsσ p ) -O(ηα) = w (0) yi,r , ξ i + Θ(ηsσ p ), where in the first inequality the term O(ηsσ p ) represents the coordinates that |ξ i [k]| ≤ O(σ p ) (so that we cannot use the sign information of ∇ yi,r L(W (0) ) but directly bound it by Θ(1)) and the last inequality is due to the fact that |B i | ≥ s -1 and α = o(1). For general t, we will consider the following induction hypothesis: w (t+1) yi,r , ξ i = w (t) yi,r , ξ i + Θ(ηsσ p ), (C.9) which has already been verified for t = 0. By Hypothesis (C.9), the following holds at time t, w (t) yi,r , ξ i = w (0) yi,r , ξ i + Θ(tηsσ p ) = Θ(s 1/2 σ p σ 0 + tηsσ p ). In the meanwhile, we have the following upper bound for |w (t) j,r [k]|, |w (t) j,r [k]| ≤ |w (t-1) j,r [k]| + η| sign(∇ wj,r L(W (t-1) ))| ≤ |w (0) j,r [k]| + tη = Θ(σ 0 + tη). (C.10) Besides, it is also easy to verify that for any t ≤ T 0 = Θ 1 sσpηm = Θ 1 sσpη , we have w (t) yi,r , ξ i , w (t) yi,r , j • v < Θ(1/m) and thus | (t) j,i | = Θ(1). Then similar to (C.8), we have sgn ∇ wy i ,r L(W (t) )[k] = sgn (t) yi,i σ ( w (t) yi,r , ξ i )ξ i [k] -nλw (t) yi,r [k] = -sgn Θ (s 1/2 σ p σ 0 + tηsσ p ) q-1 σ p • sgn(ξ i [k]) ± o σ q-2 0 σ p • (σ 0 + tη) = -sgn(ξ i [k]). (C.11) This further implies that w (t+1) yi,r , ξ i ≥ w (t) yi,r , ξ i -Θ(η) • k∈Bi sgn ∇ wy i ,r L(W (t) )[k] , ξ i [k] -O(η 2 s 2 σ 2 p ) -O(ηα) = w (t) yi,r , ξ i + Θ(ηsσ p ) , where the term -O(η 2 s 2 σ 2 p ) is contributed by the gradient coordinates that are smaller than Θ(ηsσ p ). This verifies Hypothesis (C.9) at time t and thus completes the proof. From Lemma C.3, note that sσ p = ω(1), then it can be seen that w (t) j,r , j • v increases much faster than w (t) j,r , j • v . By looking at the update rule of w (t) j,r , j • v (see (C.2)), it will keeps increasing only when, roughly speaking, σ ( w (t) j,r , j • v ) > ασ ( w (t) j,r , ξ i ). Since w (t) j,r , ξ i increases much faster than w (t) j,r , j • v , it can be anticipated after a certain number of iterations, w (t) j,r , j • v will start to decrease. In the following lemma, we provide an upper bound on the iteration number such that this decreasing occurs. Lemma C.4 (Lemma C.4, restated). Suppose the training data is generated according to Definition 3.1, α ≥ Θ (sσ p ) 1-q ∨ σ q-1 0 and σ 0 < O((sσ p ) -1 ), then for any t ∈ [T r , T 0 ] with T r = O σ0 ηsσpα 1/(q-1) ≤ T 0 , w (t+1) j,r , j • v = w (t) j,r , j • v -Θ(η). Moreover, it holds that w (T0) j,r [k] =      -sgn(j) • Ω 1 sσp , k = 1, sgn(ξ i [k]) • Ω 1 sσp or ± O(η), k ∈ B i , with y i = j, ± O(η), otherwise. Proof. Recall from Lemma C.3 that for any t ≤ T 0 we have w (t+1) j,r , j • v ≤ w (t) j,r , j • v + Θ(η) ≤ w (0) j,r , j • v + Θ(tη), w (t+1) ys,r , ξ s = w (t) ys,r , ξ s + Θ(ηsσ p ) ≤ w (0) ys,r , ξ s + Θ(tηsσ p ). Besides, by Lemma C.2 we also have |w (t) j,r [k]| ≤ |w (0) j,r [k]| + O(tη). Then it can be verified that for some T r = O σ0 ηsσpα 1/(q-1) , we have for all i ∈ [n] and t ∈ [T r , T 0 ] ασ ( w (t) yi,r , ξ i ) ≥ C • σ ( w (t) j,r , j • v ) + λn|w (t) j,r [1]| for some constant C. This further implies that sgn ∇ wj,r L(W (t) )[1] = -sgn n i=1 y i (t) j,i σ ( w (t) j,r , y i v ) -α n i=1 y i (t) j,i σ ( w (t) j,r , ξ i ) -nλw (t) j,r [1] = -sgn -α n i=1 y i (t) j,i σ ( w (t) j,r , ξ i ) = sgn(j), where we use the fact that sgn(y i (t) j,i ) = sgn(j) for all i ∈ [n]. Then by Lemma C.2 and (C.2), we have for all t ∈ [T r , T 0 ], w (t+1) j,r , j • v = w (t) j,r , j • v -Θ(η) • sgn(j) • sgn ∇ wj,r L(W (t) )[1] = w (t) j,r , j • v -Θ(η). Then at iteration T 0 , for the first coordinate we have w (T0) j,r [1] = w (0) j,r [1] + sgn(j) • Θ(T r η) -sgn(j) • Θ((T 0 -T r )η) ≥ -sgn(j) • Ω 1 sσ p For any k ∈ B i with y i = j, we have either the coordinate will increase at a rate of Θ(1) or fall into 0. As a consequence we have either w (T0) j,r [k] ∈ [-Θ(η), Θ(η)] or w (T0) j,r [k] = w (0) j,r [k] + sgn(ξ i [k]) • Θ(T 0 η) ≥ sgn(ξ i [k]) • Ω 1 sσ p . For the remaining coordinate, its update will be determined by the regularization term, which will finally fall into the region around zero since we have T 0 η = ω(σ 0 ). By Lemma C.2 it is clear that w (T0) j,r [k] ∈ [-Θ(η), Θ(η)]. Lemma C.5 (Lemma 5.4, restated). If α = O sσ 2 p n and η = o(λ), then let r * = arg max r∈[m] w (t) yi,r , ξ i , for any t ≥ T 0 , i ∈ [n], j ∈ [2] and r ∈ [m], it holds that w (t) yi,r * , ξ i = Θ(1), k∈Bi |w (t) yi,r * [k]| • |ξ i [k]| = Θ(1), ∀r ∈ [m], w • v ∈ [-O nα sσ 2 p , O(λ -1 η)]. Proof. The proof will be relying on the following three induction hypothesis: Verifying Hypothesis (C.12). We first verify Hypothesis (C.12). Recall that the update rule for w w (t) yi,r * , ξ i = Ω(1), (C.12) k∈Bi |w (t+1) yi,r * [k]| • |ξ i [k]| = Θ(1), (C.13) ∀r ∈ [m], w (t) yi,r , ξ i is given as follows, w (t+1) yi,r , ξ i = w (t) yi,r , ξ i -η • m (t) yi,r / v (t) yi,r , ξ i ≥ w (t) yi,r , ξ i -Θ(η) • sgn ∇ wy i ,r L(W (t) ) , ξ i -Θ(η 2 s 2 σ 2 p ) = w (t) yi,r , ξ i + Θ(η) • k∈Bi sgn (t) yi,i σ ( w (t) yi,r , ξ i )ξ i [k] -nλw (t) yi,r [k] , ξ i [k] -αy i Θ(η) • sgn n i=1 y i (t) j,i σ ( w (t) j,r , y i v ) -α n i=1 y i (t) j,i σ ( w (t) j,r , ξ i ) -nλw (t) j,r [1] -Θ(η 2 s 2 σ 2 p ). (C.15) Note that for any a and b we have sgn(a -b) • a ≥ |a| -2|b|. Then it follows that k∈Bi sgn (t) yi,i σ ( w (t) yi,r , ξ i )ξ i [k] -nλw (t) yi,r [k] , ξ i [k] ≥ k∈Bi |ξ i [k]| - 2nλ|w (t) yi,r [k]| (t) yi,i σ ( w (t) yi , ξ i ) ≥ Θ(sσ p ) -Θ nλ (t) yi,i σ p , where the last inequality follows from Hypothesis (C.12) and (C.13). Further recall that λ = o(σ q-2 0 σ p /n), plugging the above inequality to (C.15) gives w (t+1) yi,r , ξ i ≥ w (t) yi,r , ξ i + Θ(ηsσ p ) -Θ ηnλ (t) yi,i σ p -Θ(η 2 s 2 σ 2 p ) ≥ w (t) yi,r , ξ i + Θ(ηsσ p ) -Θ(αη) -Θ ησ q-2 0 (t) yi,i . (C.16) Then it is clear that w (t) yi,r , ξ i will increase by Θ(ηsσ p ) if (t) yi,i is larger than some constant of order Ω( nλ sσ 2 p ) = Ω( σ q-2 0 sσp ). We will first show that as soon as there is a iterate W (τ ) satisfying (τ ) yi,i ≤ O nλ sσ 2 p for some τ ≤ t, then it must hold that (τ ) yi,i will also be smaller than some constant in the order of O nλ sσ 2 p for all τ ∈ [τ, t + 1]. To prove this, we first note that if  w (t+1) yi,r , ξ i ≥ w (t) yi,r , ξ i + Θ(ηsσ p ), w (t+1) -yi,r , ξ i ≤ w (t) -yi,r , ξ i + O(αη), | w (t+1) j,r , v | ≤ | w (t) j,r , v | + O(η). (C.17) Therefore, we have (t+1) yi,i = e F-y i (W (t+1) ,xi) j∈{-1,1} e Fj (W (t+1) ,xi) = 1 1 + exp m r=1 σ( w (t+1) yi,r , v ) + σ( w (t+1) yi,r , ξ i ) -σ( w (t+1) -yi,r , v ) -σ( w -yi,r , ξ i ) ≤ 1 1 + exp m r=1 σ( w (t) yi,r , v ) + σ( w (t) yi,r , ξ i ) -σ( w (t) -yi,r , v ) -σ( w (t) -yi,r , ξ i ) + Θ(ηsσ 2 p ) ≤ 1 1 + exp m r=1 σ( w (t) yi,r , v ) + σ( w (t) yi,r , ξ i ) -σ( w (t) -yi,r , v ) -σ( w (t) -yi,r , ξ i ) = (t) yi,i , where inequality follows from (C.17). Therefore, this implies that as long as (t) yi,i is larger than some constant b = O nλ sσ 2 p , then the adam algorithm will prevent it from further increasing. Besides, since mησ 2 p = o(1), then we must have (t+1) yi,i ∈ [0.5 (t) yi,i , 2 (t) yi,i ]. As a consequence, we can deduce that (t) yi,i cannot be larger than 2b, since otherwise there must exists a iterate W (τ ) with τ ≤ t such that (τ ) yi,i ∈ [b, 2b] and (τ +1) yi,i ≥ (τ ) yi,i , which contradicts the fact that (τ ) yi,i should decreases if (τ ) yi,i ≥ b. Therefore, we can claim that if (τ ) yi,i ≤ b = O nλ sσ 2 p for some τ ≤ t, then we have (τ ) yi,i ≤ O nλ sσ 2 p (C.18) for all τ ∈ [τ, t + 1]. Then further note that 2 (t+1) yi,i ≥ (t) yi,i = e F-y i (W (t) ,xi) j∈{-1,1} e Fj (W (t) ,xi) ≥ exp - m r=1 σ( w (t) yi,r , y i v ) + σ( w (t) yi,r , ξ i ) ≥ exp -Θ m max r∈[m] σ( w (t) yi,r , ξ i ) , (C.19) where in the last inequality we use Hypothesis (C.14). Then by the fact that (t+1) yi,i ≤ O nλ sσ 2 p = o(1) and m = Θ(1), it is clear that exp -Θ m max r∈[m] σ( w (t+1) yi,r , ξ i ) = o(1) so that max r∈[m] w (t+1) yi,r , ξ i = Ω(1). This verifies Hypothesis (C.12). Verifying Hypothesis (C.13). Now we will verify Hypothesis (C.13). First, note that we have already shown that w (t+1) yi,r * , ξ i = Ω(1) so it holds that k∈Bi |w (t+1) yi,r * [k]| • |ξ i [k]| + α|w (t+1) yi,r * [1]| ≥ w (t+1) yi,r * , ξ i = Ω(1). By Hypothesis (C.14), we have |w (t+1) yi,r * [1]| ≤ |w (t) yi,r * [1]|+η = o(1). Besides, since each coordinate in ξ i is a Gaussian random variable, then max k∈Bi |ξ i [k]| = O(σ p ). This immediately implies that k∈Bi |w (t+1) yi,r * [k]| • |ξ i [k]| = Ω(1). Then we will prove the upper bound of k∈Bi |w (t+1) yi,r [k]| • |ξ i [k]|. Recall that by Lemma C.2, for any k ∈ B i such that ∇ wy i ,r L(W (t) )[k] ≥ Θ(n -1 ηsσ p (t) yi,i ), we have w (t+1) yi,r [k] = w (t) yi,r [k] + Θ(η) • sgn (t) yi,i σ ( w (t) yi,r , ξ i )ξ i [k] -nλw (t) yi,r [k] . Note that by Lemma C.4, for every k ∈ B i , we have either w (T0) yi,r [k] = sgn(ξ i [k]) • Θ 1 sσp or |w (T0) yi,r [k]| ≤ η. Then during the training process after T 0 , we have either sgn(w (t) yi,r [k]) = sgn(ξ i [k]) or sgn(ξ i [k]) • w (t) yi,r ≥ -O(η) since if for some iteration number t that we have sgn(w (t ) yi,r [k]) = -sgn(ξ i [k]) but sgn(w (t -1) yi,r [k]) = sgn(ξ i [k] ), then after τ = O(1) steps (see the proof of Lemma C.2 for the definition of τ ) in the constant number of steps the gradient will must be in the same direction of ξ i [k], which will push w yi,r [k] back to zero or become positive along the direction of ξ i [k] . Therefore, based on this property we have the following regarding the inner product w (t) yi,r , ξ i , w (t) yi,r , ξ i = k∈Bi∪{1} w (t) yi,r [k] • ξ i [k] ≥ k∈Bi∪{1} |w (t) yi,r [k]| • |ξ i [k]| -O(η) • k∈Bi∪{1} |ξ i [k]| = k∈Bi∪{1} |w (t) yi,r [k]| • |ξ i [k]| -O(ηsσ p ), where the second inequality follows from the fact that the entry w (t) yi,r [k] that has different sign of ξ i [k] satisfies |w (t) yi,r [k]| ≤ O(η). Then let B (t) i = j∈Bi∪{1} w (t) yi,r [k] • 1(|w (t) yi,r [k]| ≥ O(η)) • |ξ i [k]|, which satisfies B (T0) i = Θ(1) by Lemma C.4. Then assume B (t) i keeps increasing and reaches some value in the order of Θ log(dnη -1 ) , it holds that according to the inequality above w (t) yi,r , ξ i = Θ log(dnη -1 ) -Θ(ηsσ p ) = Θ log(dnη -1 ) , where we use the condition that η = O (sσ p ) -1 . Then by Hypothesis (C.12) and (C.14) we know that | w (t) j,r , v | = o(1), w (t) yi,r * , ξ i = Ω(1), and | w (t) -yi,r * , ξ i | = O(dη)+α| w (t) -yi,r * , v | = o(1) then similar to (C.19), it holds that (t) yi,i = e F-y i (W (t) ,xi) j∈{-1,1} e Fj (W (t) ,xi) ≤ exp -Θ σ( w (t) yi,r * , ξ i ) ≤ poly(d -1 , n -1 , η). Therefore, at this time we have for all k ∈ B i , (t) yi,i σ (w (t) yi,r , ξ i )ξ i [k] ≤ poly(d -1 , n -1 , η) • Θ log q-1 (dnη -1 ) • Θ(σ p ) ≤ nλη. Then for all |w (t) yi,r [k]| ≥ O(η) , the sign of the gradient satisfies sgn ∇ wy i ,r L(W (t) )[k] = -sgn (t) yi,i σ ( w (t) yi,r , ξ i )ξ i [k] -nλw (t) yi,r [k] = sgn(nλη -w (t) yi,r [k]) = sgn(w (t) yi,r [k]). Then note that |∇ wy i ,r L( W (t) )[k]| = Θ(|λw (t) yi,r [k]|) ≥ Θ n -1 ηsσ p (t) yi,i + λη , by the update rule of w (t) yi,r [k] and Lemma C.2, we know the sign gradient will dominate the update process. Then we have |w (t+1) yi,r [k]| = |w (t) yi,r [k] -Θ(η) • sgn(w (t) yi,r [k])| ≤ |w (t) yi,r [k]|, which implies that w (t) yi,r [k] • 1(|w (t) yi,r [k]| ≥ O(η)) decreases so that B (t) i also decreases. Therefore, we can conclude that B (t) i will not exceed Θ log(dnη -1 ) . Then combining the results for all i ∈ [n] gives k∈Bi |w (t) yi,r * [k]| • |ξ i [k]| ≤ B (t) i + O(sησ p ) ≤ Θ log(dnη -1 ) + O(1) = Θ(1), where in the first inequality we again use the condition that η = o(1/d) = o (sσ p ) -1 . This verifies Hypothesis (C.13). Notably, this also implies that w (t) yi,r * , ξ i = max r∈[m] w (t) yi,r , ξ i ≤ Θ(1). Verifying Hypothesis (C.14). In order to verify Hypothesis (C.14), let us first recall the update rule of w (t) j,r , v :  w (t+1) j,r , v = w (t) j,r , v -η m (t) j,r v (t) j,r m (t) j,r v (t) j,r , v = -sgn n i=1 y i (t) j,i σ ( w (t) j,r , y i v ) -α n i=1 y i (t) j,i σ ( w (t) j,r , ξ i ) -nλw (t) j,r [1] • Θ(1). Without loss of generality we assume j = 1, then by Lemma C.4 we know that w (T0) 1,r [1] = -Ω 1 sσp . In the remaining proof, we will show that either w (t+1) 1,r [1] ∈ [0, Θ(λ -1 η)] or w (t+1) 1,r [1] ∈ -O nα sσ 2 p , 0 . First we will show that w (t+1) 1,r [1] ∈ [0, Θ(λ -1 η)] for all r. Note that in the beginning of this stage, we have w (T0) 1,r [1] < 0. In order to make the sign of w (t ) 1,r [1] flip, we must have, in some iteration t ≤ t that satisfies w (t ) 1,r [1] ∈ [0, Θ(λ -1 η)], therefore -n∇ w1,r L(W (t ) )[1] = n i=1 y i (t ) j,i σ ( w (t ) j,r , y i v ) -α n i=1 y i (t ) j,i σ ( w (t ) j,r , ξ i ) -nλw (t ) j,r [1] ≤ n w (t ) j,r [1] q-2 -λ • w (t ) j,r [1] ≤ -Θ(nη) ≤ 0, where the second inequality holds since η = o(λ (q-1)/(q-2) ). Note that |∇ w1,r L(W (t ) )[1]| ≥ Θ(η), then by Lemma C.2 we know that Adam is similar to sign gradient descent and thus w (t +1) 1,r [1] = w (t ) 1,r [1] -Θ(η) which starts to decrease. This implies that if w (t+1) 1,r [1] is positive, then it cannot exceed Θ(λ -1 η) = o(1). Then we can prove that if w (t+1) 1,r [1] is negative, then |w (t+1) 1,r [1]| = O nα sσ 2 p . In this case we have for all t ≤ t, -n∇ w (t) 1,r L(W (t ) )[1] = n i=1 y i (t ) 1,i σ ( w (t ) 1,r , y i v ) -α n i=1 y i (t ) 1,i σ ( w (t ) 1,r , ξ i ) -nλw (t ) 1,r [1] ≥ - i:yi=1 | (t ) 1,i | • Θ(α) + nλ|w (t ) 1,r [1]| + i:yi=-1 | (t ) 1,i | • |w (t ) 1,r [1]| q-1 , ≥ - i:yi=1 | (t ) 1,i | • Θ(α) + nλ|w (t ) 1,r [1]|, where in the inequality we use Hypothesis (C.13) and (C.14) to get that w (t ) yi,r , ξ i ≤ k∈Bi |w (t ) yi,r [k]| • max k∈Bi |ξ i [k]| + α| w (t ) yi,r , v | = Θ(1). Recall from (C.18) that we have | (t ) j,i | = O nλ sσ 2 p , therefore we have if w (t ) j,r [1] is smaller than some value in the order of -Θ nα sσ 2 p • polylog(d), then -n∇ w (t) 1,r L(W (t ) )[1] ≥ -Θ αn 2 λ sσ 2 p + Θ nλ • nα sσ 2 p • polylog(d) ≥ Θ(nη), which by Lemma C.2 implies that w (t ) j,r [1] will increase. Therefore, we can conclude that w (t+1) ∈ -O nα sσ 2 p , 0 in this case, which verifies Hypothesis (C.14). Lemma C.6 (Lemma 5.5, restated). If the step size satisfies η = O(d -1/2 ), then for any t it holds that L(W (t+1) ) -L(W (t) ) ≤ -η ∇L(W (t) ) 1 + Θ(η 2 d). Proof. Let ∆F j,i = F j (W (t+1) , x i ) -F j (W (t) , x i ). Then regarding the loss function ,xi) . L i (W) = -log e Fy i (W,xi) j e Fj (W,xi) = -F yi (W, x i ) + log j e Fj (W It is clear that the function L i (W) is 1-smooth with respect to the vector [F -1 (W, x i ), F 1 (W, x i )]. Then based on the definition of ∆F j,i , we have L i (W (t+1) ) -L i (W (t) ) ≤ j ∂L i (W (t) ) ∂F j (W (t) , x i ) • ∆F j,i + j (∆F j,i ) 2 . (C.20) Moreover, note that F j (W (t) , x i ) = m r=1 σ( w (t) j,r , y i v ) + σ( w (t) j,r , ξ i ) . By the results that w (t) j,r , v ≤ Θ(1) and w (t) j,r , ξ ≤ Θ(1), for any η = O(d -1/2 ), we have w (t+1) j,r , v ≤ w (t) j,r , v + η ≤ Θ(1), w (t+1) j,r , ξ i ≤ w (t) j,r , ξ i + Θ(ηs 1/2 ) ≤ Θ(1), which implies that the smoothness parameter of the functions σ( w (t) j,r , y i v ) and σ( w (t) j,r , ξ i ) are at most Θ(1) for any w in the path between w . Then we can apply first Taylor expansion on σ( w (t) j,r , y i v ) and σ( w (t) j,r , ξ i ) and bound the second-order error as follows, σ( w (t+1) j,r , y i v ) -σ( w (t) j,r , y i v ) -∇ wj,r σ( w (t) j,r , y i v ), w (t+1) j,r -w (t) j,r ≤ Θ w (t+1) j,r -w (t) j,r 2 2 = Θ(η 2 d), (C.21) where the last inequality is due to Lemma C.2 that [w (t+1) j,r -w (t) j,r ] 2 = η 2 m (t) j,r v (t) j,r 2 2 ≤ Θ(η 2 d). Similarly, we can also show that σ( w (t+1) j,r , ξ i ) -σ( w (t) j,r , ξ i ) -∇ wj,r σ( w (t) j,r , ξ i ), w (t+1) j,r -w (t) j,r ≤ Θ(η 2 d). (C.22) Combining the above bounds on the second-order errors, we have ∆F j,i -∇ W F j (W (t) , x i ), W (t+1) -W (t) ≤ Θ(mη 2 d) = Θ(η 2 d), (C.23) where the last equation is due to our assumption that m = Θ(1). Besides, by (C.21) and (C.22) the convexity property of the function σ(x), we also have σ( w (t+1) j,r , y i v ) -σ( w (t) j,r , y i v ) ≤ | ∇ wj,r σ( w (t) j,r , y i v ), w (t+1) j,r -w (t) j,r | + Θ(η 2 d) = Θ η|σ ( w (t+1) j,r , y i v )| • v 1 + Θ(η 2 d) = Θ(η + η 2 d); σ( w (t+1) j,r , ξ i ) -σ( w (t) j,r , ξ i ) ≤ | ∇ wj,r σ( w (t) j,r , ξ i ), w (t+1) j,r -w (t) j,r | + Θ(η 2 d) = Θ η|σ ( w (t+1) j,r , ξ i )| • ξ 1 + Θ(η 2 d) = Θ(ηsσ p + η 2 d). These bounds further imply that |∆F j,i | ≤ Θ m • (ηsσ p + η 2 d) = Θ ηsσ p + η 2 d . (C.24) Now we can plug (C.23) and (C.24) into (C.20) and get L i (W (t+1) ) -L i (W (t) ) ≤ j ∂L i (W (t) ) ∂F j (W (t) , x i ) • ∆F j,i + j (∆F j,i ) 2 ≤ j ∂L i (W (t) ) ∂F j (W (t) , x i ) • ∇ W F j (W (t) , x i ), W (t+1) -W (t) + Θ(η 2 d) + Θ (ηsσ p + η 2 d) 2 = ∇L i (W (t) ), W (t+1) -W (t) + Θ(η 2 d), (C.25) where in the second inequality we use the fact that L i (W) is 1-Lipschitz with respect to F j (W, x i ) and the last equation is due to our assumption that σ p = O(s -1/2 ) so that Θ((ηsσ p + η 2 d) 2 ) = O(η 2 d). Now we are ready to characterize the behavior on the entire training objective L(W) = n -1 n i=1 L i (W) + λ W 2 F . Note that λ W 2 F is 2λ-smoothness, where λ = o(1). Then ap- plying (C.25) for all i ∈ [n] gives L(W (t+1) ) -L(W (t) ) = 1 n n i=1 L i (W (t+1) ) -L i (W (t) ) + λ W (t+1) 2 F -W (t) 2 F ≤ ∇L(W (t) ), W (t+1) -W (t) + Θ(η 2 d), where the second equation uses the fact that W (t+1) -W (t) 2 F = Θ(η 2 d). Recall that we have w (t+1) j,r -w (t) j,r = -η • m (t) j,r v (t) j,r Then by Lemma C.2, we know that m (t) j,r [k]/ v (t) j,r [k] is close to sign gradient if ∇L(w (t) )[k] is large. Then we have ∇ wj,r L(W (t) ), m (t) j,r v (t) j,r ≥ Θ ∇ wj,r L(W (t) ) 1 -Θ d • η -Θ(ns • ηsσ p ) ≥ Θ ∇ wj,r L(W (t) ) 1 -Θ(dη), where the second and last terms on the R.H.S. of the first inequality are contributed by the small gradient coordinates k / ∈ ∪ n i=1 B i and k ∈ ∪ n i=1 B i respectively, and the last inequality is by the fact that ns 2 σ p = O(d). Therefore, based on this fact (C.25) further leads to L(W (t+1) ) -L(W (t) ) ≤ -η ∇L(W (t) ) 1 + Θ(η 2 d), which completes the proof. Lemma C.7 (Generalization Performance of Adam). Let W * = argmin W∈{W (1) ,...,W (T ) } ∇L(W) 1 . Then for all training data, we have 1 n n i=1 1 F yi (W * , x i ) ≤ F -yi (W * , x i ) = 0. Moreover, in terms of the test data (x, y) ∼ D, we have P (x,y)∼D F y (W * , x) ≤ F -y (W * , x) ≥ 1 2 . Proof. By Lemma C.6, we know that the algorithm will converge to a point with very small gradient (up to O(ηd) in 1 norm). Then in terms of a noise vector ξ i , we have k∈Bi ∇ wy i ,r L(W * )[k] ≤ O(ηd). (C.26) Note that n∇ wy i ,r L(W * )[k] = * yi,i σ ( w * yi,r , ξ i )ξ i [k] -nλw * yi,r [k], where * yi,i = 1 -logit yi (F * , x i ). Then by triangle inequality and (C.26), we have for any r ∈ [m], k∈Bi | * yi,i |σ ( w * yi,r , ξ i )|ξ i [k]| -nλ k∈Bi |w * yi,r [k]| ≤ n k∈Bi ∇ wy i ,r L(W * )[k] ≤ O(nηd). Then by Lemma C.5, let r * = arg max r∈[m] w * yi,r , ξ i , we have w yi,r * , ξ i = Θ(1) and k∈Bi |w * yi,r * [k]| • |ξ i [k]| = Θ(1). Note that |ξ i [k]| = O(σ p ), we have k∈Bi |w * yi,r * [k]| ≥ Θ(1/σ p ). Then according to the inequality above, it holds that | * yi,i | • Θ(sσ p ) ≥ Θ nλ k∈Bi |w * yi,r [k]| -nηd ≥ Θ nλ σ p , where the second inequality is due to our choice of η. This further implies that | * yi,i | = | * -yi,i | = Θ nλ sσ 2 p by combining the above results with (C.18). Then let us move to the gradient with respect to the first coordinate. In particular, since ∇L(W * ) 1 ≤ O(ηd), we have |n∇ wj,r L(W * )[1]| = n i=1 y i * j,i σ ( w * j,r , y i v ) -α n i=1 y i * j,i σ ( w * j,r , ξ i ) -nλw * j,r [1] ≤ O(nηd). (C.27) Published as a conference paper at ICLR 2023 Then note that sgn(y i * j,i ) = sgn(j), it is clear that w * j,r * [1] • j ≤ 0 since otherwise |n∇ w j,r * L(W * )[1]| ≥ α n i=1 y i * j,i σ ( w * j,r * , ξ i ) -σ ( w * j,r * , y i v ) ≥ Θ αn 2 λ sσ 2 p ≥ Ω(nηd), which contradicts (C.27). Therefore, using the fact that w * j,r * [1] • j ≤ 0, we have |n∇ w j,r * L(W * )[1]| = α n i:yi=j y i * j,i σ ( w * j,r * , ξ i ) - n i:yi=-j y i * j,i σ (|w * j,r * [1]|) -nλ|w * j,r * [1]| . Then applying (C.27)and using the fact that | * yi,i | = | * -yi,i | = Θ nλ sσ 2 p for all i ∈ [n], it is clear that |w * j,r * [1]| ≥ Θ α 1/(q-1) ∧ nα sσ 2 p ≥ Θ nα sσ 2 p , where the second equality is due to our choice of σ p and α. Then combining with Lemma C.5 and the fact that w * j,r * [1] • j < 0, we have  w * j,r * [1] • j ≤ -Θ nα sσ F yi (W * , x i ) = m r=1 σ( w * yi,r , y i v ) + σ( w * yi,r , ξ i ) = Θ(1), F -yi (W * , x i ) = m r=1 σ( w * -yi,r , -y i v ) + σ( w * -yi,r , ξ i ) = o(1), which directly implies that the NN model W * can correctly classify all training data and thus achieve zero training error. In terms of the test data (x, y) where x = [yv, ξ], which is generated according to Definition 3.1. Note that for each neural, its weight w * j,r can be decomposed into two parts: the first coordinate and the rest d -1 coordinates. As previously discussed, for any j ∈ [2] and r = r * , we have sgn(j) • w * j,r [1] ≤ -Θ nα/(sσ 2 p ) and sgn(j) • w * j,r [1] ≤ Θ(λ -1 η) for r = r * . Therefore, using the fact that Θ nα/(sσ 2 p ) = ω(λ -1 η) and Lemma C.5, given the test data (x, y), we have F y (W * , x) = m r=1 σ( w * y,r , yv ) + σ( w * y,r , ξ ) ≤ m r=1 Θ α • nα sσ 2 p + ζ y,r q + , F -y (W * , x)) = m r=1 σ( w * -y,r , yv ) + σ( w * -y,r , ξ ) ≥ Θ |w * -y,r * [1]| q + [ζ -y,r * ] q + ≥ Θ nα sσ 2 p q + + [ζ -y,r * ] q + , where the random variables ζ y,r and ζ y,r are symmetric and independent of v. Besides, note that α = o(1), it can be clearly shown that α • nα/(sσ 2 p ) nα/(sσ 2 p ). Therefore, if the random noise ζ y,r and ζ -y,r are dominated by the feature noise term w * -y,r * , yv , we can directly get that F y (W * , x) ≤ F -y (W * , x)) (recall that m = Θ(1)), which implies that the model has been biased by the feature noise and the true feature vector in the test dataset will not give any "positive" effect to the classification. Also note that ζ y and ζ -y are also independent of v, which implies that if the random noise dominates the feature noise term, the model W * will give at least 0.5 error on test data. In sum, we can conclude that with probability at least 1/2 it holds that F y (W * , x) ≤ F -y (W * , x), which implies that the output of Adam achieves 1/2 test error. Then (C.35) implies that if Γ (t) 1 ≤ Θ(log(1/σ 0 )/m), we have Λ (t+1) 1 ≥ (1 -ηλ)Λ (t) 1 + Θ(ηα) • Λ (t) 1 -Θ(α q η) ≥ Λ (t) 1 + Θ(ηα) • Λ (t) 1 ≥ Λ (t) 1 , where the second inequality is due to λ = o(α). This implies that Λ (t) 1 will keep increases in this case so that it is impossible that Λ (t) 1 ≤ Θ(1/m), which completes the proof of the first part. For the second part, (C.28) implies that Λ (t+1) 1 ≤ (1 -ηλ)Λ (t) 1 + Θ η n • i:yi=1 | (t) 1,i | • Λ (t) 1 q-1 . (C.36) Consider the case when Γ  L i (W (t+1) ) -L i (W (t) ) ≤ j ∂L i (W (t) ) ∂F j (W (t) , x i ) • ∆F j,i + j (∆F j,i ) 2 = ∇L i (W (t) ), W (t+1) -W (t) + Θ(η 2 ∇L(W (t) ) 2 F ). (C.44) Taking sum over i ∈ [n] and applying the smoothness property of the regularization function λ W 2 F , we can get L(W (t+1) ) -L(W (t) ) = 1 n n i=1 L i (W (t+1) ) -L i (W (t) ) + λ W (t+1) 2 F -W (t) 2 F ≤ ∇L(W (t) ), W (t+1) -W (t) + Θ(η 2 ∇L(W (t) ) 2 F ) = -η -Θ(η 2 ) • ∇L(W (t) ) 2 F ≤ - η 2 ∇L(W (t) ) 2 F , where the last inequality is due to our choice of step size η = o(1) so that gives η -Θ(η 2 ) ≥ η/2. This completes the proof. This implies that GD can also achieve nearly at most 1/poly(n) test error. This completes the proof. D PROOF OF THEOREM 4.2: CONVEX CASE Theorem D.1 (Convex setting, restated). Assume the model is over-parameterized. Then for any convex and smooth training objective with positive regularization parameter λ, suppose we run Adam and gradient descent for T = poly(n) η iterations, then with probability at least 1 -n -1 , the obtained parameters W * Adam and W * GD satisfy that ∇L(W * Adam ) 1 ≤ 1 T η and ∇L(W * Adam ) 2 2 ≤ 1 T η respectively. Moreover, it holds that: • Training errors are the same: 1 n n i=1 1 sgn F (W * Adam , x i ) = y i = 1 n n i=1 1 sgn F (W * GD , x i ) = y i . • Test errors are nearly the same: P (x,y)∼D sgn F (W * Adam , x i ) = y = P (x,y)∼D sgn F (W * GD , x) = y ± o(1). Proof. The proof is straightforward by applying the same proof technique used for Lemmas C.6 and C.13, where we only need to use the smoothness property of the loss function. Then it is clear that both Adam and GD can provably find a point with a sufficiently small gradient. Note that the training objective becomes strongly convex when adding weight decay regularization, implying that the entire training objective only has one stationary point, i.e., point with a sufficiently small gradient. This further implies that the points found by Adam and GD must be exactly the same and thus GD and Adam must have nearly the same training and test performance. Besides, when the problem is sufficiently over-parameterized, with proper regularization (feasibly small), we can still guarantee zero training errors. In particular, given the binary label y i ∈ {-1, 1}, the feature vector x i is set as x i [j] =        y i , j = 1 1, j = 2, 3 1, j = 4 + 5(i -1), . . . , 4 + 5(i -1) + 2(1 -y i ) 0, otherwise. Data model in Reddi et al. (2018) . In particular, Reddi et al. ( 2018) considers a one-dimensional optimization objective. Besides, in each iteration of Adam, the stochastic gradient is taken based on the function f t (x) defined as follows: f t (x) = Cx, t mod 3 = 1 -x, otherwise. Then it can be seen that in these two prior works, each coordinate of the feature vector (or the objective function) is hard coded. In contrast, our data model allows randomness in the data generation process. This implies that our theory can hold for the data points generated from a certain distribution, while these prior works can only cover one particular data or optimization objective.



Our theory can still hold for mini-batch stochastic gradient descent, which we will discuss in Appendix. Recall that all data inputs have nonzero first coordinate by Definition 3.1



Figure 1: Visualization of the first layer of AlexNet trained by Adam and SGD on the CIFAR-10 dataset. Both algorithms are run for 100 epochs with weight decay regularization and standard data augmentations.Clearly, the model learned by Adam is more "noisy" than that learned by SGD, implying that Adam is more likely to overfit the noise in the training data. provide theoretical insights into the effectiveness of these empirical studies. Optimization and generalization in deep learning. Our work is also closely related to the recent line of work studying the optimization and generalization guarantees of neural networks in the neural tangent kernel (NTK) regime(Jacot et al., 2018) or lazy training regime(Chizat et al., 2019). In particular, recent works(Du et al., 2019b;a; Allen-Zhu et al., 2019b; Zou et al., 2019)  showed that the optimization only happens within a small neighborhood region around the random initialization and proved the global convergence of GD and SGD when the neural network is sufficiently wide. Moreover, the generalization ability of GD/SGD has been further studied in the same setting(Allen- Zhu et al., 2019a; Arora et al., 2019a;b;Ji & Telgarsky, 2020;Chen et al., 2021), which suggests that wide neural network trained by GD/SGD can learn a low-dimensional function class. Moreover, Allen-Zhu & Li (2019); Bai & Lee (2019) initiated the study of learning neural networks beyond the NTK regime as it differs from the practical DNN training. Our analysis in this paper is also beyond NTK, and gives a detailed comparison between GD and Adam. Feature learning by neural networks. This paper is also closely related to several recent works that studied how neural networks can learn features.Allen-Zhu & Li (2020a)  showed that adversarial training purifies the learned features by removing certain "dense mixtures" in the hidden layer weights of the network.Allen-Zhu & Li (2020b)  studied how ensemble and knowledge distillation work in deep learning when the data have "multi-view" features.Frei et al. (2022b)  studied the feature learning for two-layer networks, and demonstrated its superior performance than linear models.Shen et al. (2022) explored the benefit of data augmentation by showing its ability to achieve more effective feature learning. This paper studies a different aspect of feature learning by Adam and GD, and shows that GD can learn the features while Adam may fail even with proper regularization.

Figure 1 are performed by running (stochastic gradient) Adam and SGD in training AlexNet on the CIFAR-10 dataset. Specifically, the first layer of AlexNet is set as: kernel size=11, stride=4, padding=2, in order to match the size of CIFAR-10 image. In terms of the input data, we use standard random crop and horizontal flip data augmentations. In terms of the model training, we set the batch size as 64, the epoch number as 100, the regularization parameter as λ = 5 × 10 -4 . Besides, we set the learning rate η = 0.01 for SGD and η = 1 × 10 -5 for Adam, where β 1 and β 2 are set as their default values in pytorch. A.2 NUMERICAL EXPERIMENTS ON SYNTHETIC DATA In this section we perform numerical experiments on the synthetic data generated according to Definition 3.1 to verify our main results. In particular, we set the problem dimension d = 1000, the training sample size n = 200 (100 positive examples and 100 negative examples), feature vector v = [1, 0, . . . , 0] , noise sparsity s = 0.1d = 100, standard deviation of noise σ p = 1/s 1/2 = 0.1, feature noise strength α = 0.2, initialization scaling σ 0 = 0.01, regularization parameter λ = 1 × 10 -5 .

Figure 2: Visualization of the feature learning (max r w 1,r , v ) and noise memorization (min i max r w 1,r , ξ i ) for training the two-layer CNN model.

Figure 4: Visualization of the first layer of the 5-layer CNN model on the synthetic dataset.

Stochastic gradient Adam, Informal). Suppose the training data is generated according to Definition 3.1, then given proper configurations of λ and , for any t 0 ≤ T 0 with T 0 = O n ηBsσp and any i ∈ [n], we have the following one-epoch update of feature learning and noise memorization w

. Using the fact that τ = Θ(1), λ = o(1), and | (t) j,i | ≤ 1, we have for all k = 1 or k / ∈ B i for any i,

they hold for all τ ≤ t and r ∈ [m], i ∈ [n], and j ∈ [2]. It is clear that all hypothesis hold when t = T 0 according to Lemma C.4.

for all r ∈ [m] by(C.16)

Then by Lemma C.2, we know that if |∇ wj,r L(W (t) )[1]| ≤ Θ(η), then |m

Lemma C.11 (Stage I of GD: part II). Without loss of generality assuming T1 < T -1 . Then it holds that Λ (t) 1 = Θ(1) for all t ∈ [T 1 , T -1 ].Proof. Recall from (C.29) that we have the following general lower bound for the increase of Λ inequality is by Lemma C.10. Note that by Lemma C.8, we have Γ (t) j = O(σ 0 ) for all t ≤ T -1 and . Then the above inequality leads to Λ the fact that α = ω(σ 0 ). The the remaining proof consists of two parts: (1) provingΛ (t) j ≥ Θ(1/m) = Θ(1) and (2)Λ (t) j ≤ Θ(log(1/λ)).Without loss of generality we consider j = 1. Regarding the first part, we first note that Lemma C.8 implies that Λ (T1) 1 ≥ Θ(1/m). Then we consider the case when Λ (t) 1 ≤ Θ(log(1/α)/m), it holds that for all y i = 1,

r , y i v ) + σ( w(t) -1,r , ξ i ) -, y i v ) + σ( w (t) 1,r , ξ i ) ≤ exp -Θ(Λ (t) 1 ) ≤ exp(-Θ(log(1/λ)) = Θ(poly(λ)).Combining the above inequalities for every r ∈ [m], we have∆F jΘ mη 2 ∇L(W (t) ) 2 F = Θ η 2 ∇L(W (t)) 2 F . (C.43) Now we can plug (C.42) and (C.43) into (C.40), which gives

Lemma C.14 (Generalization Performance of GD). Let W * = arg min{W (1) ,...,W (T ) } ∇L(W (t) ) F .Then for all training data, we have1 n n i=1 1 F yi (W * , x i ) ≤ F -yi (W * , x i ) = 0.Moreover, in terms of the test data (x, y) ∼ D, we haveP (x,y)∼D F y (W * , x) ≤ F -y (W * , x) = o(1).Proof. By Lemma C.12 it is clear that all training data can be correctly classified so that the training error is zero. Besides, for test data (x, y) with x = [yv , ξ ] , it is clear that with high probability w * y,r , yv = Θ(1) and[ w * y,r , ξ ] + ≤ O(σ 0 ), then F y (W * , x) = m r=1 σ( w * y,r , yv ) + σ( w * y,r , ξ ) ≥ Ω(1).If j = -y, we have with probability at least 1 -1/poly(n), w * -y,r , yv ≤ 0 and [w * -y,r , ξ ] + ≤ O(α), which leads toF -y (W * , x) = m r=1σ( w * -y,r , yv ) + σ( w * -y,r , ξ ) ≤ O(mα q ) = O(α q ) = o(1).

DISCUSSION ON THE DATA MODELS IN WILSON ET AL. (2017); REDDI ET AL. (2018) Data model in Wilson et al. (2017).



Greg Yang. Scaling limits of wide neural networks with weight sharing: Gaussian process behavior, gradient independence, and neural tangent kernel derivation. arXiv preprint arXiv:1902.04760, 2019. Pan Zhou, Jiashi Feng, Chao Ma, Caiming Xiong, Steven Chu Hong Hoi, et al. Towards theoretically understanding why sgd generalizes better than adam in deep learning. Advances in Neural Information Processing Systems, 33, 2020. Difan Zou, Yuan Cao, Dongruo Zhou, and Quanquan Gu. Gradient descent optimizes overparameterized deep ReLU networks. Machine Learning, Oct 2019.



Training and test errors achieved by GD and Adam.

Now we are ready to evaluate the training error and test error. In terms of training error, it is clear that by Lemma C.5, we have w * yi,r

ACKNOWLEDGEMENTS

We thank the anonymous reviewers and area chair for their helpful comments. YL is supported by the National Science Foundation CCF-2145703. QG is supported in part by the National Science Foundation CAREER Award 1906169, IIS-2008981 and the Sloan Research Fellowship. 

annex

Published as a conference paper at ICLR 2023 C.3 PROOF FOR GRADIENT DESCENT Recall the feature learning and noise memorization of gradient descent can be formulated byyi,s σ ( w (t) yi,r , ξ s ) -yi,s σ ( w (t) yi,r , y s v ) . (C.28)Then similar to the analysis for Adam, we decompose the gradient descent process into multiple stages and characterize the algorithmic behaviors separately. The following lemma characterizes the first training stage, i.e., the stage where all outputs F j (W (t) , x i ) remain in the constant level for all j and i. Lemma C.8. [Lemma 5.6, restated] Suppose the training data is generated according to Definition 3.1 andj,i . Then let T j be the iteration number that Λ (t) j reaches Θ(1/m), we haveWe first provide the following useful lemma. Lemma C.9. Let {x t , y t } t=1,... be two positive sequences that satisfyfor some A = Θ(1) and B = o(1). Then for any q ≥ 3 and suppose y 0 = O(x 0 ) and η < O(x 0 ), we have for every C ∈ [x 0 , O(1)], let T x be the first iteration such that x t ≥ C, then we have). Then we will show y t ≤ 2x 0 for all t ≤ T x . In particular, let T x η = C x 2-q 0 for some absolute constant C and assume C B2 q-1 < 1 (this is true since B = o(1)), we first made the following induction hypothesis on y t for all t ≤ T a , y t ≤ y 0 + tηB (2x 0 ) q-1 .Note that for any t ≤ T 0 , this hypothesis clearly implies thatThen we are able to verify the hypothesis at time t + 1 based on the recursive upper bound of y t , i.e.,Therefore, we can conclude that y t ≤ 2x 0 for all t ≤ T x . This completes the proof. for all j ∈ {-1, 1}. Then we haveThen we will consider the training period where | (t) j,i | = Θ(1) for all j, i, and t. Besides, note that sgn(y i (t)Similarly, let r * = arg max r w (t) yi,r , ξ i , we also have the following according to (C.28)Then by our definition ofj,i , we further get the following for all j ∈ {-1, 1},where the last equation is by our assumption that α = O(sσ 2 p /n). Then we will prove the main argument for general t, which is based on the following two induction hypothesesNote that when t = 0, we have already verified these two hypotheses in (C.29) and (C.30), where we use the fact that λ = o(σ q-2 0 σ p /n) ≤ Λ (0) j q-2 and α = o(1). Suppose that (C.29) and (C.30) hold for iterations τ ≤ t. At time t + 1, for all τ ≤ t, we haveas sσ 2 /n = o(1) and Λ (t) j increases faster than Γ (t) j . Besides, we can also show that λΓq-1 , which has been verified at time t = 0, since Γ (t) j keeps increasing. Therefore, we haveand hence (C.29) implies, which verifies Hypothesis (C.31) at t + 1. Additionally, (C.30) implieswhich verifies Hypothesis (C.32) at t + 1. Then by Lemma C.9, we have thatMoreover, Lemma C.9 also shows that Γ. This completes the proof.Lemma C.10. For all i ∈ [n] and t ≤ T -yi , it holds that wProof. First of all, for j ∈ {±1}, by the definition of T j , we haveMoreover, with the same proof as Lemma C.8, it is clear that -wfor all t ≤ T j .Now by the update form of GD, we have for any k ∈ B i ,Note that-yi,r , ξ i ) < 0, which implies that wTherefore, for all r and i, we havewhere the third inequality follows by (C.33). This completes the proof.Note that for different j, the iteration numbers when Λ 1) and the effect of gradient descent on the feature learning (i.e., increase of w j,r , j • v ) becomes weaker. In the following lemma we give a characterization ofThen (C.36) further implies thatwhich implies that Λ (t) 1 will decrease. As a result, we can conclude that λ (t)1 will not exceed Θ(log(1/λ)), this completes the proof of the second part.Lemma C.12 (Lemma 5.7, restated).Proof. We will prove the desired argument based on the following three induction hypothesis:In terms of Hypothesis (C.37), we can apply Hypothesis (C.38) and (C.39) to (C.34) and get thatwhere the last inequality we use the fact that α ≥ σ 0 . This verifies Hypothesis (C.37).In order to verify Hypothesis (C.38), we have the following according to (C.37),where the last equality holds since α = o(1). Recursively applying the above inequality from T -1 to t givesThen by Hypothesis (C.39) we haveNow let us look at the rate of memorizing noises. By (C.28) and use the fact that α 2 ≤ O(sσ 2 p /n), we have Proof. The proof of this lemma is similar to that of Lemma C.6, which is basically relying the smoothness property of the loss function L(W) given certain constraints on the inner products w j,r , v and w j,r , ξ i .Let ∆F j,i = F j (W (t+1) , x i ) -F j (W (t) , x i ), we can get the following Taylor expansion on the loss function L i (W (t+1) ),In particular, by Lemma C.12, we know that w (t) j,r , y i v ≤ Θ(1) and w (t) j,r , ξ i ≤ Θ(σ 0 ) ≤ Θ(1). Then similar to (C.21), we can apply first-order Taylor expansion to F j (W (t+1) , x i ), which requires to characterize the second-order error of the Taylor expansions on σ( w 

