PROVABLE SHARPNESS-AWARE MINIMIZATION WITH ADAPTIVE LEARNING RATE

Abstract

Sharpness aware minimization (SAM) optimizer has been extensively explored as it can converge fast and train deep neural networks efficiently via introducing extra perturbation steps to flatten the landscape of deep learning models. A combination of SAM with adaptive learning rate (AdaSAM) has also been explored to train large-scale deep neural networks without theoretical guarantee due to the dual difficulties in analyzing the perturbation step and the coupled adaptive learning rate. In this paper, we try to analyze the convergence rate of AdaSAM in the stochastic non-convex setting. We theoretically show that AdaSAM admit a O(1/ √ bT ) convergence rate and show linear speedup property with respect to mini-batch size b. To the best of our knowledge, we are the first to provide the non-trivial convergence rate of SAM with an adaptive learning rate. To decouple the two stochastic gradient steps with the adaptive learning rate, we introduce the delayed second-order momentum term in the convergence to decompose them to make them independent while taking an expectation. Then we bound them by showing the adaptive learning rate has a limited range, which makes our analysis feasible. At last, we conduct experiments on several NLP tasks and they show that AdaSAM could achieve superior performance compared with SGD, AMSGrad, and SAM optimizer.

1. INTRODUCTION

Sharpness-aware minimization (SAM) Foret et al. (2021) is a powerful optimizer for training largescale deep learning models by minimizing the gap between the training performance and generalization performance. It has achieved remarkable results on various deep neural networks, such as ResNet He et al. (2016) , vision transformer Dosovitskiy et al. (2021) ; Xu et al. (2021) , language models Devlin et al. (2018) ; He et al. (2020) , on extensive benchmarks. However, SAM-type methods suffer from several issues during training the deep neural network, especially for huge computation costs and heavily hyper-parameter tuning procedure. It needs double gradients computation compared with classic optimizers, like SGD, AdamKingma & Ba (2015) , AMSGrad Reddi et al. (2018) , due to the extra perturbation step. Hence, SAM requires to forward and back propagate twice for one parameter update, resulting in one more computation cost than the classic optimizers. Moreover, as there are two steps during the training process, it needs double hyper-parameters, which makes the learning rate tuning unbearable and costly. Adaptive learning rate optimization methods scale the gradients based on the history gradient information to accelerate convergence. These methods, such as Adagrad Duchi et al. (2011) , Adam, and AMSGrad, have been studied for solving the NLP tasks Zhang et al. (2020) . However, it might converge to a local minima which leads to poor generalization performance. Many studies have investigated how to improve the generalization ability of the adaptive learning rate methods. Recently, several work has tried to ease the learning rate tuning in SAM by combining SAM with adaptive learning rate. For example, Zhuang et al. (2022) trains ViT models with an adaptive method and Kwon et al. (2021) training a NLP model. Although remarkable performance has been achieved, their convergence is still unknown since the adaptive learning rate is used in SAM. Directly analyzing the convergence of this method is difficult due to the two steps of optimization, especially the adaptive learning rate adopted in the second optimization step. In this paper, we analyze the convergence rate of SAM with an adaptive learning rate in the nonconvex stochastic setting, dubbed AdaSAM. To circumvent the difficulty in the analysis, we develop technique to decouple the two-step training of SAM and the adaptive learning rate. The analysis procedure is mainly divided into two parts. The first part is to analyze the procedure of the SAM. Then we analyze the second step that adopts the adaptive learning rate method. We introduce a second-order momentum term from the previous iteration, which is related to the adaptive learning rate and independent of SAM while taking an expectation. Then we can bound the term composed by the SAM and the previous second-order momentum due to the limited adaptive learning rate. We prove that AdaSAM enjoys the property of linear speedup with respect to the batch size. Furthermore, we apply AdaSAM to train RoBERTa model on the GLUE benchmark. We show that AdaSAM achieves the best performance in experiments, where it wins 6 tasks of 8 tasks, and the linear speedup can be clearly observed. In the end, we summarize our contributions as follows: • We present the first convergence guarantee of the adaptive SAM method under the nonconvex setting. Our results suggest that a large mini-batch can help convergence due to the linear speedup with respect to batch size. • We conduct a series of experiments on various tasks. The results show that AdaSAM outperforms the most of state-of-art optimizers and the linear speedup is verified.

2. PRELIMINARY 2.1 PROBLEM SETUP

In this work, we focus on the following stochastic nonconvex optimization min x∈R d f (x) := E ξ∼D f ξ (x), where d is dimension of variable x, D is the unknown distribution of the data samples, f ξ (x) is a smooth and possibly non-convex function, and f ξi (x) denotes the objective function at the sampled data point ξ i according to data distribution D. In machine learning, it covers empirical risk minimization as a special case and f is the loss function. When the dataset D cover N data points, i.e., D = {ξ i , i = 1, 2, . . . , N }. Problem 1 reduces to the following finite-sum problem: min x∈R d f (x) := 1 N i f ξi (x). Without additional declaration, we represent f i (x) as f ξi (x) for simplification. Notations. f i (x) is the i-th loss function while x ∈ R d is the model parameter and d is the parameter dimension. We denote the l 2 norm as ∥ • ∥ 2 . A Hadamard product is denoted as a ⊙ b where a,b are two vectors. For a vector a ∈ R d , √ a is denoted as a vector that the j-th value, ( √ a) (j) , is equal to the square root of a j (2019) . Therefore, it is popular to consider sharpness to be closely related to the generalization. Sharpness-aware minimization is proposed in Foret et al. (2021) , which targets to find flat minimizers by minimizing the training loss uniformly in entire neighborhood. Specifically, Foret et al. (2021) aims to solve the following minimax saddle point problem:

Sharpness

min x max ∥δ∥≤ρ f (x + δ) + λ∥x∥ 2 2 , where ρ ≥ 0 and λ ≥ 0 are two hyperparameters. That is, the perturbed loss function of f (x) in a neighborhood is minimized instead of the original loss function f (x). By using Taylor expansion of f (x + δ) with respect to δ, the inner max problem is approximately solved via δ * (x) = arg max ∥δ∥≤ρ f (x + δ) ≈ arg max ∥δ∥≤ρ f (x) + δ ⊤ ∇f (x) = arg max ∥δ∥≤ρ δ ⊤ ∇f (x) = ρ ∇f (x) ∥∇f (x)∥ . Dropping the quadratic term, (3) is simplified as the following minimization problem min x f x + ρ ∇f (x) ∥∇f (x)∥ . (4) The stochastic gradient of f x + ρ ∇f (x) ∥∇f (x)∥ on a batch data b includes the Hessian-vector product, SAM further approximates the gradient by ∇ x f b x + ρ ∇f b (x) ∥∇f b (x)∥ ≈ ∇ x f b (x)| x+ρ ∇f b (x) ∥∇f b (x)∥ . Then, along the negative direction -∇ x f b (x)| x+ρ ∇f b (x) ∥∇f b (x)∥ , SGD is applied to solve minimization problem (4). It is easy to see that SAM requires twice gradient back-propagation, i.e., ∇f b (x) and ∇ x f b (x)| x+ρ ∇f b (x) ∥∇f b (x)∥ . Due to the existence of hyperparameter ρ, one needs to carefully tune both ρ and learning rate in SAM. In practice, ρ is predefined to control the radius of the neighborhood. Some variants of SAM are proposed to improve its performance. Recently, several works Zhuang et al. (2022) ; Kwon et al. (2021) have empirically incorporated adaptive learning rate with SAM and shown impressive generalization accuracy, while their convergence analysis has never been studied. ESAM Du et al. (2022) proposes an efficient method by sparsifying the gradients will alleviate the double computation cost of backpropagation. ASAM Kwon et al. (2021) modifies SAM by adaptively scaling the neighborhood so that the sharpness is invariant to parameters re-scaling. GSAM Zhuang et al. (2022) simultaneously minimizes the perturbed function and a new defined surrogate gap function to further improve the flatness of minimizers. Liu et al. Liu et al. (2022) also studies SAM in large-batch training scenario and periodically update the perturbed gradient. On the other hand, there are some works analyzing the convergence of the SAM such as Andriushchenko & Flammarion (2022) , where it does not consider a normalization step, i.e., the normalization in ∇f b (x) ∥∇f b (x)∥ . Adaptive optimizer The adaptive learning rate method can automatically adjust the learning rate based on the history gradients methods. The first adaptive method is Adagrad Duchi et al. (2011) , and it can achieve a better result than other first-order methods under the convex setting. 

