AN ADAPTIVE POLICY TO EMPLOY SHARPNESS-AWARE MINIMIZATION

Abstract

Sharpness-aware minimization (SAM), which searches for flat minima by min-max optimization, has been shown to be useful in improving model generalization. However, since each SAM update requires computing two gradients, its computational cost and training time are both doubled compared to standard empirical risk minimization (ERM). Recent state-of-the-arts reduce the fraction of SAM updates and thus accelerate SAM by switching between SAM and ERM updates randomly or periodically. In this paper, we design an adaptive policy to employ SAM based on the loss landscape geometry. Two efficient algorithms, AE-SAM and AE-LookSAM, are proposed. We theoretically show that AE-SAM has the same convergence rate as SAM. Experimental results on various datasets and architectures demonstrate the efficiency and effectiveness of the adaptive policy.

1. INTRODUCTION

Despite great success in many applications (He et al., 2016; Zagoruyko & Komodakis, 2016; Han et al., 2017) , deep networks are often over-parameterized and capable of memorizing all training data. The training loss landscape is complex and nonconvex with many local minima of different generalization abilities. Many studies have investigated the relationship between the loss surface's geometry and generalization performance (Hochreiter & Schmidhuber, 1994; McAllester, 1999; Keskar et al., 2017; Neyshabur et al., 2017; Jiang et al., 2020) , and found that flatter minima generalize better than sharper minima (Dziugaite & Roy, 2017; Petzka et al., 2021; Chaudhari et al., 2017; Keskar et al., 2017; Jiang et al., 2020) . Sharpness-aware minimization (SAM) (Foret et al., 2021) is the current state-of-the-art to seek flat minima by solving a min-max optimization problem. In the SAM algorithm, each update consists of two forward-backward computations: one for computing the perturbation and the other for computing the actual update direction. Since these two computations are not parallelizable, SAM doubles the computational overhead as well as the training time compared to empirical risk minimization (ERM). Several algorithms (Du et al., 2022a; Zhao et al., 2022b; Liu et al., 2022) have been proposed to improve the efficiency of SAM. ESAM (Du et al., 2022a) uses fewer samples to compute the gradients and updates fewer parameters, but each update still requires two gradient computations. Thus, ESAM does not alleviate the bottleneck of training speed. Instead of using the SAM update at every iteration, recent state-of-the-arts (Zhao et al., 2022b; Liu et al., 2022) proposed to use SAM randomly or periodically. Specifically, SS-SAM (Zhao et al., 2022b) selects SAM or ERM according to a Bernoulli trial, while LookSAM (Liu et al., 2022) employs SAM at every k step. Though more efficient, the random or periodic use of SAM is suboptimal as it is not geometry-aware. Intuitively, the SAM update is more useful in sharp regions than in flat regions. In this paper, we propose an adaptive policy to employ SAM based on the geometry of the loss landscape. The SAM update is used when the model is in sharp regions, while the ERM update is used in flat regions for reducing the fraction of SAM updates. To measure sharpness, we use the squared stochastic gradient norm and model it by a normal distribution, whose parameters are estimated by exponential moving average. Experimental results on standard benchmark datasets demonstrate the superiority of the proposed policy. Our contributions are summarized as follows: (i) We propose an adaptive policy to use SAM or ERM update based on the loss landscape geometry. (ii) We propose an efficient algorithm, called AE-SAM (Adaptive policy to Employ SAM), to reduce the fraction of SAM updates. We also theoretically study its convergence rate. (iii) The proposed policy is general and can be combined with any SAM variant. In this paper, we integrate it with LookSAM (Liu et al., 2022) and propose AE-LookSAM. (iv) Experimental results on various network architectures and datasets (with and without label noise) verify the superiority of AE-SAM and AE-LookSAM over existing baselines. Notations. Vectors (e.g., x) and matrices (e.g., X) are denoted by lowercase and uppercase boldface letters, respectively. For a vector x, its 2 -norm is x . N (µ; σ 2 ) is the univariate normal distribution with mean µ and variance σ 2 . diag(x) constructs a diagonal matrix with x on the diagonal. Moreover, I A (x) denotes the indicator function for a given set A, i.e., I A (x) = 1 if x ∈ A, and 0 otherwise.

2. RELATED WORK

