HOW DOES SHARPNESS-AWARE MINIMIZATION MINIMIZE SHARPNESS?

Abstract

Sharpness-Aware Minimization (SAM) is a highly effective regularization technique for improving the generalization of deep neural networks for various settings. However, the underlying working of SAM remains elusive because of various intriguing approximations in the theoretical characterizations. SAM intends to penalize a notion of sharpness of the model but implements a computationally efficient variant; moreover, a third notion of sharpness was used for proving generalization guarantees. The subtle differences in these notions of sharpness can indeed lead to significantly different empirical results. This paper rigorously nails down the exact sharpness notion that SAM regularizes and clarifies the underlying mechanism. We also show that the two steps of approximations in the original motivation of SAM individually lead to inaccurate local conclusions, but their combination accidentally reveals the correct effect, when full-batch gradients are applied. Furthermore, we also prove that the stochastic version of SAM in fact regularizes the third notion of sharpness mentioned above, which is most likely to be the preferred notion for practical performance. The key mechanism behind this intriguing phenomenon is the alignment between the gradient and the top eigenvector of Hessian when SAM is applied.

1. INTRODUCTION

Modern deep nets are often overparametrized and have the capacity to fit even randomly labeled data (Zhang et al., 2016) . Thus, a small training loss does not necessarily imply good generalization. Yet, standard gradient-based training algorithms such as SGD are able to find generalizable models. Recent empirical and theoretical studies suggest that generalization is well-correlated with the sharpness of the loss landscape at the learned parameter (Keskar et al., 2016; Dinh et al., 2017; Dziugaite et al., 2017; Neyshabur et al., 2017; Jiang et al., 2019) . Partly motivated by these studies, Foret et al. ( 2021 Despite its empirical success, the underlying working of SAM remains elusive because of the various intriguing approximations made in its derivation and analysis. There are three different notions of sharpness involved -SAM intends to optimize the first notion, the sharpness along the worst direction, but actually implements a computationally efficient notion, the sharpness along the direction of the gradient. But in the analysis of generalization, a third notion of sharpness is actually used to prove generalization guarantees, which admits the first notion as an upper bound. The subtle difference between the three notions can lead to very different biases (see Figure 1 for demonstration). More concretely, let L be the training loss, x be the parameter and ρ be the perturbation radius, a hyperparameter requiring tuning. The first notion corresponds to the following optimization problem (1), where we call R Max ρ (x) = L Max ρ (x) -L(x) the worst-direction sharpness at x. SAM intends to minimize the original training loss plus the worst-direction sharpness at x. 2021) have introduced a second notion of sharpness, which approximates the worstcase direction in (1) by the direction of gradient, as defined below in (2). We call R Asc ρ (x) = L Asc ρ (x) -L(x) the ascent-direction sharpness at x. min x L Asc ρ (x), where L Asc ρ (x) = L (x + ρ∇L(x)/ ∥∇L(x)∥ 2 ) . (2) 1



); Wu et al. (2020); Zheng et al. (2021); Norton et al. (2021) propose to penalize the sharpness of the landscape to improve the generalization. We refer this method to Sharpness-Aware Minimization (SAM) and focus on the version of Foret et al. (2021).

where L Max ρ (x) = max ∥v∥ 2 ≤1 L(x + ρv) .(1)However, even evaluating L Max ρ (x) is computationally expensive, not to mention optimization. Thus Foret et al. (2021); Zheng et al. (