3. ADASAM: SAM WITH ADAPTIVE LEARNING RATE

In this section, we introduce SAM with the adaptive learning rate (AdaSAM) from the AMSGrad optimizer. Then, we will present the convergence results of AdaSAM. At last, we give the proof sketch for main theorem. AdaSAM is described as in Algorithm 1. In each iteration, a mini-batch gradient estimation g t at point x + ϵ(x) with batchsize b is computed, i.e., g t = ∇ x f b (x)| xt+ϵ(xt) = 1 b i∈B ∇f ξi (x t + δ(x t )). Here, δ(x t ) is the extra perturbed gradient step in SAM that is given as follows δ(x t ) = ρ s t ∥s t ∥ , where s t = ∇ x f b (x)| xt = 1 b i∈B ∇f ξi (x t ). Then, exponential averaging of g t and the second-order term [g t ] 2 are accumulatively computed as m t and v t , respectively. AdaSAM then updates iterate along -m t with the adaptive learning rate γη t . ; mt = β1mt-1 + (1 -β1)gt; vt = β2vt-1 + (1 -β2)[gt] 2 ; vt = max(vt-1, vt); ηt = 1/ √ vt; xt+1 = xt -γmt ⊙ ηt; end Remark 3.1. Below, we give several comments on AdaSAM: • When β 2 = 1, the adaptive learning rate reduce to the diminishing one as SGD. Then, AdaSAM recovers the classic SAM optimizer. • If we drop out the 8-th line vt = max(v t-1 , v t ), then our algorithm becomes the variant of Adam. The counterexample that Adam does not converge in the Reddi et al. ( 2018) also holds for the SAM variant, while AdaSAM can converge.

3.1. CONVERGENCE ANALYSIS

Before stating our convergence analysis, we first introduce some useful assumptions. Assumption 3.2. L-smooth. f i and f is differentiable with gradient Lipschitz property: ∥∇f i (x) -∇f i (y)∥ ≤ L∥x -y∥, ∥∇f (x) -∇f (y)∥ ≤ L∥x -y∥, ∀x, y ∈ R d , i = 1, 2, ..., N, which also implies the descent inequality, i.e., f i (y) ≤ f i (x) + ⟨∇f i (x), y -x⟩ + L 2 ∥y -x∥ 2 . Assumption 3.3. Bounded variance. The estimator of the gradient is unbiased and the variance of the stochastic gradient is bounded: E∇f i (x) = ∇f (x), E∥∇f i (x) -∇f (x)∥ 2 ≤ σ 2 . When the mini-batch size b is used, we have E∥∇f b (x) -∇f (x)∥ 2 ≤ σ 2 b . Assumption 3.4. Bounded stochastic gradients. The stochastic gradient is uniformly bounded, i.e., ∥∇f i (x)∥ ∞ ≤ G, f or any i = 1, . . . , N. Remark 3.5. The above assumptions are commonly used in the proof of the convergence for adaptive stochastic gradient methods such as Cutkosky & Orabona (2019) ; Huang et al. (2021) ; Zhou et al. (2018) ; Chen et al. (2019) . Note that AdaSAM is the combination of AMSGrad and SAM. We briefly explain the main idea of analysis in AMSGradReddi et al. (2018) and SAMAndriushchenko & Flammarion (2022) . We firstly notice that the main step in AMSGrad analysis is to estimate the expectation E[x t+1 -x t ] = -Em t ⊙ η t = -E(1 -β 1 )g t ⊙ η t -Eβ 1 m t-1 ⊙ η t , which is conditioned on the filtration σ(x t ). First, we consider the situation that β 1 = 0 which does not include the momentum. Some works Zaheer et al. (2018) ; Savarese et al. (2021) apply delay technology to split the dependence between g t and η t , that is Eg t ⊙ η t = E[g t ⊙ η t-1 ] + E[g t ⊙ (η t -η t-1 )] = ∇f (x t ) ⊙ η t-1 + E[g t ⊙ (η t -η t-1 )]. The second term E[g t ⊙ (η t -η t-1 )] is dominated by the first term ∇f (x t ) ⊙ η t-1 . Then, it is not difficult to get the convergence result of AMSGrad. When we apply the same strategy to AdaSAM, we find that Eg t ⊙η t-1 cannot be handled similarly because Eg t = E∇ x f b x + ρ ∇f b (x) ∥∇f b (x)∥ ̸ = ∇f (x t ). Inspired by (Andriushchenko & Flammarion, 2022, Lemma 16) , our key observation is that E∇ x f b x + ρ ∇f b (x) ∥∇f b (x)∥ ≈ E∇ x f b x + ρ ∇f (x) ∥∇f (x)∥ = ∇ x f x + ρ ∇f (x) ∥∇f (x)∥ and we prove the other terms such as E ∇ x f b x + ρ ∇f b (x) ∥∇f b (x)∥ -∇ x f b x + ρ ∇f (x) ∥∇f (x)∥ ⊙ η t-1 have small values that do not dominate the convergence rate. Then, when we apply the momentum method, we find that the term Em t-1 ⊙ η t cannot be ignored. By introducing an auxiliary sequence z t = x t + β1 1-β1 (x t -x t-1 ), we have E[z t+1 -z t ] = -E[ β1 1-β1 γm t-1 ⊙ (η t-1 -η t ) -γg t ⊙ η t ]. The first term contains the momentum term which has a small value and we can remove it without hurting the convergence rate. Theorem 3.6. Under the assumptions 3.2,3.3,3.4, and γ is a fixed number that satisfies that γ ≤ ϵ 16L , for the sequence {x t } generated by Algorithm 1, we have the following convergence rate 1 T T -1 t=0 E∥∇f (x t )∥ 2 2 ≤ 2G(f (x 0 ) -f * ) γT + 8GγL ϵ σ 2 bϵ + Φ where Φ = 45GL 2 ρ 2 t ϵ + 2G 3 (1 -β 1 )T d( 1 ϵ - 1 G ) + 8GγL ϵ Lρ 2 t ϵ + 2(4 + ( β1 1-β1 ) 2 )γLG 3 T d(ϵ -2 -G -2 ) + 6γ 2 L 2 β 2 1 (1 -β 1 ) 2 dG 3 ϵ 3 . Corollary 3.7 (mini-batch linear speedup). Under the same conditions of Theorem 3.6. Furthermore, when we choose the learning rate γ = O( b T ) and neighborhood size ρ = O( 1 bT ) , we have 1 T T -1 t=0 E∥∇f (x t )∥ 2 2 = O 1 √ bT + O 1 bT + O 1 T + O 1 b 1 2 T 3 2 + O b 1 2 T 3 2 + O b T . When T is sufficiently large, we achieve the linear speedup convergence rate with respect to mini-batch size b, i.e., 1 T T -1 t=0 E∥∇f (x t )∥ 2 2 = O 1 √ bT . Remark 3.8. To reach a O(δ) stationary point, when the batch size is 1, it needs T = O( 1 δ 2 ) iterations. When the batch size is b, we need to run T = O( 1bδ 2 ) steps. The method with batch size b is b times faster than batch size of 1.

3.2. PROOF SKETCH