We are given a training set D with i.i.d. samples {(x i , y i ) : i = 1, . . . , n}. Let f (x; w) be a model parameterized by w. Its empirical risk on D is L(D; w) = 1 n n i=1 (f (x i ; w), y i ), where (•, •) is a loss (e.g., cross-entropy loss for classification). Model training aims to learn a model from the training data that generalizes well on the test data. Generalization and Flat Minima. The connection between model generalization and loss landscape geometry has been theoretically and empirically studied in (Keskar et al., 2017; Dziugaite & Roy, 2017; Jiang et al., 2020) . Recently, Jiang et al. (2020) conducted large-scale experiments and find that sharpness-based measures (flatness) are related to generalization of minimizers. Although flatness can be characterized by the Hessian's eigenvalues (Keskar et al., 2017; Dinh et al., 2017) , handling the Hessian explicitly is computationally prohibitive. To address this issue, practical algorithms propose to seek flat minima by injecting noise into the optimizers (Zhu et al., 2019; Zhou et al., 2019; Orvieto et al., 2022; Bisla et al., 2022) , introducing regularization (Chaudhari et al., 2017; Zhao et al., 2022a; Du et al., 2022b) , averaging model weights during training (Izmailov et al., 2018; He et al., 2019; Cha et al., 2021) , or sharpness-aware minimization (SAM) (Foret et al., 2021; Kwon et al., 2021; Zhuang et al., 2022; Kim et al., 2022) . SAM. The state-of-the-art SAM (Foret et al., 2021) and its variants (Kwon et al., 2021; Zhuang et al., 2022; Kim et al., 2022; Zhao et al., 2022a) search for flat minima by solving the following min-max optimization problem: min w max ≤ρ L(D; w + ), where ρ > 0 is the radius of perturbation. The above can also be rewritten as min w L(D; w) + R(D; w), where R(D; w) ≡ max ≤ρ L(D; w + ) -L(D; w) is a regularizer that penalizes sharp minimizers (Foret et al., 2021) . As solving the inner maximization in (1) exactly is computationally infeasible for nonconvex losses, SAM approximately solves it by first-order Taylor approximation, leading to the update rule: w t+1 = w t -η∇L(B t ; w t + ρ t ∇L(B t ; w t )), where B t is a mini-batch of data, η is the step size, and ρ t = ρ ∇L(Bt;wt) . Although SAM has shown to be effective in improving the generalization of deep networks, a major drawback is that each update in (2) requires two forward-backward calculations. Specifically, SAM first calculates the gradient of L(B t ; w) at w t to obtain the perturbation, then calculates the gradient of L(B t ; w) at w t +ρ t ∇L(B t ; w t ) to obtain the update direction for w t . As a result, SAM doubles the computational overhead compared to ERM. Efficient Variants of SAM. Several algorithms have been proposed to accelerate the SAM algorithm. ESAM (Du et al., 2022a) uses fewer samples to compute the gradients and only updates part of the model in the second step, but still requires to compute most of the gradients. Another direction is to reduce the number of SAM updates during training. SS-SAM (Zhao et al., 2022b) SAM or ERM update according to a Bernoulli trial, while LookSAM (Liu et al., 2022) employs SAM at every k iterations. Intuitively, the SAM update is more suitable for sharp regions than flat regions. However, the mixing policies in SS-SAM and LookSAM are not adaptive to the loss landscape. In this paper, we design an adaptive policy to employ SAM based on the loss landscape geometry.

3. METHOD

In this section, we propose an adaptive policy to employ SAM. The idea is to use ERM when w t is in a flat region, and use SAM only when the loss landscape is locally sharp. We start by introducing a sharpness measure (Section 3.1), then propose an adaptive policy based on this (Section 3.2). Next, we propose two algorithms (AE-SAM and AE-LookSAM) and study the convergence.

3.1. SHARPNESS MEASURE

Though sharpness can be characterized by Hessian's eigenvalues (Keskar et al., 2017; Dinh et al., 2017) , they are expensive to compute. A widely-used approximation is based on the gradient magnitude diag([∇L(B t ; w t )] 2 ) (Bottou et al., 2018; Khan et al., 2018) , where [v] 2 denotes the elementwise square of a vector v. As ∇L(B t ; w t ) 2 equals the trace of diag([∇L(B t ; w t )] 2 ), it is reasonable to choose ∇L(B t ; w t ) 2 as a sharpness measure. ∇L(B t ; w t ) 2 is also related to the gradient variance Var(∇L(B t ; w t )), another sharpness measure (Jiang et al., 2020) . Specifically, Var(∇L(B t ; w t )) ≡ E Bt ∇L(B t ; w t )-∇L(D; w t ) 2 = E Bt ∇L(B t ; w t ) 2 -∇L(D; w t ) 2 . (3) With appropriate smoothness assumptions on L, both SAM and ERM can be shown theoretically to converge to critical points of L(D; w) (i.e., ∇L(D; w) = 0) (Reddi et al., 2016; Andriushchenko & Flammarion, 2022) . Thus, it follows from (3) that Var(∇L(B t ; w t )) = E Bt ∇L(B t ; w t ) 2 when w t is a critical point of L(D; w). Jiang et al. (2020) conducted extensive experiments and empirically show that Var(∇L(B t ; w t )) is positively correlated with the generalization gap. The smaller the Var(∇L(B t ; w t )), the better generalization is the model with parameter w t . This finding also explains why SAM generalizes better than ERM. Figure 1 shows the gradient variance w.r.t. the number of epochs using SAM and ERM on CIFAR-100 with various network architectures (experimental details are in Section 4.1). As can be seen, SAM always has a much smaller variance than ERM. Figure 2 shows the expected squared norm of the stochastic gradient w.r.t. the number of epochs on CIFAR-100. As shown, SAM achieves a much smaller E Bt ∇L(B t ; w t ) 2 than ERM.

3.2. ADAPTIVE POLICY TO EMPLOY SAM

