K-SAM: SHARPNESS-AWARE MINIMIZATION AT THE SPEED OF SGD

Abstract

Sharpness-Aware Minimization (SAM) has recently emerged as a robust technique for improving the accuracy of deep neural networks. However, SAM incurs a high computational cost in practice, requiring up to twice as much computation as vanilla SGD. The computational challenge posed by SAM arises because each iteration requires both ascent and descent steps and thus double the gradient computations. To address this challenge, we propose to compute gradients in both stages of SAM on only the top-k samples with highest loss. K-SAM is simple and extremely easy-to-implement while providing significant generalization boosts over vanilla SGD at little to no additional cost.

1. INTRODUCTION

Methods for promoting good generalization are of tremendous value to deep learning practitioners and a point of fascination for deep learning theorists. While machine learning models can easily achieve perfect training accuracy on any dataset given enough parameters, it is unclear when a model will generalize well to test data. Improving the ability of models to extrapolate their knowledge, learned during training, to held out test samples is the key to performing well in the wild. Recently, there has been a line of work that argues for the geometry of the loss landscape as a major contributor to generalization performance for deep learning models. A number of researchers have argued that flatter minima lead to models that generalize better (Keskar et al., 2016; Xing et al., 2018; Jiang et al., 2019; Smith et al., 2021) . Underlying this work is the intuition that small changes in parameters yield perturbations to decision boundaries so that flat minima yield wide-margin decision boundaries (Huang et al., 2020) . Motivated by these investigations, Foret et al. (2020) propose an effective algorithm -Sharpness Aware Minimization (SAM) -to optimize models toward flatter minima and better generalization performance. The proposed algorithm entails performing one-step adversarial training in parameter space, finding a loss function minimum that is "flat" in the sense that perturbations to the network parameters in worst-case directions still yield low training loss. This simple concept achieves impressive performance on a wide variety of tasks. For example, Foret et al. ( 2020) achieved notable improvements on various benchmark vision datasets (e.g., CIFAR-10, ImageNet) by simply swapping out the optimizer. Later, Chen et al. ( 2021) found that SAM improves sample complexity and performance of vision transformer models so that these transformers are competitive with ResNets even without pre-training. Moreover, further innovations to the SAM setup, such as modifying the radius of the adversarial step to be invariant to parameter re-scaling (Kwon et al., 2021) , yield additional improvements to generalization on vision tasks. In addition, Bahri et al. ( 2021) recently found that SAM not only works in the vision domain, but also improves the performance of language models on GLUE, SuperGLUE, Web Questions, Natural Questions, Trivia QA, and TyDiQA (Wang et al., 2018; 2019; Joshi et al., 2017; Clark et al., 2020) . Despite the simplicity of SAM, the improved performance comes with a steep cost of twice as much compute, given that SAM requires two forward and backward passes for each optimization step: one for the ascent step and another for the descent step. The additional cost may make SAM too expensive for widespread adoption by practitioners, thus motivating studies to decrease the computational cost of SAM. For example, Efficient SAM (Du et al., 2021) decreases the computational cost of SAM by using examples with the largest increase in losses for the descent step. Bahri et al. (2021) randomly select a subset of examples for the ascent step making the ascent step faster. However, both methods still require a full forward and backward pass for either the ascent or descent step on 1

