PROVABLE SHARPNESS-AWARE MINIMIZATION WITH ADAPTIVE LEARNING RATE

Abstract

Sharpness aware minimization (SAM) optimizer has been extensively explored as it can converge fast and train deep neural networks efficiently via introducing extra perturbation steps to flatten the landscape of deep learning models. A combination of SAM with adaptive learning rate (AdaSAM) has also been explored to train large-scale deep neural networks without theoretical guarantee due to the dual difficulties in analyzing the perturbation step and the coupled adaptive learning rate. In this paper, we try to analyze the convergence rate of AdaSAM in the stochastic non-convex setting. We theoretically show that AdaSAM admit a O(1/ √ bT ) convergence rate and show linear speedup property with respect to mini-batch size b. To the best of our knowledge, we are the first to provide the non-trivial convergence rate of SAM with an adaptive learning rate. To decouple the two stochastic gradient steps with the adaptive learning rate, we introduce the delayed second-order momentum term in the convergence to decompose them to make them independent while taking an expectation. Then we bound them by showing the adaptive learning rate has a limited range, which makes our analysis feasible. At last, we conduct experiments on several NLP tasks and they show that AdaSAM could achieve superior performance compared with SGD, AMSGrad, and SAM optimizer.



), due to the extra perturbation step. Hence, SAM requires to forward and back propagate twice for one parameter update, resulting in one more computation cost than the classic optimizers. Moreover, as there are two steps during the training process, it needs double hyper-parameters, which makes the learning rate tuning unbearable and costly. Adaptive learning rate optimization methods scale the gradients based on the history gradient information to accelerate convergence. These methods, such as Adagrad Duchi et al. (2011) , Adam, and AMSGrad, have been studied for solving the NLP tasks Zhang et al. (2020) . However, it might converge to a local minima which leads to poor generalization performance. Many studies have investigated how to improve the generalization ability of the adaptive learning rate methods. Recently, several work has tried to ease the learning rate tuning in SAM by combining SAM with adaptive learning rate. For example, Zhuang et al. ( 2022) trains ViT models with an adaptive method and Kwon et al. ( 2021) training a NLP model. Although remarkable performance has been achieved, their convergence is still unknown since the adaptive learning rate is used in SAM. Directly analyzing the convergence of this method is difficult due to the two steps of optimization, especially the adaptive learning rate adopted in the second optimization step. In this paper, we analyze the convergence rate of SAM with an adaptive learning rate in the nonconvex stochastic setting, dubbed AdaSAM. To circumvent the difficulty in the analysis, we develop technique to decouple the two-step training of SAM and the adaptive learning rate. The analysis procedure is mainly divided into two parts. The first part is to analyze the procedure of the SAM. Then we analyze the second step that adopts the adaptive learning rate method. We introduce a second-order momentum term from the previous iteration, which is related to the adaptive learning rate and independent of SAM while taking an expectation. Then we can bound the term composed by the SAM and the previous second-order momentum due to the limited adaptive learning rate. We prove that AdaSAM enjoys the property of linear speedup with respect to the batch size. Furthermore, we apply AdaSAM to train RoBERTa model on the GLUE benchmark. We show that AdaSAM achieves the best performance in experiments, where it wins 6 tasks of 8 tasks, and the linear speedup can be clearly observed. In the end, we summarize our contributions as follows: • We present the first convergence guarantee of the adaptive SAM method under the nonconvex setting. Our results suggest that a large mini-batch can help convergence due to the linear speedup with respect to batch size. • We conduct a series of experiments on various tasks. The results show that AdaSAM outperforms the most of state-of-art optimizers and the linear speedup is verified.

2. PRELIMINARY 2.1 PROBLEM SETUP

In this work, we focus on the following stochastic nonconvex optimization min x∈R d f (x) := E ξ∼D f ξ (x), where d is dimension of variable x, D is the unknown distribution of the data samples, f ξ (x) is a smooth and possibly non-convex function, and f ξi (x) denotes the objective function at the sampled data point ξ i according to data distribution D. In machine learning, it covers empirical risk minimization as a special case and f is the loss function. When the dataset D cover N data points, i.e., D = {ξ i , i = 1, 2, . . . , N }. Problem 1 reduces to the following finite-sum problem: min x∈R d f (x) := 1 N i f ξi (x). Without additional declaration, we represent f i (x) as f ξi (x) for simplification. Notations. f i (x) is the i-th loss function while x ∈ R d is the model parameter and d is the parameter dimension. We denote the l 2 norm as ∥ • ∥ 2 . A Hadamard product is denoted as a ⊙ b where a,b are two vectors. For a vector a ∈ R d , √ a is denoted as a vector that the j-th value, ( √ a) (j) , is equal to the square root of a j 



minimization (SAM) Foret et al. (2021) is a powerful optimizer for training largescale deep learning models by minimizing the gap between the training performance and generalization performance. It has achieved remarkable results on various deep neural networks, such as ResNet He et al. (2016), vision transformer Dosovitskiy et al. (2021); Xu et al. (2021), language models Devlin et al. (2018); He et al. (2020), on extensive benchmarks. However, SAM-type methods suffer from several issues during training the deep neural network, especially for huge computation costs and heavily hyper-parameter tuning procedure. It needs double gradients computation compared with classic optimizers, like SGD, AdamKingma & Ba (2015), AMSGrad Reddi et al. (

minimization Previous work shows that sharp minima may lead to poor generalization whereas flat minima perform betterJiang et al. (2020); Keskar et al. (2017); He et al. (2019). Therefore, it is popular to consider sharpness to be closely related to the generalization. Sharpness-aware minimization is proposed in Foret et al. (2021), which targets to find flat minimizers by minimizing the training loss uniformly in entire neighborhood. Specifically, Foret et al. (2021) aims to solve the following minimax saddle point problem: ρ ≥ 0 and λ ≥ 0 are two hyperparameters. That is, the perturbed loss function of f (x) in a neighborhood is minimized instead of the original loss function f (x). By using Taylor expansion of f (x + δ) with respect to δ, the inner max problem is approximately solved via δ * (x) = arg max ∥δ∥≤ρ f (x + δ) ≈ arg max ∥δ∥≤ρ f (x) + δ ⊤ ∇f (x) = arg max ∥δ∥≤ρ δ ⊤ ∇f (x) = ρ ∇f (x) ∥∇f (x)∥ .