As E Bt ∇L(B t ; w t ) 2 changes with t (Figure 2 ), the sharpness at w t also changes along the optimization trajectory. As a result, we need to estimate E Bt ∇L(B t ; w t ) 2 at every iteration. One can sample a large number of mini-batches and compute the mean of the stochastic gradient norms. However, this can be computationally expensive. To address this problem, we model ∇L(B t ; w t ) 2 with a simple distribution and estimate the distribution parameters in an online manner. Figure 3 on CIFAR-100 using ResNet-18foot_0 . As can be seen, the distribution follows a Bell curve. Figure 3(b) shows the corresponding quantile-quantile (Q-Q) plot (Wilk & Gnanadesikan, 1968 ). The closer is the curve to a line, the distribution is closer to the normal distribution. Figure 3 suggests that ∇L(B t ; w t )foot_1 can be modeled 2 with a normal distribution N (µ t , σ 2 t ). We use exponential moving average (EMA), which is popularly used in adaptive gradient methods (e.g., RMSProp (Tieleman & Hinton, 2012 ), AdaDelta (Zeiler, 2012 ), Adam (Kingma & Ba, 2015) ), to estimate its mean and variance: µ t = δµ t-1 + (1 -δ) ∇L(B t ; w t ) 2 , ( ) σ 2 t = δσ 2 t-1 + (1 -δ)( ∇L(B t ; w t ) 2 -µ t ) 2 , where δ ∈ (0, 1) controls the forgetting rate. Empirically, we use δ = 0.9. Since ∇L(B t ; w t ) is already available during training, this EMA update does not involve additional gradient calculations (the cost for the norm operator is negligible). Using µ t and σ 2 t , we employ SAM only at iterations where ∇L(B t ; w t ) 2 is relatively large (i.e., the loss landscape is locally sharp). Specifically, when ∇L(B t ; w t ) 2 ≥ µ t + c t σ t (where c t is a threshold), SAM is used; otherwise, ERM is used. When c t → -∞, it reduces to SAM; when c t → ∞, it becomes ERM. Note that during the early training stage, the model is still underfitting and w t is far from the region of final convergence. Thus, minimizing the empirical loss is more important than seeking a locally flat region. Andriushchenko & Flammarion (2022) also empirically observe that the SAM update is more effective in boosting performance towards the end of training. We therefore design a schedule that linearly decreases c t from λ 2 to λ 1 (which are pre-set values): c t = g λ1,λ2 (t) ≡ t T λ 1 + 1 -t T λ 2 , where T is the total number of iterations. The whole procedure, called Adaptive policy to Employ SAM (AE-SAM), is shown in Algorithm 1.

AE-LookSAM.

The proposed adaptive policy can be combined with any SAM variant. Here, we consider integrating it with LookSAM (Liu et al., 2022) . When ∇L(B t ; w t ) 2 ≥ µ t + c t σ t , SAM is used and the update direction for w t is decomposed into two orthogonal directions as in LookSAM: (i) the ERM update direction to reduce training loss, and (ii) the direction that biases the model to a flat region. When ∇L(B t ; w t ) 2 < µ t + c t σ t , ERM is performed and the second direction of the previous SAM update is reused to compose an approximate SAM direction. The procedure, called AE-LookSAM, is also shown in Algorithm 1. Algorithm 1 AE-SAM and AE-LookSAM . stepsize η, radius ρ; λ 1 and λ 2 for g λ1,λ2 (t); w 0 , µ -1 = 0, σ 2 -1 = e -10 , and α for AE-LookSAM; 1: for t = 0, . . w t+1 = w t -ηg s ; 14: end for 15: return w T . Require: training set D,

3.3. CONVERGENCE ANALYSIS

In this section, we study the convergence of any algorithm A whose update in each iteration can be either SAM or ERM. Due to this mixing of SAM and ERM updates, analyzing its convergence is more challenging compared with that of SAM. The following assumptions on smoothness and bounded variance of stochastic gradients are standard in the literature on non-convex optimization (Ghadimi & Lan, 2013; Reddi et al., 2016) and SAM (Andriushchenko & Flammarion, 2022; Abbas et al., 2022; Qu et al., 2022) . Assumption 3.1 (Smoothness). L(D; w) is β-smooth in w, i.e., ∇L(D; w) -∇L(D; v) ≤ β w -v . Assumption 3.2 (Bounded variance of stochastic gradients). E (xi,yi)∼D ∇ (f (x i ; w), y i ) - ∇L(D; w) 2 ≤ σ 2 . Let ξ t be an indicator of whether SAM or ERM is used at iteration t (i.e., ξ t = 1 for SAM, and 0 for ERM). For example, ξ t = I {w: ∇L(Bt;w) 2 ≥µt+ctσt} (w t ) for the proposed AE-SAM, and ξ t is sampled from a Bernoulli distribution for SS-SAM (Zhao et al., 2022b) . , algorithm A satisfies min 0≤t≤T -1 E ∇L(D; w t ) 2 ≤ 32β (L(D; w 0 ) -EL(D; w T )) √ T (7 -6ζ) + (1 + ζ + 5β 2 ζ)σ 2 b √ T (7 -6ζ) , ( ) where ζ = 1 T T -1 t=0 ξ t ∈ [0, 1] is the fraction of SAM updates, and the expectation is taken over the random training samples. All proofs are in Appendix A. Note that a larger ζ leads to a larger upper bound in (6). When ζ = 1, the above reduces to SAM (Corollary A.2 of Appendix A.1).

