AN ADAPTIVE POLICY TO EMPLOY SHARPNESS-AWARE MINIMIZATION

Abstract

Sharpness-aware minimization (SAM), which searches for flat minima by min-max optimization, has been shown to be useful in improving model generalization. However, since each SAM update requires computing two gradients, its computational cost and training time are both doubled compared to standard empirical risk minimization (ERM). Recent state-of-the-arts reduce the fraction of SAM updates and thus accelerate SAM by switching between SAM and ERM updates randomly or periodically. In this paper, we design an adaptive policy to employ SAM based on the loss landscape geometry. Two efficient algorithms, AE-SAM and AE-LookSAM, are proposed. We theoretically show that AE-SAM has the same convergence rate as SAM. Experimental results on various datasets and architectures demonstrate the efficiency and effectiveness of the adaptive policy.

1. INTRODUCTION

Despite great success in many applications (He et al., 2016; Zagoruyko & Komodakis, 2016; Han et al., 2017) , deep networks are often over-parameterized and capable of memorizing all training data. The training loss landscape is complex and nonconvex with many local minima of different generalization abilities. Many studies have investigated the relationship between the loss surface's geometry and generalization performance (Hochreiter & Schmidhuber, 1994; McAllester, 1999; Keskar et al., 2017; Neyshabur et al., 2017; Jiang et al., 2020) , and found that flatter minima generalize better than sharper minima (Dziugaite & Roy, 2017; Petzka et al., 2021; Chaudhari et al., 2017; Keskar et al., 2017; Jiang et al., 2020) . Sharpness-aware minimization (SAM) (Foret et al., 2021) is the current state-of-the-art to seek flat minima by solving a min-max optimization problem. In the SAM algorithm, each update consists of two forward-backward computations: one for computing the perturbation and the other for computing the actual update direction. Since these two computations are not parallelizable, SAM doubles the computational overhead as well as the training time compared to empirical risk minimization (ERM). Several algorithms (Du et al., 2022a; Zhao et al., 2022b; Liu et al., 2022) have been proposed to improve the efficiency of SAM. ESAM (Du et al., 2022a) uses fewer samples to compute the gradients and updates fewer parameters, but each update still requires two gradient computations. Thus, ESAM does not alleviate the bottleneck of training speed. Instead of using the SAM update at every iteration, recent state-of-the-arts (Zhao et al., 2022b; Liu et al., 2022) proposed to use SAM randomly or periodically. Specifically, SS-SAM (Zhao et al., 2022b) selects SAM or ERM according to a Bernoulli trial, while LookSAM (Liu et al., 2022) employs SAM at every k step. Though more efficient, the random or periodic use of SAM is suboptimal as it is not geometry-aware. Intuitively, the SAM update is more useful in sharp regions than in flat regions. In this paper, we propose an adaptive policy to employ SAM based on the geometry of the loss landscape. The SAM update is used when the model is in sharp regions, while the ERM update is used in flat regions for reducing the fraction of SAM updates. To measure sharpness, we use the squared stochastic gradient norm and model it by a normal distribution, whose parameters are estimated by exponential moving average. Experimental results on standard benchmark datasets demonstrate the superiority of the proposed policy. Our contributions are summarized as follows: (i) We propose an adaptive policy to use SAM or ERM update based on the loss landscape geometry. (ii) We propose an efficient algorithm, called AE-SAM (Adaptive policy to Employ SAM), to reduce the fraction of SAM updates. We also theoretically study its convergence rate. (iii) The proposed policy is general and can be combined with any SAM variant. In this paper, we integrate it with LookSAM (Liu et al., 2022) and propose AE-LookSAM. (iv) Experimental results on various network architectures and datasets (with and without label noise) verify the superiority of AE-SAM and AE-LookSAM over existing baselines. Notations. Vectors (e.g., x) and matrices (e.g., X) are denoted by lowercase and uppercase boldface letters, respectively. For a vector x, its 2 -norm is x . N (µ; σ 2 ) is the univariate normal distribution with mean µ and variance σ 2 . diag(x) constructs a diagonal matrix with x on the diagonal. Moreover, I A (x) denotes the indicator function for a given set A, i.e., I A (x) = 1 if x ∈ A, and 0 otherwise.

2. RELATED WORK

We are given a training set D with i.i.d. samples {(x i , y i ) : i = 1, . . . , n}. Let f (x; w) be a model parameterized by w. Its empirical risk on D is L(D; w) = 1 n n i=1 (f (x i ; w), y i ), where (•, •) is a loss (e.g., cross-entropy loss for classification). Model training aims to learn a model from the training data that generalizes well on the test data. Generalization and Flat Minima. The connection between model generalization and loss landscape geometry has been theoretically and empirically studied in (Keskar et al., 2017; Dziugaite & Roy, 2017; Jiang et al., 2020) . Recently, Jiang et al. ( 2020) conducted large-scale experiments and find that sharpness-based measures (flatness) are related to generalization of minimizers. Although flatness can be characterized by the Hessian's eigenvalues (Keskar et al., 2017; Dinh et al., 2017) , handling the Hessian explicitly is computationally prohibitive. To address this issue, practical algorithms propose to seek flat minima by injecting noise into the optimizers (Zhu et al., 2019; Zhou et al., 2019; Orvieto et al., 2022; Bisla et al., 2022) , introducing regularization (Chaudhari et al., 2017; Zhao et al., 2022a; Du et al., 2022b) , averaging model weights during training (Izmailov et al., 2018; He et al., 2019; Cha et al., 2021) , or sharpness-aware minimization (SAM) (Foret et al., 2021; Kwon et al., 2021; Zhuang et al., 2022; Kim et al., 2022) . SAM. The state-of-the-art SAM (Foret et al., 2021) and its variants (Kwon et al., 2021; Zhuang et al., 2022; Kim et al., 2022; Zhao et al., 2022a) search for flat minima by solving the following min-max optimization problem: min w max ≤ρ L(D; w + ), where ρ > 0 is the radius of perturbation. The above can also be rewritten as min w L(D; w) + R(D; w), where R(D; w) ≡ max ≤ρ L(D; w + ) -L(D; w) is a regularizer that penalizes sharp minimizers (Foret et al., 2021) . As solving the inner maximization in (1) exactly is computationally infeasible for nonconvex losses, SAM approximately solves it by first-order Taylor approximation, leading to the update rule: w t+1 = w t -η∇L(B t ; w t + ρ t ∇L(B t ; w t )), where B t is a mini-batch of data, η is the step size, and ρ t = ρ ∇L(Bt;wt) . Although SAM has shown to be effective in improving the generalization of deep networks, a major drawback is that each update in (2) requires two forward-backward calculations. Specifically, SAM first calculates the gradient of L(B t ; w) at w t to obtain the perturbation, then calculates the gradient of L(B t ; w) at w t +ρ t ∇L(B t ; w t ) to obtain the update direction for w t . As a result, SAM doubles the computational overhead compared to ERM. Efficient Variants of SAM. Several algorithms have been proposed to accelerate the SAM algorithm. ESAM (Du et al., 2022a) uses fewer samples to compute the gradients and only updates part of the model in the second step, but still requires to compute most of the gradients. Another direction is to reduce the number of SAM updates during training. SS-SAM (Zhao et al., 2022b) randomly selects

