HOW DOES SHARPNESS-AWARE MINIMIZATION MINIMIZE SHARPNESS?

Abstract

Sharpness-Aware Minimization (SAM) is a highly effective regularization technique for improving the generalization of deep neural networks for various settings. However, the underlying working of SAM remains elusive because of various intriguing approximations in the theoretical characterizations. SAM intends to penalize a notion of sharpness of the model but implements a computationally efficient variant; moreover, a third notion of sharpness was used for proving generalization guarantees. The subtle differences in these notions of sharpness can indeed lead to significantly different empirical results. This paper rigorously nails down the exact sharpness notion that SAM regularizes and clarifies the underlying mechanism. We also show that the two steps of approximations in the original motivation of SAM individually lead to inaccurate local conclusions, but their combination accidentally reveals the correct effect, when full-batch gradients are applied. Furthermore, we also prove that the stochastic version of SAM in fact regularizes the third notion of sharpness mentioned above, which is most likely to be the preferred notion for practical performance. The key mechanism behind this intriguing phenomenon is the alignment between the gradient and the top eigenvector of Hessian when SAM is applied.

1. INTRODUCTION

Modern deep nets are often overparametrized and have the capacity to fit even randomly labeled data (Zhang et al., 2016) . Thus, a small training loss does not necessarily imply good generalization. Yet, standard gradient-based training algorithms such as SGD are able to find generalizable models. Recent empirical and theoretical studies suggest that generalization is well-correlated with the sharpness of the loss landscape at the learned parameter (Keskar et al., 2016; Dinh et al., 2017; Dziugaite et al., 2017; Neyshabur et al., 2017; Jiang et al., 2019) . Partly motivated by these studies, Foret et al. (2021); Wu et al. (2020); Zheng et al. (2021); Norton et al. (2021) propose to penalize the sharpness of the landscape to improve the generalization. We refer this method to Sharpness-Aware Minimization (SAM) and focus on the version of Foret et al. (2021) . Despite its empirical success, the underlying working of SAM remains elusive because of the various intriguing approximations made in its derivation and analysis. There are three different notions of sharpness involved -SAM intends to optimize the first notion, the sharpness along the worst direction, but actually implements a computationally efficient notion, the sharpness along the direction of the gradient. But in the analysis of generalization, a third notion of sharpness is actually used to prove generalization guarantees, which admits the first notion as an upper bound. The subtle difference between the three notions can lead to very different biases (see Figure 1 for demonstration). More concretely, let L be the training loss, x be the parameter and ρ be the perturbation radius, a hyperparameter requiring tuning. The first notion corresponds to the following optimization problem (1), where we call R Max ρ (x) = L Max ρ (x) -L(x) the worst-direction sharpness at x. SAM intends to minimize the original training loss plus the worst-direction sharpness at x. min x L Max ρ (x), where L Max ρ (x) = max ∥v∥ 2 ≤1 L(x + ρv) . However, even evaluating L Max ρ (x) is computationally expensive, not to mention optimization. Thus Foret et al. ( 2021); Zheng et al. ( 2021) have introduced a second notion of sharpness, which approximates the worstcase direction in (1) by the direction of gradient, as defined below in (2). We call R Asc ρ (x) = L Asc ρ (x) -L(x) the ascent-direction sharpness at x. min x L Asc ρ (x), where L Asc ρ (x) = L (x + ρ∇L(x)/ ∥∇L(x)∥ 2 ) . Type of Sharpness-Aware Loss Notation Definition Biases (among minimizers) Worst-direction 2021) omit the gradient through other occurrence of x and approximate the gradient of ascent-direction sharpness by gradient taken after one-step ascent, i.e., ∇L Asc ρ (x) ≈ ∇L (x + ρ∇L(x)/ ∥∇L(x)∥ 2 ) and derive the update rule of SAM, where η is the learning rate. Sharpness-Aware Minimization (SAM): L Max ρ max ∥v∥ 2 ≤1 L(x + ρv) min x λ 1 (∇ 2 L(x)) (Thm G.3) Ascent-direction L Asc ρ L x + ρ ∇L(x) ∥∇L(x)∥ 2 min x λ min (∇ 2 L(x)) (Thm G.4) Average-direction L Avg ρ E g∼N (0,I) L(x + ρ g ∥g∥ 2 ) min x Tr(∇ 2 L(x)) (Thm G.5) x(t + 1) = x(t) -η∇L (x + ρ∇L(x)/ ∥∇L(x)∥ 2 ) . (3) Intriguingly, the generalization bound of SAM upperbounds the generalization error by the third notion of sharpness, called average-direction sharpness, R Avg ρ (x) and defined formally below. R Avg ρ (x) = L Avg ρ (x) -L(x), where L Avg ρ (x) = E g∼N (0,I) L (x + ρg/∥g∥ 2 ) . The worst-case sharpness is an upper bound of the average case sharpness and thus it is a looser bound for generalization error. In other words, according to the generalization theory in Foret et al. ( 2021); Wu et al. (2020) in fact motivates us to directly minimize the average case sharpness (as opposed to the worst-case sharpness that SAM intends to optimize). In this paper, we analyze the biases introduced by penalizing these various notions of sharpness as well as the bias of SAM (Equation 3). Our analysis for SAM is performed for small perturbation radius ρ and learning rate η under the setting where the minimizers of loss form a manifold following the setup of Fehrman et al. ( 2020); Li et al. (2021) . In particular, we make the following theoretical contributions. 1. We prove that full-batch SAM indeed minimizes worst-direction sharpness. (Theorem 4.5) 2. Surprisingly, when batch size is 1, SAM minimizes average-direction sharpness. (Theorem 5.4) 3. We provide a characterization (Theorems 4.2 and 5.3) of what a few sharpness regularizers bias towards among the minimizers (including all the three notions of the sharpness in Table 1 ), when the perturbation radius ρ goes to zero. Surprisingly, both heuristic approximations made for SAM lead to inaccurate conclusions: (1) Minimizing worst-direction sharpness and ascent-direction sharpness induce different biases among minimizers, and (2) SAM doesn't minimize ascent-direction sharpness. The key mechanism behind this bias of SAM is the alignment between gradient and the top eigenspace of Hessian of the original loss in the latter phase of training-the angle between them decreases gradually to the level of O(ρ). It turns out that the worst-direction sharpness starts to decrease once such alignment is established (see Section 4.3). Interestingly, such an alignment is not implied by the minimization problem (2), but rather, it is an implicit property of the specific update rule of SAM. Interestingly, such an alignment property holds for SAM with full batch and SAM with batch size one, but does not necessarily hold for the mini-batch case. 