In this part, we will give the proof sketch of the Theorem 3.6. For the complete proof, please see Appendix. We first introduce an auxiliary sequence z t = x t + β1 1-β1 (x t -x t-1 ). By applying L-smooth condition, we have f (z t+1 ) ≤ f (z t ) + ⟨∇f (z t ), z t+1 -z t ⟩ + L 2 ∥z t+1 -z t ∥ 2 . ( ) Applying it to the sequence {z t } and using the delay strategy yield f (z t+1 ) -f (z t ) ≤ ⟨∇f (z t ), γβ 1 1 -β 1 m t-1 ⊙ (η t-1 -η t )⟩ + L 2 ∥z t+1 -z t ∥ 2 + ⟨∇f (z t ), γ b i∈B ∇f i (x t + ρ t s t ∥s t ∥ ) ⊙ (η t-1 -η t )⟩ + ⟨∇f (z t ) -∇f (x t ), - γ b i∈B ∇f i (x t + ρ t s t ∥s t ∥ ) ⊙ η t-1 ⟩ + ⟨∇f (x t ), γ b i∈B ∇f i (x t + ρ t ∇f (x t ) ∥∇f (x t )∥ ) ⊙ η t-1 - γ b i∈B ∇f i (x t + ρ t s t ∥s t ∥ ) ⊙ η t-1 ⟩ + ⟨∇f (x t ), - γ b i∈B ∇f i (x t + ρ t ∇f (x t ) ∥∇f (x t )∥ ) ⊙ η t-1 ⟩. From the Lemma B.5, Lemma B.6, Lemma B.7 in appendix, we can bound the above terms in (9) as follows ⟨∇f (z t ), γ b i∈B ∇f i (x t + ρ t s t ∥s t ∥ ) ⊙ (η t-1 -η t )⟩ ≤ γG 2 ∥η t-1 -η t ∥ 1 (10) ⟨∇f (z t ), γβ 1 1 -β 1 m t-1 ⊙ (η t-1 -η t )⟩ ≤ γβ 1 1 -β 1 G 2 ∥η t-1 -η t ∥ 1 (11) ⟨∇f (x t ), γ b i∈B ∇f i (x t + ρ t ∇f (x t ) ∥∇f (x t )∥ ) ⊙ η t-1 - γ b i∈B ∇f i (x t + ρ t s t ∥s t ∥ ) ⊙ η t-1 ⟩ ≤ γ 2µ 2 ∥∇f (x t ) ⊙ √ η t-1 ∥ 2 + 2µ 2 γL 2 ρ 2 t ϵ . ( ) Then we substitute them into the (9), and take the conditional expectation to get Ef (z t+1 ) -f (z t ) ≤ E⟨∇f (x t ), - γ b i∈B ∇f i (x t + ρ t ∇f (x t ) ∥∇f (x t )∥ ) ⊙ η t-1 ⟩ + L 2 E∥z t+1 -z t ∥ 2 + γ 2µ 2 ∥∇f (x t ) ⊙ √ η t-1 ∥ 2 + 2µ 2 γL 2 ρ 2 t ϵ + γ 1 -β 1 G 2 ∥η t-1 -η t ∥ 1 + E⟨∇f (z t ) -∇f (x t ), - γ b i∈B ∇f i (x t + ρ t s t ∥s t ∥ ) ⊙ η t-1 ⟩, where µ > 0 is to be determined. From the Lemma B.8, Lemma B.10 and Lemma B.9 in Appendix, we have E⟨∇f (x t ), - γ b i∈B ∇f i (x t + ρ t ∇f (x t ) ∥∇f (x t )∥ ) ⊙ η t-1 ⟩ ≤ -γ∥∇f (x t ) ⊙ √ η t-1 ∥ 2 + E γ 2α 2 ∥∇f (x t ) ⊙ √ η t-1 ∥ 2 + γα 2 L 2 ρ 2 t 2ϵ (14) L 2 E∥z t+1 -z t ∥ 2 ≤ LG 2 γ 2 β 2 1 (1 -β 1 ) 2 E∥η t -η t-1 ∥ 2 + γ 2 L(3 1 + β βϵ (E∥∇f (x t ) ⊙ √ η t-1 ∥ 2 + Lρ 2 t ϵ + σ 2 bϵ ) + (1 + β)G 2 E∥η t -η t-1 ∥ 2 ) (15) E⟨∇f (z t ) -∇f (x t ), - γ b i∈B ∇f i (x t + ρ t s t ∥s t ∥ ) ⊙ η t-1 ⟩ ≤ γ 3 L 2 β 2 1 2ϵ(1 -β 1 ) 2 ( 1 λ 2 1 + 1 λ 2 2 + 1 λ 2 3 ) dG 2 ∞ ϵ 2 + γλ 2 1 2 ∥∇f (x t ) ⊙ √ η t-1 ∥ 2 + γL 2 ρ 2 t 2ϵ (λ 2 2 + 4λ 2 3 ). (16) Next, we substitute it into the (13). Taking the expectation over all history information yields Ef (x t+1 ) -Ef (x t ) ≤ -γ(1 - 1 2µ 2 - 1 2α 2 - 3γL(1 + β) βϵ - λ 2 1 2 )E∥∇f (x t ) ⊙ √ η t-1 ∥ 2 + 2µ 2 γL 2 ρ 2 t ϵ + γ 1 -β 1 G 2 E∥η t-1 -η t ∥ 1 + γ 3 L 2 β 2 1 2ϵ(1 -β 1 ) 2 ( 1 λ 2 1 + 1 λ 2 2 + 1 λ 2 3 ) dG 2 ∞ ϵ 2 + γL 2 ρ 2 t 2ϵ (λ 2 2 + 4λ 2 3 ) + γα 2 L 2 ρ 2 2ϵ + 3γ 2 L(1 + β) βϵ ( Lρ 2 t ϵ + σ 2 bϵ ) + γ 2 LG 2 (( β 1 1 -β 1 ) 2 + 1 + β)E∥η t -η t-1 ∥ 2 . ( ) We set µ 2 = α 2 = 8, β = 3, λ 2 1 = 1 4 , λ 2 2 = λ 2 3 = 1 and we choose 2γL ϵ ≤ 1 8 . Note that η t is bounded. We have γ 2G E∥∇f (x t )∥ 2 ≤ γ 2 E∥∇f (x t ) ⊙ √ η t-1 ∥ 2 (18) ≤ -Ef (x t+1 ) + Ef (x t ) + 45γL 2 ρ 2 t 2ϵ + γ 1 -β 1 G 2 E∥η t-1 -η t ∥ 1 + 4γ 2 L ϵ ( Lρ 2 t ϵ + σ 2 bϵ ) + (4 + ( β 1 1 -β 1 ) 2 )γ 2 LG 2 E∥η t -η t-1 ∥ 2 + 3γ 3 L 2 β 2 1 (1 -β 1 ) 2 dG 2 ∞ ϵ 3 . ( ) Then, telescoping it from t = 0 to t = T -1, and assuming γ is a constant, it follows that 1 T T -1 t=0 E∥∇f (x t )∥ 2 ≤ 2G(f (x 0 ) -f * ) γT + 8GγL ϵ σ 2 bϵ + 45GL 2 ρ 2 t ϵ + 2G 3 (1 -β 1 )T d( 1 ϵ - 1 G ) + 8GγL ϵ Lρ 2 t ϵ + 2(4 + ( β1 1-β1 ) 2 )γLG 3 T d(ϵ -2 -G -2 ) + 6γ 2 L 2 β 2 1 (1 -β 1 ) 2 dG 3 ϵ 3 , which completes the proof.

4. EXPERIMENTS

