K-SAM: SHARPNESS-AWARE MINIMIZATION AT THE SPEED OF SGD

Abstract

Sharpness-Aware Minimization (SAM) has recently emerged as a robust technique for improving the accuracy of deep neural networks. However, SAM incurs a high computational cost in practice, requiring up to twice as much computation as vanilla SGD. The computational challenge posed by SAM arises because each iteration requires both ascent and descent steps and thus double the gradient computations. To address this challenge, we propose to compute gradients in both stages of SAM on only the top-k samples with highest loss. K-SAM is simple and extremely easy-to-implement while providing significant generalization boosts over vanilla SGD at little to no additional cost.

1. INTRODUCTION

Methods for promoting good generalization are of tremendous value to deep learning practitioners and a point of fascination for deep learning theorists. While machine learning models can easily achieve perfect training accuracy on any dataset given enough parameters, it is unclear when a model will generalize well to test data. Improving the ability of models to extrapolate their knowledge, learned during training, to held out test samples is the key to performing well in the wild. Recently, there has been a line of work that argues for the geometry of the loss landscape as a major contributor to generalization performance for deep learning models. A number of researchers have argued that flatter minima lead to models that generalize better (Keskar et al., 2016; Xing et al., 2018; Jiang et al., 2019; Smith et al., 2021) . Underlying this work is the intuition that small changes in parameters yield perturbations to decision boundaries so that flat minima yield wide-margin decision boundaries (Huang et al., 2020) . Motivated by these investigations, Foret et al. (2020) propose an effective algorithm -Sharpness Aware Minimization (SAM) -to optimize models toward flatter minima and better generalization performance. The proposed algorithm entails performing one-step adversarial training in parameter space, finding a loss function minimum that is "flat" in the sense that perturbations to the network parameters in worst-case directions still yield low training loss. This simple concept achieves impressive performance on a wide variety of tasks. For example, Foret et al. ( 2020) achieved notable improvements on various benchmark vision datasets (e.g., CIFAR-10, ImageNet) by simply swapping out the optimizer. Later, Chen et al. ( 2021) found that SAM improves sample complexity and performance of vision transformer models so that these transformers are competitive with ResNets even without pre-training. Moreover, further innovations to the SAM setup, such as modifying the radius of the adversarial step to be invariant to parameter re-scaling (Kwon et al., 2021) , yield additional improvements to generalization on vision tasks. In addition, Bahri et al. ( 2021) recently found that SAM not only works in the vision domain, but also improves the performance of language models on GLUE, SuperGLUE, Web Questions, Natural Questions, Trivia QA, and TyDiQA (Wang et al., 2018; 2019; Joshi et al., 2017; Clark et al., 2020) . Despite the simplicity of SAM, the improved performance comes with a steep cost of twice as much compute, given that SAM requires two forward and backward passes for each optimization step: one for the ascent step and another for the descent step. The additional cost may make SAM too expensive for widespread adoption by practitioners, thus motivating studies to decrease the computational cost of SAM. For example, Efficient SAM (Du et al., 2021) decreases the computational cost of SAM by using examples with the largest increase in losses for the descent step. Bahri et al. (2021) randomly select a subset of examples for the ascent step making the ascent step faster. However, both methods still require a full forward and backward pass for either the ascent or descent step on all samples in the batch, so they are more expensive than vanilla SGD. In comparison to these works, we develop a version of SAM that is as fast as SGD so that practitioners can adopt our variant at no cost. In this paper, we propose K-SAM, a simple modification to SAM that reduces the computational costs to that of vanilla SGD, while achieving comparable performance to the original SAM optimizer. K-SAM exploits the fact that a small subset of training examples is sufficient for both gradient computation steps (Du et al., 2021; Fan et al., 2017; Bahri et al., 2021) , and the examples with largest losses dominate the average gradient over a large batch. To decrease computational cost, we only use the K examples with the largest loss values in the training batch for both gradient computations. When K is chosen properly, our proposed K-SAM can be as fast as vanilla SGD and meanwhile improves generalization by seeking flat minima similarly to SAM. We empirically verify the effectiveness of our proposed approach across datasets and models. We demonstrate that a small number of samples with high loss produces gradients that are well-aligned with the entire batch loss. Moreover, we show that our proposed method can achieve comparable performance to the original (non-accelerated) SAM for vision tasks, such as image classification on CIFAR-{10, 100}, and language models on the GLUE benchmark while keeping the training cost roughly as low as vanilla SGD. On the other hand, we observe that on large-scale many-class image classification tasks, such as ImageNet, the average gradient within a batch is broadly distributed, so that a very small number of samples with highest losses are not representative of the batch. This phenomenon makes it hard to simultaneously achieve training efficiency and strong generalization by subsampling. Nonetheless, we show that K-SAM can achieve comparable generalization to SAM with around 65% training cost.

2.1. SHARPNESS-AWARE MINIMIZATION

In this section, we briefly introduce how SAM simultaneously minimizes the loss while also decreasing loss sharpness, and we detail why it requires additional gradient steps that in turn double the computational costs during training. Instead of finding parameters yielding low training loss only, SAM attempts to find the parameter vector whose neighborhood possesses uniformly low loss, thus leading to a flat minimum. Formally, SAM achieves this goal by solving the mini-max optimization problem, min  where w are the parameters of the neural network, L is the loss function, S is the training set, and ϵ is a small perturbation within an l 2 ball of norm ρ. In order to solve the outer minimization problem, SAM applies a first-order approximation to solve the inner maximization problem, ϵ * = arg max ∥ϵ∥2≤ρ L S (w + ϵ), ≈ arg max ∥ϵ∥2≤ρ L S (w) + ϵ T ∇ w L S (w). = ρ∇ w L S (w)/∥∇ w L S (w)∥ 2 . (2) After computing the approximate maximizer ε, SAM obtains "sharpness aware" gradient for descent: ∇L SAM S (w) ≈ ∇ w L S (w)| w+ε . (3) In short, given a base optimizer, i.e., SGD, instead of computing the gradient of the model at the current parameter vector w, SAM updates model parameters using the gradient with respect to the perturbed model at w + ε. Therefore a SAM update step requires two forward and backward passes on each sample in a batch, namely a gradient ascent step to achieve the perturbation ε and a gradient descent step to update the current model, which doubles the train time compared to the base optimizer.

2.2. EFFICIENT SHARPNESS-AWARE MINIMIZATION

Recently, several works improve the efficiency of SAM while retaining its performance benefits. Bahri et al. (2021) improve the efficiency of SAM by reducing the computational cost of the gradient