Sharpness and Generalization. The study on the connection between sharpness and generalization can be traced back toHochreiter et al. (1997).Keskar et al. (2016)  observe a positive correlation between the batch size, the generalization error, and the sharpness of the loss landscape when changing the batch size.Jastrzebski et al. (2017)  extend this by finding a correlation between the sharpness and the ratio between learning rate to batch size.Dinh et al. (2017)  show that one can easily construct networks with good generalization but with arbitrary large sharpness by reparametrization. Dziugaite et al. (2017); Neyshabur et al. (2017); Wei et al. (2019a;b) give theoretical guarantees on the generalization error using sharpness-related measures. Jiang et al. (2019) perform a large-scale empirical study on various generalization measures and show that sharpness-based measures have the highest correlation with generalization. Background on Sharpness-Aware Minimization. Foret et al. (2021); Zheng et al. (2021) concurrently propose to minimize the loss at the perturbed from current parameter towards the worst direction to improve generalization. Wu et al. (2020) propose an almost identical method for a different purpose, robust generalization of adversarial training. Kwon et al. (2021) propose a different metric for SAM to fix the rescaling problem pointed out by Dinh et al. (2017). Liu et al. (2022) propose a more computationally efficient version

Definitions and biases of different notions of sharpness-aware loss. The corresponding sharpness is defined as the difference between sharpness-aware loss and the original loss. Here λ1 denotes the largest eigenvalue and λmin denotes the smallest non-zero eigenvalue.