In this section, we apply AdaSAM to training language models and compare it with SGD, AMSGrad, and SAM to show the effectiveness of AdaSAM. Due to space limitations, more experimental results visualization, task description, implementation details and ablation study are placed in the Appendix. Implementations. We conduct our experiments using a widely-used pretrained language model, RoBERTalargefoot_0 in the open-source toolkit fairseqfoot_1 , with 24 transformer layers, a hidden size of 1024. For fine-tuning on each task, we use different combinations of hyper-parameters, including the learning rate, the number of epochs, the batch size, etcfoot_2 . In particular, for RTE, STS-B and MRPC of GLUE benchmark, we first fine-tune the pre-trained RoBERTa-large model on the MNLI dataset and continue fine-tuning the RoBERTa-large-MNLI model on the corresponding single-task corpus for better performance, as many prior works did Liu et al. (2019) ; He et al. (2020) . All models are trained on NVIDIA DGX SuperPOD cluster, in which each machine contains 8 x 40GB A100 GPUs. 

4.2. RESULTS ON GLUE BENCHMARK

We conduct the experiments on the SGD based method and adaptive learning rate method, respectively. Each method contains SAM and the base optimizer. Table 1 shows the AMSGrad and SAM with adaptive learning rate. For the SAM with an adaptive learning rate, we tune the neighborhood size from 0.01, 0.005, 0.001. The result shows that SAM with an adaptive learning rate outperforms AMSGrad on 6 tasks of 8 tasks except for QNLI and QQP. Overall, it improves the 0.28 than AMSGrad. Table 2 shows the result of SGD and SAM. We Comparing the results of Table 1 and Table 2 , we can find that the adaptive learning rate method is better than SGD based method. SAM with an adaptive learning rate achieves the best metric on 6 tasks. In general, SAM with an adaptive learning rate is better than the other methods. In addition, we conduct the experiments the momentum is set to 0 to evaluate the influence of the adaptive learning rate. Table 3 shows that SAM with adaptive learning rate outperforms AMSGrad on 6 tasks of 8 tasks except for SST-2 and RTE. In Table 4 , we compare SGD and SAM, and without the momentum, SAM outperforms SGD on all tasks. Under this situation, SAM with an adaptive learning rate method is better than the other methods. When comparing the result of Table 1 and Table 3 , adaptive learning rate method is more important than the momentum. When there is no momentum term, SAM with adaptive learning rate improves 0.74 than AMSGrad. With a momentum term, SAM with adaptive learning rate improves only 0.28.

4.3. MINI-BATCH SPEEDUP

In this part, we test the performance with different batch sizes to validate the linear speedup property. The experiments are conducted on the MRPC, RTE, and CoLA tasks. The batch size is set as 4, 8, 16, 32, respectively. We scale the learning rate as √ N , which is similar as Li et al. (2021) . N is the batch size. The results show that the training loss decreases faster as the batchsize increases, and the loss curve with the batch size of 32 achieves nearly half iterations as the curve with the batch size of 16. 

5. CONCLUSION

In this work, we study the convergence rate of Sharpness aware minimization (SAM) optimizer with adaptive learning rate from AMSGrad in the stochastic non-convex setting. To the best of our knowledge, we are the first to provide the non-trivial O(1/ √ bT ) convergence rate of SAM with an adaptive learning rate, which achieves a linear speedup property with respect to mini-batch size b. We have conducted extensive experiments on several NLP tasks, which verifies that AdaSAM could achieve superior performance compared with AMSGrad and SAM optimizers. Future works include extending AdaSAM to the distributed setting and reducing the twice gradient back-propagation cost.

A EXPERIMENTAL SETTINGS

Table 5 : Experimental settings and data divisions upon different downstream tasks. Notably, for each tasks in GLUE benchmark, we provide the number of classes ("classes"), the learning rate ("lr"), the batch size ("bsz"), the total number of updates ("total"), the number of warmup updates ("warmup") and the number of GPUs ("GPUs") during fine-tuning, respectively. CoLA is a single sentence task. Each sentence has a label 1 and -1. 1 represents that it is a grammatical sentence, while -1 represents that it is illegal. Matthews correlation coefficient, dubbed mcc is used as our evaluation metric. STS-B is a similarity and paraphrase task. Each sample has a pair of a paragraph. People annotated the sample from 1 to 5 based on the similarity between the two paragraphs. The metric is Pearson and Spearman, dubbed p/s correlation coefficients. RTE is an inference task. Each sample has two sentences. If two sentences have a relation of entailment, we view them as a positive sample. If not, they compose of a negative sample. In the RTE task, the metric is the accuracy, dubbed acc. SST-2 is a single sentence task and its metric is the accuracy. MNLI is a sentence-level task that has 3 classes. They are entailment, contradiction and neutral. MRPC is a task to classify whether the sentences in the pair are equivalent. QNLI is a question-answering task. If the sentence contains the answer to the question, then it is a positive sample. QQP is a social question-answering task that consists of question pairs from Quora. It determines whether the questions are equivalent. The metric of MNLI, MRPC, QNLI, QQP is accuracy.

B PROOF