4. EXPERIMENTS

In this section, we evaluate the proposed AE-SAM and AE-LookSAM on several standard benchmarks. As the SAM update doubles the computational overhead compared to the ERM update, the training speed is mainly determined by how often the SAM update is used. Hence, we evaluate efficiency by measuring the fraction of SAM updates used: %SAM ≡ 100 × #{iterations using SAM}/T . The total number of iterations, T , is the same for all methods. 4.1 CIFAR-10 AND CIFAR-100

Setup.

In this section, experiments are performed on the CIFAR-10 and CIFAR-100 datasets (Krizhevsky & Hinton, 2009) using four network architectures: ResNet-18 (He et al., 2016) , WideResNet-28-10 (denoted WRN-28-10) (Zagoruyko & Komodakis, 2016) , PyramidNet-110 (Han et al., 2017) , and ViT-S16 (Dosovitskiy et al., 2021) . Following the setup in (Liu et al., 2022; Foret et al., 2021; Zhao et al., 2022a) , we use batch size 128, initial learning rate of 0.1, cosine learning rate schedule, SGD optimizer with momentum 0.9 and weight decay 0.0001. The number of training epochs is 300 for PyramidNet-110, 1200 for ViT-S16, and 200 for ResNet-18 and WideResNet-28-10. 10% of the training set is used as the validation set. As in Foret et al. (2021) , we perform grid search for the radius ρ over {0.01, 0.02, 0.05, 0.1, 0.2, 0.5} using the validation set. Similarly, α is selected by grid search over {0.1, 0.3, 0.6, 0.9}. For the c t schedule g λ1,λ2 (t), λ 1 = -1 and λ 2 = 1 for AE-SAM; λ 1 = 0 and λ 2 = 2 for AE-LookSAM. Baselines. The proposed AE-SAM and AE-LookSAM are compared with the following baselines: (i) ERM; (ii) SAM (Foret et al., 2021) ; and its more efficient variants including (iii) ESAM (Du et al., 2022a) which uses part of the weights to compute the perturbation and part of the samples to compute the SAM update direction. These two techniques can reduce the computational cost, but may not always accelerate SAM, particularly in parallel training (Li et al., 2020) ; (iv) SS-SAM (Zhao et al., 2022b) , which randomly selects SAM or ERM according to a Bernoulli trial with success probability 0.5. This is the scheme with the best performance in (Zhao et al., 2022b) ; (v) Look-SAM (Liu et al., 2022) which uses SAM at every k = 5 steps. The experiment is repeated five times with different random seeds. Results. Table 1 shows the testing accuracy and fraction of SAM updates (%SAM). Methods are grouped based on %SAM. As can be seen, AE-SAM has higher accuracy than SAM while using only 50% of SAM updates. SS-SAM and AE-SAM have comparable %SAM (about 50%), and AE-SAM achieves higher accuracy than SS-SAM (which is statistically significant based on the pairwise t-test at 95% significance level). Finally, LookSAM and AE-LookSAM have comparable %SAM (about 20%), and AE-LookSAM also has higher accuracy than LookSAM. These improvements confirm that the adaptive policy is better.

4.2. ImageNet

Setup. In this section, we perform experiments on the ImageNet (Russakovsky et al., 2015) , which contains 1000 classes and 1.28 million images. The ResNet-50 (He et al., 2016) is used. Following the setup in Du et al. (2022a) , we train the network for 90 epochs using a SGD optimizer with momentum 0.9, weight decay 0.0001, initial learning rate 0.1, cosine learning rate schedule, and batch size 512. As in (Foret et al., 2021; Du et al., 2022a) , ρ = 0.05. For the c t schedule g λ1,λ2 (t), λ 1 = -1 and λ 2 = 1 for AE-SAM; λ 1 = 0 and λ 2 = 2 for AE-LookSAM. k = 5 is used for LookSAM. Experiments are repeated with three different random seeds. Results. Table 2 shows the testing accuracy and fraction of SAM updates. As can be seen, with only half of the iterations using SAM, AE-SAM achieves comparable performance as SAM. Compared with LookSAM, AE-LookSAM has better performance (which is also statistically significant), verifying the proposed adaptive policy is more effective than LookSAM's periodic policy.

4.3. ROBUSTNESS TO LABEL NOISE

