PROVABLE SHARPNESS-AWARE MINIMIZATION WITH ADAPTIVE LEARNING RATE

Abstract

Sharpness aware minimization (SAM) optimizer has been extensively explored as it can converge fast and train deep neural networks efficiently via introducing extra perturbation steps to flatten the landscape of deep learning models. A combination of SAM with adaptive learning rate (AdaSAM) has also been explored to train large-scale deep neural networks without theoretical guarantee due to the dual difficulties in analyzing the perturbation step and the coupled adaptive learning rate. In this paper, we try to analyze the convergence rate of AdaSAM in the stochastic non-convex setting. We theoretically show that AdaSAM admit a O(1/ √ bT ) convergence rate and show linear speedup property with respect to mini-batch size b. To the best of our knowledge, we are the first to provide the non-trivial convergence rate of SAM with an adaptive learning rate. To decouple the two stochastic gradient steps with the adaptive learning rate, we introduce the delayed second-order momentum term in the convergence to decompose them to make them independent while taking an expectation. Then we bound them by showing the adaptive learning rate has a limited range, which makes our analysis feasible. At last, we conduct experiments on several NLP tasks and they show that AdaSAM could achieve superior performance compared with SGD, AMSGrad, and SAM optimizer.



), due to the extra perturbation step. Hence, SAM requires to forward and back propagate twice for one parameter update, resulting in one more computation cost than the classic optimizers. Moreover, as there are two steps during the training process, it needs double hyper-parameters, which makes the learning rate tuning unbearable and costly. Adaptive learning rate optimization methods scale the gradients based on the history gradient information to accelerate convergence. These methods, such as Adagrad Duchi et al. (2011) , Adam, and AMSGrad, have been studied for solving the NLP tasks Zhang et al. (2020) . However, it might converge to a local minima which leads to poor generalization performance. Many studies have investigated how to improve the generalization ability of the adaptive learning rate methods. Recently, several work has tried to ease the learning rate tuning in SAM by combining SAM with adaptive learning rate. For example, Zhuang et al. (2022) trains ViT models with an adaptive method and Kwon et al. ( 2021) training a NLP model. Although remarkable performance has been achieved, their convergence is still unknown since the adaptive learning rate is used in SAM. Directly analyzing the convergence of this method is difficult due to the two steps of optimization, especially the adaptive learning rate adopted in the second optimization step.



minimization (SAM) Foret et al. (2021) is a powerful optimizer for training largescale deep learning models by minimizing the gap between the training performance and generalization performance. It has achieved remarkable results on various deep neural networks, such as ResNet He et al. (2016), vision transformer Dosovitskiy et al. (2021); Xu et al. (2021), language models Devlin et al. (2018); He et al. (2020), on extensive benchmarks. However, SAM-type methods suffer from several issues during training the deep neural network, especially for huge computation costs and heavily hyper-parameter tuning procedure. It needs double gradients computation compared with classic optimizers, like SGD, AdamKingma & Ba (2015), AMSGrad Reddi et al. (