We set zt = xt + β 1 1-β 1 (xt -xt-1) for t ≥ 0 and we assume x-1 = 0 and m-1 = 0. We have that zt+1 -zt = xt+1 + β1 1 -β1 (xt+1 -xt) -xt - β1 1 -β1 (xt -xt-1) (21) = 1 1 -β1 (xt+1 -xt) - β1 1 -β1 (xt -xt-1) (22) = - 1 1 -β1 γmt ⊙ ηt + β1 1 -β1 (xt -xt-1)γmt-1 ⊙ ηt-1 (23) = - 1 1 -β1 γ(β1mt-1 + (1 -β1)gt) ⊙ ηt + β1 1 -β1 (xt -xt-1)γmt-1 ⊙ ηt-1 (24) = β1 1 -β1 γmt-1 ⊙ (ηt-1 -ηt) -γgt ⊙ ηt By applying L-smooth, we have f (zt+1) ≤ f (zt) + ⟨∇f (zt), zt+1 -zt⟩ + L 2 ∥zt+1 -zt∥ 2 We re-organize it, and we have f (zt+1) -f (zt) ≤ ⟨∇f (zt), zt+1 -zt⟩ + L 2 ∥zt+1 -zt∥ 2 (27) = ⟨∇f (zt), γβ1 1 -β1 mt-1 ⊙ (ηt-1 -ηt)⟩ + ⟨∇f (zt), -γgt ⊙ ηt⟩ + L 2 ∥zt+1 -zt∥ 2 (28) = ⟨∇f (zt), γβ1 1 -β1 mt-1 ⊙ (ηt-1 -ηt)⟩ + L 2 ∥zt+1 -zt∥ 2 + ⟨∇f (zt), γt b i∈B ∇fi(xt + ρt st ∥st∥ ) ⊙ (ηt-1 -ηt)⟩ + ⟨∇f (zt), - γt b i∈B ∇fi(xt + ρt st ∥st∥ ) ⊙ ηt-1⟩ (29) = ⟨∇f (zt), γβ1 1 -β1 mt-1 ⊙ (ηt-1 -ηt)⟩ + L 2 ∥zt+1 -zt∥ 2 + ⟨∇f (zt), γt b i∈B ∇fi(xt + ρt st ∥st∥ ) ⊙ (ηt-1 -ηt)⟩ + ⟨∇f (zt) -∇f (xt), - γt b i∈B ∇fi(xt + ρt st ∥st∥ ) ⊙ ηt-1⟩ + ⟨∇f (xt), - γt b i∈B ∇fi(xt + ρt st ∥st∥ ) ⊙ ηt-1⟩ (30) = ⟨∇f (zt), γβ1 1 -β1 mt-1 ⊙ (ηt-1 -ηt)⟩ + L 2 ∥zt+1 -zt∥ 2 + ⟨∇f (zt), γt b i∈B ∇fi(xt + ρt st ∥st∥ ) ⊙ (ηt-1 -ηt)⟩ + ⟨∇f (zt) -∇f (xt), - γt b i∈B ∇fi(xt + ρt ∇fi(xt) ∥ ∇fi(xt)∥ ) ⊙ ηt-1⟩ + ⟨∇f (xt), γt b i∈B ∇fi(xt + ρt ∇f (xt) ∥∇f (xt)∥ ) ⊙ ηt-1 - γt b i∈B ∇fi(xt + ρt st ∥st∥ ) ⊙ ηt-1⟩ + ⟨∇f (xt), - γt b i∈B ∇fi(xt + ρt ∇f (xt) ∥∇f (xt)∥ ) ⊙ ηt-1⟩. From the Lemma B.5, Lemma B.6, Lemma B.7, we have ⟨∇f (zt), γt b i∈B ∇fi(xt + ρt st ∥st∥ ) ⊙ (ηt-1 -ηt)⟩ ≤ γtG 2 ∥ηt-1 -ηt∥ 1 , (32) ⟨∇f (zt), γβ1 1 -β1 mt-1 ⊙ (ηt-1 -ηt)⟩ ≤ γβ1 1 -β1 G 2 ∥ηt-1 -ηt∥ 1 , (33) ⟨∇f (xt), ηt b i∈B ∇fi(xt + ρt ∇f (xt) ∥∇f (xt)∥ ) ⊙ ηt-1 - γt b i∈B ∇fi(xt + ρt st ∥st∥ ) ⊙ ηt-1⟩ ≤ γt 2µ 2 ∥∇f (xt) ⊙ √ ηt-1∥ 2 + 2µ 2 γtL 2 ρ 2 t ϵ . ( ) Taking conditional expectation, we have Ef (zt+1) -f (zt) (35) ≤ E⟨∇f (xt), - γt b i∈B ∇fi(xt + ρt ∇f (xt) ∥∇f (xt)∥ ) ⊙ ηt-1⟩ + L 2 E∥zt+1 -zt∥ 2 + γt 2µ 2 ∥∇f (xt) ⊙ √ ηt-1∥ 2 + 2µ 2 γtL 2 ρ 2 t ϵ + γ 1 -β1 G 2 ∥ηt-1 -ηt∥ 1 + E⟨∇f (zt) -∇f (xt), - γt b i∈B ∇fi(xt + ρt st ∥st∥ ) ⊙ ηt-1⟩ where µ > 0 is to be determined. For the term E⟨∇f (xt), - γt b i∈B ∇fi(xt + ρt ∇f (xt) ∥∇f (xt)∥ ) ⊙ ηt-1⟩, ( ) the term L 2 E∥zt+1 -zt∥ 2 , ( ) and the term E⟨∇f (zt) -∇f (xt), - γt b i∈B ∇fi(xt + ρt st ∥st∥ ) ⊙ ηt-1⟩, we introduce the Lemma B.8, the Lemma B.10 and the Lemma B.9. We take the expectation over the whole processing and we have Ef (zt+1) -Ef (zt) ≤ γt 2µ 2 E∥∇f (xt) ⊙ √ ηt-1∥ 2 + 2µ 2 γtL 2 ρ 2 t ϵ + γ 1 -β1 G 2 E∥ηt-1 -ηt∥ 1 -γtE∥∇f (xt) ⊙ √ ηt-1∥ 2 + E γt 2α 2 E∥∇f (xt) ⊙ √ ηt-1∥ 2 + γtα 2 L 2 ρ 2 2ϵ + LG 2 γ 2 β 2 1 (1 -β1) 2 E∥ηt -ηt-1∥ 2 + γ 2 t L(3 1 + β βϵ (E∥∇f (xt) ⊙ √ ηt-1∥ 2 + Lρ 2 t ϵ + σ 2 bϵ ) + (1 + β)G 2 E∥ηt -ηt-1∥ 2 ) + γ 3 L 2 β 2 1 2ϵ(1 -β1) 2 ( 1 λ 2 1 + 1 λ 2 2 + 1 λ 2 3 ) dG 2 ∞ ϵ 2 + γλ 2 1 2 ∥∇f (xt) ⊙ √ ηt-1∥ 2 + γL 2 ρ 2 t 2ϵ (λ 2 2 + 4λ 2 3 ) (40) = -γt(1 - 1 2µ 2 - 1 2α 2 - 3γL(1 + β) βϵ - λ 2 1 2 )E∥∇f (xt) ⊙ √ ηt-1∥ 2 + 2µ 2 γtL 2 ρ 2 t ϵ + γ 1 -β1 G 2 E∥ηt-1 -ηt∥ 1 + γtα 2 L 2 ρ 2 2ϵ + 3γ 2 t L(1 + β) βϵ ( Lρ 2 t ϵ + σ 2 bϵ ) + γ 2 t LG 2 (( β1 1 -β1 ) 2 + 1 + β)E∥ηt -ηt-1∥ 2 + γ 3 L 2 β 2 1 2ϵ(1 -β1) 2 ( 1 λ 2 1 + 1 λ 2 2 + 1 λ 2 3 ) dG 2 ∞ ϵ 2 + γL 2 ρ 2 t 2ϵ (λ 2 2 + 4λ 2 3 ). ( ) We set µ 2 = α 2 = 8, β = 3, λ 2 1 = 1 4 , λ 2 2 = λ 2 3 = 1 and we choose 2γ t L ϵ ≤ 1 8 . So we have Ef (xt+1) -Ef (xt) ≤ - γt 2 E∥∇f (xt) ⊙ √ ηt-1∥ 2 + 16γtL 2 ρ 2 t ϵ + γ 1 -β1 G 2 E∥ηt-1 -ηt∥ 1 + 4γtL 2 ρ 2 ϵ + 4γ 2 t L ϵ ( Lρ 2 t ϵ + σ 2 bϵ ) + (4 + ( β1 1 -β1 ) 2 )γ 2 t LG 2 E∥ηt -ηt-1∥ 2 + 3γ 3 L 2 β 2 1 ϵ(1 -β1) 2 dG 2 ∞ ϵ 2 + 5γL 2 ρ 2 t 2ϵ We re-arrange it and ηt is bounded. We have γt 2G E∥∇f (xt)∥ 2 ≤ γt 2 E∥∇f (xt) ⊙ √ ηt-1∥ 2 (43) ≤ -Ef (xt+1) + Ef (xt) + 45γtL 2 ρ 2 t 2ϵ + γ 1 -β1 G 2 E∥ηt-1 -ηt∥ 1 + 4γ 2 t L ϵ ( Lρ 2 t ϵ + σ 2 bϵ ) + (4 + ( β1 1 -β1 ) 2 )γ 2 t LG 2 E∥ηt -ηt-1∥ 2 + 3γ 3 L 2 β 2 1 (1 -β1) 2 dG 2 ∞ ϵ 3 . ( ) We summary it from t = 0 to t = T -1, and we assume γt is a constant, and we have 1 T T -1 t=0 E∥∇f (xt)∥ 2 ≤ 2G Ef (x0) -Ef (xt+1) γtT + 45GL 2 ρ 2 t ϵ + 2G 3 (1 -β1)T E T -1 t=0 ∥ηt-1 -ηt∥ 1 + 8GγtL ϵ ( Lρ 2 t ϵ + σ 2 bϵ ) + 2(4 + ( β 1 1-β 1 ) 2 )γtLG 3 T E T -1 t=0 ∥ηt -ηt-1∥ 2 + 6γ 2 L 2 β 2 1 (1 -β1) 2 dG 3 ϵ 3 (45) ≤ 2G(f (x0) -f * ) γtT + 45GL 2 ρ 2 t ϵ + 2G 3 (1 -β1)T d( 1 ϵ - 1 G ) + 8GγtL ϵ ( Lρ 2 t ϵ + σ 2 bϵ ) + 2(4 + ( β 1 1-β 1 ) 2 )γtLG 3 T d(ϵ -2 -G -2 ) + 6γ 2 L 2 β 2 1 (1 -β1) 2 dG 3 ϵ 3 (46) = 2G(f (x0) -f * ) γtT + 8GγtL ϵ σ 2 bϵ + 45GL 2 ρ 2 t ϵ + 2G 3 (1 -β1)T d( 1 ϵ - 1 G ) + 8GγtL ϵ Lρ 2 t ϵ + 2(4 + ( β 1 1-β 1 ) 2 )γtLG 3 T d(ϵ -2 -G -2 ) + 6γ 2 L 2 β 2 1 (1 -β1) 2 dG 3 ϵ 3 . (47) B.1 TECHNICAL LEMMA Lemma B.1. Given two vectors a, b ∈ R d , we have ⟨a, b⟩ ≤ λ 2 2 ∥a∥ 2 + 1 2λ 2 ∥b∥ 2 for parameter λ, ∀λ ∈ (1, +∞). Proof. RHS = λ 2 2 d j=1 (a) 2 j + 1 2λ 2 d j=1 (b) 2 j ≥ d j=1 2 λ 2 2 (a) 2 j × 1 2λ 2 (b) 2 j = d j=1 |(a)j| × |(b)j| ≥ LHS. ( ) Lemma B.2. For any vector x,y ∈ R d , we have ∥x ⊙ y∥ 2 ≤ ∥x∥ 2 × ∥y∥ 2 ∞ ≤ ∥x∥ 2 × ∥y∥ 2 . ( ) Proof. The first inequality can be derived from that d i=1 (x 2 i y 2 i ) ≤ d i=1 (x 2 i ∥y∥ 2 ∞ ). The second inequality follows from that ∥y∥ 2 ∞ ≤ ∥y∥ 2 . Lemma B.3. η is bounded, i.e., 1 G∞ ≤ (ηt)j ≤ 1 ϵ . Proof. As the gradient is bounded by G and (ηt )j = 1 √ (v t ) j . Follow the update rule, we have 1 G∞ ≤ (ηt)j ≤ 1 ϵ . Lemma B.4. For the term defined in the algorithm, we have 1 T E T -1 t=0 ∥ηt-1 -ηt∥ 1 ≤ d T ( 1 ϵ - 1 G ) Proof. (ηt)i, the i-th dimension of ηt deceases as t increases. So we have 1 T E T -1 t=0 ∥ηt-1 -ηt∥ 1 = E 1 T d i=1 T -1 t=0 |(ηt-1)i -(ηt)i| ≤ E 1 T d i=1 ((η-1)i -(ηT -1)i) ≤ E 1 T d i=1 ( 1 ϵ - 1 G ) = d T ( 1 ϵ - 1 G ) Lemma B.5. For the term defined in the algorithm, we have ⟨∇f (zt), γt b i∈B ∇fi(xt + ρt st ∥st∥ ) ⊙ (ηt-1 -ηt)⟩ ≤ γtG 2 ∥ηt-1 -ηt∥ 1 (52) Proof. ⟨∇f (zt), γt b i∈B ∇fi(xt + ρt st ∥st∥ ) ⊙ (ηt-1 -ηt)⟩ ≤ γt d j=1 |(∇f (zt)) (j) | × |( 1 b i∈B ∇fi(xt + ρt ∇fi(xt) ∥ ∇fi(xt)∥ ) ⊙ (ηt-1 -ηt)) (j) | (53) ≤ γtG d j=1 |(( 1 b i∈B ∇fi(xt + ρt ∇fi(xt) ∥ ∇fi(xt)∥ ) ⊙ (ηt-1 -ηt)) (j) | (54) ≤ γtG b d j=1 i∈B |((∇fi(xt + ρt ∇fi(xt) ∥ ∇fi(xt)∥ ) ⊙ (ηt-1 -ηt)) (j) | (55) = γtG b d j=1 i∈B |(∇fi(xt + ρt ∇fi(xt) ∥ ∇fi(xt)∥ ) (j) × (ηt-1 -ηt) (j) | (56) ≤ γtG 2 b d j=1 i∈B |(ηt-1 -ηt) (j) | (57) = γtG 2 ∥ηt-1 -ηt∥ 1 (58) Lemma B.6. For the term defined in the algorithm, we have ⟨∇f (zt), γβ1 1 -β1 mt-1 ⊙ (ηt-1 -ηt)⟩ ≤ γβ1 1 -β1 G 2 ∥ηt-1 -ηt∥ 1 (59) Proof. ⟨∇f (zt), γβ1 1 -β1 mt-1 ⊙ (ηt-1 -ηt)⟩ ≤ γβ1 1 -β1 d j=1 |(∇f (zt)) (j) | × |(mt-1 ⊙ (ηt-1 -ηt)) (j) | (60) ≤ γβ1 1 -β1 G d j=1 |(mt-1 ⊙ (ηt-1 -ηt)) (j) | (61) = γβ1 1 -β1 d j=1 |(mt-1) (j) × (ηt-1 -ηt) (j) | (62) ≤ γβ1 1 -β1 G 2 d j=1 |(ηt-1 -ηt) (j) | (63) = γβ1 1 -β1 G 2 ∥ηt-1 -ηt∥ 1 Lemma B.7. For the term defined in the algorithm, we have ⟨∇f (xt), γt b i∈B ∇fi(xt + ρt ∇f (xt) ∥∇f (xt)∥ ) ⊙ ηt-1 - γt b i∈B ∇fi(xt + ρt st ∥st∥ ) ⊙ ηt-1⟩ ≤ γt 2µ 2 ∥∇f (xt) ⊙ √ ηt-1∥ 2 + 2µ 2 γtL 2 ρ 2 t ϵ . Proof. ⟨∇f (xt), γt b i∈B ∇fi(xt + ρt ∇f (xt) ∥∇f (xt)∥ ) ⊙ ηt-1 - γt b i∈B ∇fi(xt + ρt st ∥st∥ ) ⊙ ηt-1⟩ = ⟨∇f (xt) ⊙ √ ηt-1, γt b i∈B (∇fi(xt + ρt ∇f (xt) ∥∇f (xt)∥ ) -∇fi(xt + ρt i∈B ∇fi(xt) ∥ i∈B ∇fi(xt)∥ )) ⊙ √ ηt-1⟩ ≤ µ 2 γt 2b 2 ∥ (∇fi(xt + ρt ∇f (xt) ∥∇f (xt)∥ ) -∇fi(xt + ρt i∈B ∇fi(xt) ∥ i∈B ∇fi(xt)∥ )) ⊙ √ ηt-1∥ 2 + γt 2µ 2 ∥∇f (xt) ⊙ √ ηt-1∥ 2 (67) ≤ + µ 2 γt 2b ∥∇fi(xt + ρt ∇f (xt) ∥∇f (xt)∥ ) -∇fi(xt + ρt i∈B ∇fi(xt) ∥ i∈B ∇fi(xt)∥ ) ⊙ √ ηt-1∥ 2 + γt 2µ 2 ∥∇f (xt) ⊙ √ ηt-1∥ 2 (68) ≤ + µ 2 γt 2b ∥∇fi(xt + ρt ∇f (xt) ∥∇f (xt)∥ ) -∇fi(xt + ρt i∈B ∇fi(xt) ∥ i∈B ∇fi(xt)∥ )∥ 2 × ∥ √ ηt-1∥ 2 ∞ + γt 2µ 2 ∥∇f (xt) ⊙ √ ηt-1∥ 2 (69) ≤ γt 2µ 2 ∥∇f (xt) ⊙ √ ηt-1∥ 2 + µ 2 γtL 2 ρ 2 t 2bϵ ∥ ∇f (xt) ∥∇f (xt)∥ - i∈B ∇fi(xt) ∥ i∈B ∇fi(xt)∥ ∥ 2 (70) ≤ γt 2µ 2 ∥∇f (xt) ⊙ √ ηt-1∥ 2 + 2µ 2 γtL 2 ρ 2 t ϵ . Lemma B.8. For the term defined in the algorithm, we have E⟨∇f (xt), - γt b i∈B ∇fi(xt + ρt ∇f (xt) ∥∇f (xt)∥ ) ⊙ ηt-1⟩ ≤ -γt∥∇f (xt) ⊙ √ ηt-1∥ 2 + E γt 2α 2 ∥∇f (xt) ⊙ √ ηt-1∥ 2 + γtα 2 L 2 ρ 2 t 2ϵ Proof. E⟨∇f (xt), - γt b i∈B ∇fi(xt + ρt ∇f (xt) ∥∇f (xt)∥ ) ⊙ ηt-1⟩ = -γt∥∇f (xt) ⊙ √ ηt-1∥ 2 + E⟨∇f (xt), γt b i∈B (∇f (xt) -∇fi(xt + ρt ∇f (xt) ∥∇f (xt)∥ )) ⊙ ηt-1⟩ (73) = -γt∥∇f (xt) ⊙ √ ηt-1∥ 2 + E⟨∇f (xt), γt b i∈B (∇fi(xt) -∇fi(xt + ρt ∇f (xt) ∥∇f (xt)∥ )) ⊙ ηt-1⟩ (74) ≤ -γt∥∇f (xt) ⊙ √ ηt-1∥ 2 + E γt 2α 2 ∥∇f (xt) ⊙ √ ηt-1∥ 2 + γtα 2 2 E∥ 1 b i∈B (∇fi(xt) -∇fi(xt + ρt ∇f (xt) ∥∇f (xt)∥ )) ⊙ √ ηt-1∥ 2 (75) ≤ -γt∥∇f (xt) ⊙ √ ηt-1∥ 2 + E γt 2α 2 ∥∇f (xt) ⊙ √ ηt-1∥ 2 + γtα 2 2ϵ E∥ 1 b i∈B (∇fi(xt) -∇fi(xt + ρt ∇f (xt) ∥∇f (xt)∥ ))∥ 2 (76) ≤ -γt∥∇f (xt) ⊙ √ ηt-1∥ 2 + E γt 2α 2 ∥∇f (xt) ⊙ √ ηt-1∥ 2 + γtα 2 2bϵ E i∈B ∥(∇fi(xt) -∇fi(xt + ρt ∇f (xt) ∥∇f (xt)∥ ))∥ 2 (77) ≤ -γt∥∇f (xt) ⊙ √ ηt-1∥ 2 + E γt 2α 2 ∥∇f (xt) ⊙ √ ηt-1∥ 2 + γtα 2 L 2 ρ 2 t 2bϵ E i∈B ∥ ∇f (xt) ∥∇f (xt)∥ ∥ 2 (78) = -γt∥∇f (xt) ⊙ √ ηt-1∥ 2 + E γt 2α 2 ∥∇f (xt) ⊙ √ ηt-1∥ 2 + γtα 2 L 2 ρ 2 t 2ϵ Lemma B.9. For the term defined in the algorithm, we have E⟨∇f (zt) -∇f (xt), - γt b i∈B ∇fi(xt + ρt st ∥st∥ ) ⊙ ηt-1⟩ ≤ γ 3 L 2 β 2 1 2ϵ(1 -β1) 2 ( 1 λ 2 1 + 1 λ 2 2 + 1 λ 2 3 ) dG 2 ∞ ϵ 2 + γλ 2 1 2 ∥∇f (xt) ⊙ √ ηt-1∥ 2 + γL 2 ρ 2 t 2ϵ (λ 2 2 + 4λ 2 3 ). ( ) Proof. E⟨∇f (zt) -∇f (xt), - γt b i∈B ∇fi(xt + ρt st ∥st∥ ) ⊙ ηt-1⟩ (81) = γE⟨(∇f (xt) -∇f (zt)) ⊙ √ ηt-1, 1 b i∈B ∇fi(xt + ρt i∈B ∇fi(xt) ∥ i∈B ∇fi(xt)∥ ) ⊙ √ ηt-1⟩ (82) = γE⟨(∇f (xt) -∇f (zt)) ⊙ √ ηt-1, ∇f (xt) ⊙ √ ηt-1⟩ + γE⟨(∇f (xt) -∇f (zt)) ⊙ √ ηt-1, 1 b i∈B (∇fi(xt + ρt ∇f (xt) ∥∇f (xt)∥ ) -∇fi(xt)) ⊙ √ ηt-1⟩ + γE⟨(∇f (xt) -∇f (zt)) ⊙ √ ηt-1, 1 b i∈B (∇fi(xt + ρt i∈B ∇fi(xt) ∥ i∈B ∇fi(xt)∥ ) -∇fi(xt + ρt ∇f (xt) ∥∇f (xt)∥ ) ⊙ √ ηt-1⟩ (83) ≤ γ 2 ( 1 λ 2 1 + 1 λ 2 2 + 1 λ 2 3 )E∥(∇f (xt) -∇f (zt)) ⊙ √ ηt-1∥ 2 + γλ 2 1 2 ∥∇f (xt) ⊙ √ ηt-1∥ 2 + γλ 2 2 2 E∥ 1 b i∈B (∇fi(xt + ρt ∇f (xt) ∥∇f (xt)∥ ) -∇fi(xt)) ⊙ √ ηt-1∥ 2 + γλ 2 3 2 E∥ 1 b i∈B (∇fi(xt + ρt i∈B ∇fi(xt) ∥ i∈B ∇fi(xt)∥ ) -∇fi(xt + ρt ∇f (xt) ∥∇f (xt)∥ ) ⊙ √ ηt-1∥ 2 (84) ≤ γ 2 ( 1 λ 2 1 + 1 λ 2 2 + 1 λ 2 3 )E∥(∇f (xt) -∇f (zt)) ⊙ √ ηt-1∥ 2 + γλ 2 1 2 ∥∇f (xt) ⊙ √ ηt-1∥ 2 + γλ 2 2 L 2 ρ 2 t 2ϵ + 2λ 2 3 γL 2 ρ 2 t ϵ (85) ≤ γL 2 2ϵ ( 1 λ 2 1 + 1 λ 2 2 + 1 λ 2 3 )E∥zt -xt∥ 2 + γλ 2 1 2 ∥∇f (xt) ⊙ √ ηt-1∥ 2 + γλ 2 2 L 2 ρ 2 t 2ϵ + 2λ 2 3 γL 2 ρ 2 t ϵ (86) = γ 3 L 2 β 2 1 2ϵ(1 -β1) 2 ( 1 λ 2 1 + 1 λ 2 2 + 1 λ 2 3 )∥mt-1 ⊙ ηt -1∥ 2 + γλ 2 1 2 ∥∇f (xt) ⊙ √ ηt-1∥ 2 + γλ 2 2 L 2 ρ 2 t 2ϵ + 2λ 2 3 γL 2 ρ 2 t ϵ ≤ γ 3 L 2 β 2 1 2ϵ(1 -β1) 2 ( 1 λ 2 1 + 1 λ 2 2 + 1 λ 2 3 ) dG 2 ∞ ϵ 2 + γλ 2 1 2 ∥∇f (xt) ⊙ √ ηt-1∥ 2 + γL 2 ρ 2 t 2ϵ (λ 2 2 + 4λ 2 3 ). Lemma B.10. For the term defined in the algorithm, we have L 2 E∥zt+1 -zt∥ 2 ≤ LG 2 γ 2 β 2 1 (1 -β1) 2 E∥ηt -ηt-1∥ 2 + γ 2 t L(3 1 + β βϵ (E∥∇f (xt) ⊙ √ ηt-1∥ 2 + Lρ 2 t ϵ + σ 2 bϵ ) + (1 + β)G 2 E∥ηt -ηt-1∥ 2 ) (89) Proof. L 2 E∥zt+1 -zt∥ 2 = L 2 E∥ γβ1 1 -β1 mt-1 ⊙ (ηt -ηt-1) -γgt ⊙ ηt∥ 2 (90) ≤ Lγ 2 β 2 1 (1 -β1) 2 E∥mt-1 ⊙ (ηt -ηt-1)∥ 2 + LE∥ γt b (∇fi(xt + ρt st ∥st∥ )) ⊙ ηt∥ 2 (91) ≤ LG 2 γ 2 β 2 1 (1 -β1) 2 E∥ηt -ηt-1∥ 2 + LE∥ γt b (∇fi(xt + ρt st ∥st∥ )) ⊙ ηt∥ 2 (92) = γ 2 t LE∥ 1 b (∇fi(xt + ρt st ∥st∥ )) ⊙ ηt-1 + 1 b (∇fi(xt + ρt st ∥st∥ )) ⊙ (ηt -ηt-1)∥ 2 + LG 2 γ 2 β 2 1 (1 -β1) 2 E∥ηt -ηt-1∥ 2 (93) ≤ LG 2 γ 2 β 2 1 (1 -β1) 2 E∥ηt -ηt-1∥ 2 + γ 2 t L((1 + 1 β )E∥ 1 b (∇fi(xt + ρt st ∥st∥ )) ⊙ ηt-1∥ 2 + (1 + β)E∥ 1 b (∇fi(xt + ρt st ∥st∥ )) ⊙ (ηt -ηt-1)∥ 2 ) (94) ≤ γ 2 t L((1 + 1 β )E∥ 1 b (∇fi(xt + ρt st ∥st∥ )) ⊙ ηt-1∥ 2 + (1 + β)G 2 E∥ηt -ηt-1∥ 2 ) + LG 2 γ 2 β 2 1 (1 -β1) 2 E∥ηt -ηt-1∥ 2 (95) ≤ γ 2 t L((1 + 1 β )E∥ 1 b (∇fi(xt + ρt st ∥st∥ )) ⊙ √ ηt-1∥ 2 × ∥ √ ηt-1∥ 2 ∞ + (1 + β)G 2 E∥ηt -ηt-1∥ 2 ) + LG 2 γ 2 β 2 1 (1 -β1) 2 E∥ηt -ηt-1∥ 2 (96) ≤ γ 2 t L( 1 + β βϵ E∥ 1 b (∇fi(xt + ρt st ∥st∥ )) ⊙ √ ηt-1∥ 2 + (1 + β)G 2 E∥ηt -ηt-1∥ 2 ) + LG 2 γ 2 β 2 1 (1 -β1) 2 E∥ηt -ηt-1∥ 2 (97) ≤ γ 2 t L(3 1 + β βϵ E(∥∇f (xt) ⊙ √ ηt-1∥ 2 + ∥( 1 b ∇fi(xt) -∇f (xt)) ⊙ √ ηt-1∥ 2 + ∥ 1 b (∇fi(xt + ρt i∈B ∇fi(xt) ∥ i∈B ∇fi(xt)∥ ) -∇fi(xt)) ⊙ √ ηt-1∥ 2 ) + (1 + β)G 2 E∥ηt -ηt-1∥ 2 ) + LG 2 γ 2 β 2 1 (1 -β1) 2 E∥ηt -ηt-1∥ 2 (98) ≤ γ 2 t L(3 1 + β βϵ (E∥∇f (xt) ⊙ √ ηt-1∥ 2 + E∥ 1 b (∇fi(xt + ρt i∈B ∇fi(xt) ∥ i∈B ∇fi(xt)∥ ) -∇fi(xt)) ⊙ √ ηt-1∥ 2 + σ 2 bϵ ) + (1 + β)G 2 E∥ηt -ηt-1∥ 2 ) + LG 2 γ 2 β 2 1 (1 -β1) 2 E∥ηt -ηt-1∥ 2 (99) ≤ γ 2 t L(3 1 + β βϵ (E∥∇f (xt) ⊙ √ ηt-1∥ 2 + 1 ϵ E∥ 1 b (∇fi(xt + ρt i∈B ∇fi(xt) ∥ i∈B ∇fi(xt)∥ ) -∇fi(xt))∥ 2 + σ 2 bϵ ) + (1 + β)G 2 E∥ηt -ηt-1∥ 2 ) + LG 2 γ 2 β 2 1 (1 -β1) 2 E∥ηt -ηt-1∥ 2 (100) ≤ γ 2 t L(3 1 + β βϵ (E∥∇f (xt) ⊙ √ ηt-1∥ 2 + 1 ϵb E ∥∇fi(xt + ρt i∈B ∇fi(xt) ∥ i∈B ∇fi(xt)∥ ) -∇fi(xt)∥ 2 + σ 2 bϵ ) + (1 + β)G 2 E∥ηt -ηt-1∥ 2 ) + LG 2 γ 2 β 2 1 (1 -β1) 2 E∥ηt -ηt-1∥ 2 (101) ≤ γ 2 t L(3 1 + β βϵ (E∥∇f (xt) ⊙ √ ηt-1∥ 2 + Lρ 2 t ϵ + σ 2 bϵ ) + (1 + β)G 2 E∥ηt -ηt-1∥ 2 ) + LG 2 γ 2 β 2 1 (1 -β1) 2 E∥ηt -ηt-1∥ 2 . C ADDITIONAL EXPERIMENT ILLUSTRATIONS

C.1 EXPERIMENT ILLUSTRATIONS

We conduct the experiments on the GLUE benchmark with AdaSAM, AMSGrad, SAM and SGD, respectively. The optimizers do not have the momentum part (β1 = 0). As a supplement to Table 3 and Table 4 , Figure 2 and Figure 3 show the detailed loss and evaluation metrics versus number of steps curves during training. The loss curve of AdaSAM decreases faster than SAM and SGD in all tasks, and it has a similar decreasing speed as the AMSGrad. The metric curve of AdaSAM and AMSGrad show that the adaptive learning rate method is better than SGD and SAM. And AdaSAM decrease as faster as the AMSGrad in all tasks.

C.2 ABLATION STUDY

In this section, we conduct experiments to evaluate the impact of momentum (β1 = 0.9) on different optimizers. We show the experiment results in Figure 4 and Figure 5 , respectively. The results are also illustrated in Table 1 and Table 2 . SAM with adaptive learning rate (AdaSAM with ρ = 0.9) converges as fast as AMSGrad. Both AdaSAM and AMSGrad are faster than SGD with momentum and SAM with momentum. The generalization ability of AdaSAM with momentum is much better than SAM with momentum in all tasks. Besides, AdaSAM is also better than AMSGrad in GLUE benchmark except for QNLI and QQP tasks. 



https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz https://github.com/facebookresearch/fairseq Due to the space limitation, we show the details of the dataset and training setting in Appendix A.



-aware minimization Previous work shows that sharp minima may lead to poor generalization whereas flat minima perform betterJiang et al. (2020); Keskar et al. (2017); He et al.

Figure 1: The linear speedup of SAM-AMSGrad with the number of batch size of 4, 8, 16, 32.

Figure 2: The loss and evaluation metric v.s. steps on MRPC, RTE, CoLA, SST-2, STS-B and MNLI.(β 1 = 0)

Figure 4: The loss and evaluation metric v.s. steps on MRPC, RTE, CoLA, SST-2, STS-B, MNLI.(β 1 = 0.9)

Evaluating AMSGrad and AdaSAM on the GLUE benchmark with β 1 = 0.9

Evaluating SGD and SAM on the GLUE benchmark with β 1 = 0.9

Evaluating AMSGrad and AdaSAM on the GLUE benchmark without momentum (β 1 = 0)

Evaluating SGD and SAM on the GLUE benchmark β 1 = 0 size from 0.01, 0.005 for SAM. The results show that SAM is better than SGD on 7 tasks of 8 tasks except for RTE. And SAM can significantly improve the performance.