Setup. In this section, we study whether the more-efficient SAM variants will affect its robustness to training label noise. Following the setup in Foret et al. (2021) , we conduct experiments on a corrupted version of CIFAR-10, with some of its training labels randomly flipped (while its testing set is kept clean). The ResNet-18 and ResNet-32 networks are used. They are trained for 200 epochs using SGD with momentum 0.9, weight decay 0.0001, batch size 128, initial learning rate 0.1, and cosine learning rate schedule. For LookSAM, the SAM update is used every k = 2 steps.foot_3 For AE-SAM and AE-LookSAM, we set λ 1 = -1 and λ 2 = 1 in their c t schedules g λ1,λ2 (t), such that their fractions of SAM updates (approximately 50%) are comparable with SS-SAM and LookSAM. Experiments are repeated with five different random seeds. Results. Table 3 shows the testing accuracy and fraction of SAM updates. As can be seen, AE-LookSAM achieves comparable performance with SAM but is faster as only half of the iterations use the SAM update. Compared with ESAM, SS-SAM, and LookSAM, AE-LookSAM performs better. The improvement is particularly noticeable at the higher noise levels (e.g., 80%). Figure 4 shows the training and testing accuracies with number of epochs at a noise level of 80% using ResNet-18foot_5 . As can be seen, SAM is robust to the label noise, while ERM and SS-SAM heavily suffer from overfitting. AE-SAM and LookSAM can alleviate the overfitting problem to a certain extent. AE-LookSAM, by combining the adaptive policy with LookSAM, achieves the same high level of robustness as SAM. In this experiment, we study the effects of λ 1 and λ 2 on AE-SAM. We use the same setup as in Section 4.1, where λ 1 and λ 2 (with λ 1 ≤ λ 2 ) are chosen from {0, ±1, ±2}. Results on AE-LookSAM using the label noise setup in Section 4.3 are shown in Appendix B.4. Figure 5 shows the effect on the fraction of SAM updates. For a fixed λ 2 , increasing λ 1 increases the threshold c t , and the condition ∇L(B t ; w t ) 2 ≥ µ t + c t σ t becomes more difficult to satisfy. Thus, as can be seen, the fraction of SAM updates is reduced. The same applies when λ 2 increases. A similar trend is also observed on the testing accuracy (Figure 6 ).

4.5. CONVERGENCE

In this experiment, we study whether w t 's (where t is the number of epochs) obtained from AE-SAM can reach critical points of L(D; w), as suggested in Theorem 3.3. Figure 7 shows ∇L(D; w t ) 2 w.r.t. t for the experiment in Section 4.1. As can be seen, in all settings, ∇L(D; w t ) 2 converges to 0. In Appendix B.5, we also verify the convergence of AE-SAM's training loss on CIFAR-10 and CIFAR-100 (Figure 14 ), and that AE-SAM and SS-SAM have comparable convergence speeds (Figure 15 ), which agrees with Theorem 3.3 as both have comparable fractions of SAM updates (Table 1 ). 

5. CONCLUSION

In this paper, we proposed an adaptive policy to employ SAM based on the loss landscape geometry. Using the policy, we proposed an efficient algorithm (called AE-SAM) to reduce the fraction of SAM updates during training. We theoretically and empirically analyzed the convergence of AE-SAM. Experimental results on a number of datasets and network architectures verify the efficiency and effectiveness of the adaptive policy. Moreover, the proposed policy is general and can be combined with other SAM variants, as demonstrated by the success of AE-LookSAM.

A PROOFS

A  E ∇L(D; w t ) 2 ≤ 32β (L(D; w 0 ) -EL(D; w T )) √ T (7 -6ζ) + (1 + ζ + 5β 2 ζ)σ 2 b √ T (7 -6ζ) , where (Andriushchenko & Flammarion (2022) ). Under Assumptions 3.1 and 3.2 for all t and ρ > 0, we have ζ = 1 T T -1 t=0 ξ t ∈ [0, 1]. Lemma A.1 E∇L(B t ; w + ρ∇L(B t ; w)) ∇L(D; w) ≥ 1 2 -ρβ ∇L(D; w) 2 - β 2 ρ 2 σ 2 2b . Proof. Let g t ≡ 1 b (xi,yi)∈Bt ∇ (f (x i ; w t ), y i ), h t ≡ 1 b (xi,yi)∈Bt ∇ (f (x i ; w t + ρg t ), y i ), and ĝt ≡ ∇L(D; w t ). By Taylor expansion and L(D; w) is β-smooth, we have L(D; w t+1 ) ≤L(D; w t ) + ĝ t (w t+1 -w t ) + β 2 w t+1 -w t 2 ≤L(D; w t ) -ηĝ t ((1 -ξ t )g t + ξ t h t ) + βη 2 2 (1 -ξ t )g t + ξ t h t 2 =L(D; w t )-η(1-ξ t )ĝ t g t -ηξ t ĝ t h t + βη 2 2   (1-ξ t ) g t 2 +ξ t h t 2 +2ξ t (1 -ξ t )g t h t =0   (9) =L(D; w t ) -η(1 -ξ t )ĝ t g t -ηξ t ĝ t h t + βη 2 2 (1 -ξ t ) g t 2 + ξ t h t 2 , where we have used ξ t (1 -ξ t ) = 0 as ξ t ∈ {0, 1}, ξ 2 t = ξ t , and (1 -ξ t ) 2 = 1 -ξ t to obtain (9). Taking expectation w.r.t. w t on both sides of (10), we have EL(D; w t+1 ) ≤ EL(D; w t )-η(1-ξ t )E ĝt 2 -ηξ t Eĝ t h t + βη 2 (1 -ξ t ) 2 E g t 2 + βη 2 ξ t 2 E h t 2 . Claim 1: E g t 2 = E g t -ĝt 2 +E ĝt 2 = σ 2 b +E ĝt 2 , which follows from Assumption 3.2. Claim 2: E h t 2 ≤ 2(1 + ρ 2 β 2 ) σ 2 b -(1 -2ρ 2 β 2 )E ĝt 2 + 2Eĝ t h t , which is derived as follows: E h t 2 = E h t -ĝt 2 -E ĝt 2 + 2Eĝ t h t = 2E h t -g t 2 + 2E g t -ĝt 2 -E ĝt 2 + 2Eĝ t h t ≤ 2ρ 2 β 2 E g t 2 + 2σ 2 b -E ĝt 2 + 2Eĝ t h t (12) ≤ 2ρ 2 β 2 σ 2 b + E ĝt 2 + 2σ 2 b -E ĝt 2 + 2Eĝ t h t (13) = 2(1 + ρ 2 β 2 ) σ 2 b -(1 -2ρ 2 β 2 )E ĝt 2 + 2Eĝ t h t , where ( 12) follows from h t -g t ≤ ρβ g t and Assumption 3.2, (13) follows from Claim 1. Substituting Claims 1 and 2 into (11), we obtain EL(D; w t+1 ) ≤ EL(D; w t ) -η (1 -ξ t ) E ĝt 2 -ηξ t Eĝ t h t + βη 2 (1 -ξ t ) 2 σ 2 b + E ĝt 2 + βη 2 ξ t 2 2(1 + ρ 2 β 2 ) σ 2 b -(1 -2ρ 2 β 2 )E ĝt 2 + 2Eĝ t h t (15) = EL(D; w t ) -η 1 -ξ t - βη(1 -ξ t ) 2 + βηξ t (1 -2ρ 2 β 2 ) 2 E ĝt 2 -ηξ t (1 -ηβ) Eĝ t h t + βη 2 (1 -ξ t ) 2 + βη 2 ξ t (1 + ρ 2 β 2 ) σ 2 b ≤ EL(D; w t ) -η 1 -ξ t - βη(1 -ξ t ) 2 + βηξ t (1 -2ρ 2 β 2 ) 2 + ξ t (1 -ηβ) ( 1 2 -ρβ) E ĝt 2 + βη 2 (1 -ξ t ) 2 + βη 2 ξ t (1 + ρ 2 β 2 ) + ηξ t (1 -ηβ) β 2 ρ 2 2 σ 2 b (16) ≤ EL(D; w t ) -η 1 -(1 + βη -2ρβ) ξ t 2 - βη 2 E ĝt 2 + η + ξ t (η + 2ηρ 2 β 2 + βρ 2 -ηβ 2 ρ 2 ) ηβσ 2 2b , where (15) follows from Claims 1 and 2, (16) follows from Lemma A.1 and 1 -ηβ > 0. As η < 1 4β , we have 1 + βη -2ρβ ≤ 3/2 and βη < 1/4, thus, 1 - (1 + βη -2ρβ) ξt 2 -βη 2 > 0. as a conference at ICLR 2023 Summing over t on both sides of ( 17) and rearranging, we obtain min 0≤t≤T -1 E ĝt 2 ≤ L(D; w 0 ) -EL(D; w T ) η T -1 t=0 1 -(1 + βη -2ρβ) ξt 2 -βη 2 + T -1 t=0 η + ξ t (η + ηρ 2 β 2 + βρ 2 ) T -1 t=0 1 -(1 + βη -2ρβ) ξt 2 -βη 2 βσ 2 2b = L(D; w 0 ) -EL(D; w T ) T η(1 -γζ 2 -βη 2 ) + T (η + ηκζ + βρ 2 ζ)βσ 2 2bT (1 -γζ 2 -βη 2 ) (18) = L(D; w 0 ) -EL(D; w T ) T η(1 -γζ 2 -βη 2 ) + (1 + κζ + 4β 2 ζ)ηβσ 2 2b(1 -γζ 2 -βη 2 ) = L(D; w 0 ) -EL(D; w T ) T η(1 -γζ 2 -βη 2 ) + (1 + κζ + 4β 2 ζ)σ 2 8b √ T (1 -γζ 2 -βη 2 ) (19) ≤ 32β (L(D; w 0 ) -EL(D; w T )) √ T (7 -6ζ) + (1 + ζ + 5β 2 ζ)σ 2 b √ T (7 -6ζ) , where γ = 1 + βη -2ρβ ≤ 3/2, κ = 1 + ρ 2 β 2 , ρ 2 = 1/ √ T , and ζ = 1 T T -1 t=0 ξ t ∈ [0, 1]. We thus finish the proof. , SAM (Foret et al., 2021 ) satisfies  min 0≤t≤T -1 E ∇L(D; w t ) 2 ≤ 32β (L(D; w 0 ) -EL(D; w T )) √ T + (2 + 5β 2 )σ 2 b √ T . (1 + ζ + 5β 2 ζ)σ 2 √ T b(7 -6ζ) , ( ) where ζ = 1 T T -1 t=0 ξ t ∈ [0, 1]. Proof. It follows from ( 18) that min 0≤t≤T -1 E ĝt 2 ≤ L(D; w 0 ) -EL(D; w T ) T η(1 -γζ 2 -βη 2 ) + ηβ(1 + κζ + 4β 2 ζ)σ 2 2b(1 -γζ 2 -βη 2 ) (23) ≤ 4β(L(D; w 0 ) -EL(D; w T )) √ T b( 7 8 -3 4 ζ) + (1 + ζ + 5β 2 ζ)σ 2 8 √ T b( 7 8 -3ζ 4 ) (24) = 32β(L(D; w 0 ) -EL(D; w T )) √ T b(7 -6ζ) + (1 + ζ + 5β 2 ζ)σ 2 √ T b(7 -6ζ) . A.2 CONVERGENCE OF FULL-BATCH GRADIENT DESCENT FOR AE-SAM Theorem A.4. Under Assumption 3.1, with full-batch gradient descent, if ρ < 1 2β and η < 1 β , algorithm A satisfies min 0≤t≤T -1 ∇L(D; w t ) 2 ≤ L(D; w 0 ) -L(D; w T ) T η 1 -βη 2 -βρζ , ( ) where ζ = 1 T T -1 t=0 ξ t ∈ [0, 1]. Lemma A.5 (Lemma 7 in Andriushchenko & Flammarion (2022)  ≤ L(D; w t ) -ηg t ((1 -ξ t )g t + ξ t h t ) + βη 2 2 (1 -ξ t )g t + ξ t h t 2 = L(D; w t )-η(1-ξ t ) g t 2 -ηξ t g t h t + βη 2 2   (1-ξ t ) g t 2 +ξ t h t 2 +2ξ t (1-ξ t )g t h t =0   (28) = L(D; w t ) -η 1 -ξ t - βη(1 -ξ t ) 2 g t 2 + βη 2 ξ t 2 h t 2 -ηξ t g t h t , where we have used ξ t (1 -ξ t ) = 0 as ξ t ∈ {0, 1}, ξ 2 t = ξ t , and (1 -ξ t ) 2 = 1 -ξ t to obtain (28). As h t 2 = h t -g t 2 -g t 2 + 2g t h t , it follows from (29) that L(D; w t+1 ) =L(D; w t )-η 1-ξ t - βη(1 -ξ t ) 2 g t 2 + βη 2 ξ t 2 h t -g t 2 -g t 2 + 2g t h t -ηξ t g t h t ≤L(D; w t )-η 1-ξ t - βη(1 -ξ t ) 2 + βηξ t 2 g t 2 + βη 2 ξ t 2 h t -g t 2 -η(1 -βη)ξ t g t h t ≤L(D; w t )-η 1-ξ t - βη(1-ξ t ) 2 + βηξ t 2 g t 2 + β 3 η 2 ρ 2 ξ t 2 g t 2 -η(1 -βη)ξ t g t h t (30) =L(D; w t ) -η 1 -ξ t - βη(1 -ξ t ) 2 + βηξ t 2 + β 3 ηρ 2 ξ t 2 + (1 -βη)(1 -βρ)ξ t g t 2 (31) =L(D; w t ) -η 1 - βη(1 -ξ t ) 2 + βηξ t 2 + β 3 ηξ t ρ 2 2 -βηξ t -βρξ t + β 2 ηρξ t g t 2 ≤L(D; w t ) -η 1 - βη 2 -βρξ t g t 2 , where we have used h t -g t 2 = ∇L(D; w t + ρ∇L(D; w t )) -∇L(D; w t ) 2 ≤ β 2 ρ 2 ∇L(D; w t ) 2 = β 2 ρ 2 g t 2 to obtain (30), and Lemma A.5 to obtain (31). Summing over t from t = 0 to T -1 on both sides of (32) and rearranging, we have T -1 t=0 η 1 - βη 2 -βρξ t g t 2 ≤ L(D; w 0 ) -L(D; w T ). As ρ < 1 2β and η < 1 β , it follows that 1 -βη 2 -βρξ t > 0 for all t. Thus, (33) implies min 0≤t≤T -1 g t 2 ≤ L(D; w 0 ) -L(D; w T ) T -1 t=0 η 1 -βη 2 -ξ t βρ = L(D; w 0 ) -L(D; w T ) T η 1 -βη 2 -βρζ , where ζ = 1 T T -1 t=0 ξ t ∈ [0, 1] and we finish the proof.

B ADDITIONAL EXPERIMENTAL RESULTS

B.1 DISTRIBUTION OF STOCHASTIC GRADIENT NORMS Figure 8 shows the distributions of stochastic gradient norms for ResNet-18, WRN-28-10 and PyramidNet-110 on CIFAR-10 and CIFAR-100. As can be seen, the distribution follows a Bell curve in all settings. Figure 9 shows the Q-Q plots. We can see that the curves are close to the lines. 

B.2 EFFECT OF k ON LOOKSAM

In this experiment, we demonstrate that LookSAM is sensitive to the choice of k. Table 4 shows the testing accuracy and fraction of SAM updates when using LookSAM on noisy CIFAR-10, with k ∈ {2, 3, 4, 5} and the ResNet-18 model. As can be seen, k = 2 yields much better performance than k ∈ {3, 4, 5}, particularly at higher noise levels (e.g., 80%). Figure 11 : Accuracies with number of epochs on CIFAR-10 with 20%, 40%, 60%, and 80% noise level using ResNet-32. Best viewed in color.



Results on other architectures and CIFAR-10 are shown in Figures 8 and 9 of Appendix B.1. Note that normality is not needed in the theoretical analysis (Section 3.3). The performance of LookSAM can be sensitive to the value of k. Table of Appendix B.2 shows that using k = 2 leads to the best performance in this experiment. Results for other noise levels and ResNet-32 are shown in Figures 10 and 11 of Appendix B.3, respectively.



Figure 1: Variance of gradient on CIFAR-100. Best viewed in color.

(a)   shows ∇L(B t ; w t ) 2 of 400 mini-batches at different training stages(epoch = 60, 120, and 180)

Figure 2: Squared stochastic gradient norms E B ∇L(B; w t ) 2 on CIFAR-100. Best viewed in color.

Figure 3: Stochastic gradient norms { L(B t ; w t ) 2 : B t ∼ D} of ResNet-18 on CIFAR-100 are approximately normally distributed. Best viewed in color.

Theorem 3.3. Let b be the mini-batch size. If stepsize η = 1 4β √ T and ρ = 1 T 1 4

Figure 4: Accuracies with number of training epochs on CIFAR-10 (with 80% noise labels) using ResNet-18. Best viewed in color.

Figure 5: Effects of λ 1 and λ 2 on fraction of SAM updates using ResNet-18. Best viewed in color.

Figure 6: Effects of λ 1 and λ 2 on testing accuracy using ResNet-18. Best viewed in color.

Figure 7: Squared gradient norms of AE-SAM with number of epochs. Best viewed in color.

.1 PROOF OF THEOREM 3.3 Theorem 3.3. Let b be the mini-batch size. If η =

Let b be the mini-batch size. If η =

Corollary A.3. Let b be the mini-batch size. If η = ∇L(D; w t ) 2 ≤ 32β(L(D; w 0 ) -EL(D; w T )) √ T b(7 -6ζ) +

Figure 8: Distributions of stochastic gradient norms on CIFAR-10 (top) and CIFAR-100 (bottom). Best viewed in color.

Figure 9: Q-Q plots of stochastic gradient norms on CIFAR-10 (top) and CIFAR-100 (bottom). Best viewed in color.

Figure10: Accuracies with number of epochs on CIFAR-10 with 20%, 40%, 60%, and 80% noise level using ResNet-18. Best viewed in color.

. , T -1 do

Means and standard deviations of testing accuracy and fraction of SAM updates (%SAM) on CIFAR-10 and CIFAR-100. Methods are grouped based on %SAM. The highest accuracy in each group is underlined; while the highest accuracy for each network architecture (across all groups) is in bold.

Means and standard deviations of testing accuracy and fraction of SAM updates (%SAM) on ImageNet using ResNet-50. Methods are grouped based on %SAM. The highest accuracy in each group is underlined; while the highest across all groups is in bold.

Testing accuracy and fraction of SAM updates on CIFAR-10 with different levels of label noise. The best accuracy is in bold and the second best is underlined.

). Let L(D; w) be a β-smooth function. For any ρ > 0, we have ∇L(D; w) ∇L(D; w + ρ∇L(D; w)) ≥ (1 -ρβ) ∇L(D; w) 2 . Proof of Theorem A.4. Let g t ≡ ∇L(D; w t ) and h t ≡ ∇L(D; w t + ρ∇L(D; w t )) be the update direction of ERM and SAM, respectively. By Taylor expansion and L(D; w) is β-smooth, we have L(D; w t+1 )

Effects of k in LookSAM on CIFAR-10 with different levels of label noise using ResNet-18. .3 MORE RESULTS ON ROBUSTNESS TO LABEL NOISE Figure10(resp. 11) shows the curves of accuracies at noise levels of 20%, 40%, 60%, and 80% with ResNet-18 (resp. ResNet-32). As can be seen, in all settings, AE-LookSAM is as robust to label noise as SAM.

ACKNOWLEDGMENTS

This work was supported by NSFC key grant 62136005, NSFC general grant 62076118, and Shenzhen fundamental research program JCYJ20210324105000003. This research was supported in part by the Research Grants Council of the Hong Kong Special Administrative Region (Grant 16200021).

annex

B.4 EFFECTS OF λ 1 AND λ 2 ON AE-LOOKSAM In this experiment, we study the effects of λ 1 and λ 2 on AE-LookSAM. Experiment is performed on CIFAR-10 with label noise (80% noisy labels), using the same setup as in Section 4.3.Figure 12 shows the effects of λ 1 and λ 2 on the fraction of SAM updates. Again, as in Section 4.4, for a fixed λ 2 , increasing λ 1 always reduces the fraction of SAM updates. Figure 13 shows the effects of λ 1 and λ 2 on the testing accuracy of AE-SAM. As can be seen, the observations are similar to those in Section 4.4. As can be seen, AE-SAM achieves convergence with various network architectures.Figure 15 shows the training losses w.r.t. the number of epochs for AE-SAM and SS-SAM. As can be seen, AE-SAM and SS-SAM converge with comparable speeds, which agrees with Theorem 3.3 as both of them have comparable fractions of SAM updates (Table 1 ). 

