RANDOMIZED SHARPNESS-AWARE TRAINING FOR BOOSTING COMPUTATIONAL EFFICIENCY IN DEEP LEARNING

Abstract

By driving optimizers to converge to flat minima, sharpness-aware learning algorithms (such as SAM) have shown the power to achieve state-of-art performances. However, these algorithms will generally incur one extra forward-backward propagation at each training iteration, which largely burdens the computation especially for scalable models. To this end, we propose an efficient training scheme, called Randomized Sharpness-Aware Training (RST). Optimizers in RST would perform a Bernoulli trial at each iteration to choose randomly from base algorithms (SGD) and sharpness-aware algorithms (SAM) with a probability arranged by a predefined scheduling function. Due to the mixture of base algorithms, the overall count of propagation pairs could be largely reduced. Also, we give theoretical analysis on the convergence of RST. Then, we empirically study the computation cost and effect of various types of scheduling functions, and give directions on setting appropriate scheduling functions. Further, we extend the RST to a general framework (G-RST), where we can adjust regularization degree on sharpness freely for any scheduling function. We show that G-RST can outperform SAM in most cases while saving 50% extra computation cost.

1. INTRODUCTION

Deep neural networks (DNNs) have shown great capabilities in solving many real-world complex tasks (He et al., 2016; Redmon et al., 2016; Devlin et al., 2018) . However, it is quite challenging to efficiently train them to achieve good performance, especially for today's severely overparameterized networks (Dosovitskiy et al., 2021; Han et al., 2017) . Although such numerous parameters can improve the expressiveness of DNNs, yet they may complicate the geometry of the loss surface and generate more global and local minima in this huge hypothesis weight space. By leveraging the finding that flat minima could exhibit better generalization ability, Foret et al. (2021) propose a sharpness-aware learning method called SAM, where loss geometry will be connected to the optimization to guide optimizers to converge to flat minima. Training with the SAM has shown the power to significantly improve model performance for various tasks (Foret et al., 2021; Chen et al., 2021) . But on the other hand, the computation cost of SAM is almost twice that of the vanilla stochastic gradient descent (SGD), since it will incur one additional forward-backward propagation for each training iteration, which largely burdens the computation in practice. Recently, techniques are introduced to improve the computation efficiency in SAM. Specifically, instead of using the full batch samples, Bahri et al. (2021) and Du et al. (2021a) select only part of batch samples to make approximations for the two forward-backward propagations. Although the computation cost can be reduced to some extent, unfortunately, the forward-backward propagation count in the SAM training scheme will not change essentially. Mi et al. (2022) randomly masking out part of weights during optimization in expectation to reduce the amount of gradient computations at each iteration. However, the efficiency improvement of such a method is strongly limited by the chain rule of gradient computation (Du et al., 2021a) . Besides, Liu et al. (2022) propose to repeatedly use the past descent vertical gradients in SAM to reduce the incurred computational overhead. Meanwhile, random selection strategy is a powerful technique for boosting optimization efficiency, particularly in the field of gradient boosting (Friedman, 2001) , where a small set of learners in gradient boosting machines would be selected randomly to be optimized under certain rule (Lu & Mazumder, 2020; Konstantinov et al., 2021) . Inspired by such randomization scheme in gradient boosting, we would present a simple but efficient training scheme, called Randomized Sharpness-Aware Training (RST). In our RST, the learning process would be randomized, where optimizers would randomly select to perform from base learning algorithms and sharpness-aware learning algorithms at each training iteration with a given probability. And this selecting probability is arranged by a custom scheduling function predefined before training. The scheduling function not only controls how much propagation count would be reduced, but also impacts the model performance. Our contribution can be summarized as, 1. We propose a simple but efficient training scheme, called RST, which could reduce the propagation count via mixing base learning (SGD) algorithms and sharpness-aware learning (SAM) algorithms randomly. 2. We give interpretation of our RST scheme from the perspective of gradient norm regularization (GNR) (Zhao et al., 2022) , and theoretically prove the convergence of RST scheme. 3. We empirically study the effect when arranging different scheduling functions, including totally three typical types of function families with six function groups. 4. We extend the RST to a general framework (G-RST), where GRN algorithm is mixed such that regularization degree on gradient norm can be adjusted freely. By training both CNN models and ViT models on commonly-used datasets, we show that G-RST can outperform SAM mostly while saving at least 50% extra computation cost.

1.1. OTHER RELATED WORKS

We would like to discuss works associated with the research on flat minima. In Hochreiter & Schmidhuber (1997) , the authors are the first to point out that the flatness of minima could be associated with the model generalization, where models with better generalization should converge to flat minima. And such claim has been supported extensively by both empirical evidences and theoretical demonstrations (Keskar et al., 2017; Dinh et al., 2017) . In the meantime, researchers are also fascinating by how to implement practical algorithms to force the models to converge to such flat minima. By summarizing this problem to a specific minimax optimization, Foret et al. (2021) introduce the SAM training scheme, which successfully guides optimizers to converge to flat minima. Further, Zheng et al. (2021) perform gradient descent twice to solve the minimization and maximization respectively in this minimax optimization. In Kwon et al. (2021) , Adaptive SAM training scheme for improving SAM to be able to remain steady when performing weight rescaling operations. Zhao et al. (2022) seek flat minima by explicitly penalizing the gradient norm of the loss function. Unlike SAM-related training schemes, without a restriction on neighborhood region, Du et al. (2022) We propose to minimize the KL-divergence between the output distributions yielded by the current model and the moving average of past models, similar to the idea of knowledge distillation.

2.1. RANDOMIZED SHARPNESS-AWARE TRAINING (RST)

The general idea of RST follows a randomization scheme, where the learning process will be randomized. Specifically, for each training iteration t, optimizers would perform a Bernoulli trial to choose from base learning algorithms and sharpness-aware learning algorithms. Here, we will consider first mixing the two most commonly-used algorithms, SGD and SAM. Thus, in each Bernoulli trial, the optimizer would perform the SAM algorithm with a probability p(t) or perform the SGD algorithm with probability 1-p(t). Here, p(t) could be a predefined custom function of iteration t, and we would call it the scheduling function of RST. Apparently, the sample space for this Bernoulli trial corresponds to the set Ω = {SGD, SAM}. Correspondingly, a random variable could be defined on this sample space, X(t) : Ω → {0, 1}, where X(t) = 0 denotes performing the SGD algorithm while X(t) = 1 denotes performing the SAM algorithm. In summary, X(t) ∼ Bernoulli(p(t)), and Compute the gradient g = ∇ θ L(θ t ). θ 0 X(1) ---→ θ 1 • • • θ t X(t+1) -----→ θ t+1 , X(t) ∈ {0, 1}

4:

Perform the Bernoulli trial with probability p t and record the result X t .

5:

if X t = 0 then Implement SGD algorithm 6: g t = g.

7:

else Implement SAM algorithm 8: g t = ∇ θ L(θ t ) at θ t = θ t + t with t = ρ g ||g|| . 9: end if 10: Update weight θ t+1 = θ t -η • g t 11: end for 12: return final weight θ = θ T Additionally, Algorithm 1 shows the complete implementation when training with RST scheme. Compared to the SAM training scheme, every time SGD algorithm is selected instead of SAM algorithm in the RST scheme, we would save one forward-backward propagation. Therefore, for training iteration t, the expectation of propagation count ηt in RST could be ηt = 2  • p t + 1 • (1 -p t ) = 1 + p t where ∆η ∈ [0, 1], bounded between ∆η in the vanilla SGD scheme and the SAM scheme. Obviously, the scheduling function p(t) would straightforwardly control the number of propagations being saved. ∆η would be larger if performing the SAM optimization with a higher probability. Also, an appropriate schedule could improve model performance further while a bad one may largely harm the training. We would provide a detailed study on the scheduling function in the later sections.

2.2. UNDERSTANDING RST FROM GRADIENT NORM REGULARIZATION

From previous demonstration, the gradient of RST at training iteration t could be expressed as, g t = (1 -X t ) • ∇ θ L(θ t ) + X t ∇ θ L(θ t + t ) (4) where t = ρ • ∇ θ L(θ t )/||∇ θ L(θ t )||. And the expectation of this gradient over X is, E X [g t ] = (1 -p t )∇ θ L(θ t ) + p t ∇ θ L(θ t + t ) (5) According to Zhao et al. (2022) , gradients in the form of Equation 5 can be interpreted as regularizations on the gradient norm (GRN) of loss function. Specifically, when imposing penalty on the gradient norm during training with a penalty coefficient γ, L(θ) + γ||∇ θ L(θ)||, the corresponding gradient could be approximated via the combination between ∇ θ L(θ t ) and ∇ θ L(θ t + t ), which is g (gnr) t = (1 - γ ρ )∇ θ L(θ t ) + γ ρ ∇ θ L(θ t + t ) meaning that SAM is one special implementation of gradient norm regularization, where γ sam = ρ. From Equation 5and Equation 6, we could reason that p t in Equation 5 has an equivalent effect with the term γ/ρ in GNR. It means the equivalent penalty coefficient in RST would be γ rst = p t • ρ = p t • γ sam (7) Compared to the SAM training scheme, the penalty degree is reduced by a factor of p t in RST.

2.3. CONVERGENCE ANALYSIS OF RST

In this section, we would give analysis in regards to the convergence in RST. Theorem 1. Assume the gradient of the loss function L(•) is β-smoothness, i.e. ||∇L(θ 1 ) -∇L(θ 2 )|| ≤ β||θ 1 -θ 2 || for ∀θ 1 , θ 2 ∈ Θ. For iteration steps T ≥ 0, learning rate α t ≤ 1/β and √ p t ρ ≤ 1/β, we have min t∈{0,1,••• ,T -1} ||∇L(θ t )|| 2 ≤ 2(L(θ 0 ) -L * ) t∈{0,1,••• ,T -1} α t + t∈{0,1,••• ,T -1} α t p t ρ 2 β 2 t∈{0,1,••• ,T -1} α t We would provide detailed proof in the Appendix.  min t∈{0,1,••• ,T -1} ||∇L(θ t )|| 2 ≤ 2β(L(θ 0 ) -L * ) CT + pρ 2 β 2 Corollary 2. For decayed learning rate α t = C/t and constant scheduling probability p, we have min t∈{0,1,••• ,T -1} ||∇L(θ t )|| 2 ≤ 2(L(θ 0 ) -L * ) C log T + pρ 2 β 2 Corollary 1 and 2 show the convergence of common implementation in practice. Theorem 2. Assume the gradient of the loss function L(•) is β-smoothness. Assume Polyak- Lojasiewicz condition, i.e. 1 2 ||∇L(θ t )|| 2 ≥ (L(θ t ) -L * ). For iteration steps T ≥ 0, learning rate α t ≤ 1/β and √ p t ρ ≤ 1/β, we have, E X [L(θ t )] -L * L(θ 0 ) -L * ≤ t∈{0,1,••• ,T -1} 1 -α t (1 -p t ρ 2 t β 2 ) Appendix shows the proof. Theorem 2 indicates that RST experiences a linear convergence rate.

3. EMPIRICAL STUDY OF SCHEDULING FUNCTION p(t)

In this section, we would investigate the computation efficiency and the impact on model performance when training with the RST scheme under different types of scheduling functions p(t).

3.1. BASIC SETTING AND BASELINES

In our investigation of the effect of scheduling functions, we will train models with different scheduling functions from scratch to tackle the image classification tasks on Cifar-{10, 100} datasets, and compare the corresponding convergence performance and the incurred extra computation overhead. For models, we would choose ResNet18 (He et al., 2016) and WideResNet-28-10 (Zagoruyko & Komodakis, 2016) architectures as our main target. For data augmentation, we would follow the basic strategy, where each image would be randomly flipped horizontally, then padded with four extra pixels and finally cropped randomly to 32 × 32. Expect for the scheduling functions implemented in the RST schemes, all the involved models are trained for 200 epochs with exactly the same hyperparameters. For each training case, we would run with five different seeds and report the average mean and standard deviation of these five runs. All the training details could be found in Appendix. Meanwhile, we have also reported additional results regarding other model architectures and other data augmentation strategy in Appendix. Before our investigations on scheduling functions in RST, we would like to clarity the baseline first, where models are trained with the vanilla SGD scheme and SAM scheme. Table 1 shows the corresponding results, including the testing error rate (Error column), the training time (Time column) and the extra expected propagation count (∆η column). For the training time, we would report the total wall time spent to train for 200 epochs on four A100 Nvidia GPUs. From the table, we could find that compared to the SGD scheme, the SAM scheme could indeed improve the model performance, but in the meantime would incur more computations (102% for ResNet18 and 83% for WideResNet28-10).

3.2. IMPLEMENTATION OF SCHEDULING FUNCTION

Here, we will focus on studying three types of function families, which can cover most scheduling patterns. Table 2 shows the basic information regarding the three function scheduling families.  a l t + b l Propagation Count ∆η a c 1 + 2apbp -bp -ap p l ( T 2 ) Constant Function Family In constant scheduling function family, the scheduling probability is p c (t) = a c , where a c ∈ [0, 1]. Optimizers would select to perform the SAM algorithm with a fixed probability a c and the SGD algorithm with 1 -a c during the whole training process. This implies that the extra computation overhead for constant scheduling function is proportional to the scheduling probability a c . Here, we will experimentally investigate a group of implementation with constant functions, where the scheduling probability a c will be set from 0.1 to 0.9 with an interval of 0.1. Figure 1A shows the scheduling functions of this group and Figure 1B shows the relationship between the extra expected propagation count η (x-axis) and the extra practical training wall time (y-axis) incurred by selecting SAM algorithm in RST. We could see that for both ResNet18 and WideResNet28-10 models, all the points locate very close to the reference line (x = y). The actual extra training wall time can be almost fully decided by the theoretical extra ∆η. Therefore, we could directly use ∆η to indicate the extra computation cost for RST in the following demonstrations. Then, Figure 1C shows the corresponding testing error rates of the two models with error bars (neighbor area) on Cifar10 (left) and Cifar100 (right). In the figure, x-axis denotes the extra ∆η and meanwhile the markers are scaled by the actual training wall time. And the endpoints on both sides of the lines denote the testing error rates of training with the SGD scheme and the SAM scheme. Firstly, we could find that even with the lowest probability a c = 0.1, as long as SAM algorithm is involved during training process, testing error rates could be generally reduced compared to those trained with only the SGD algorithm. But on the other side, model performance can not be improved continuously with the growth selecting probability towards the SAM algorithm. Secondly, compared to the SAM scheme, testing error rates would already reach comparable performance when a c = 0.6 in RST, which would save about 40% computation overhead. In particular, when around a c = 0.8, models would achieve the best performance, slightly outperforming the SAM scheme (3.65%/19.61% for ResNet18 and 2.71%/16.17% for WideResNet28-10 in RST). Additionally, we could see from the error bars that despite the randomness introduced in RST, training would still be fairly stable over the five runs. For the first group, we would set a p = 0 and change the stage-related parameter b p from 0.1 to 0.9 with an interval of 0.1. Now, the optimizer actually behaves in a deterministic manner, which performs SGD algorithm in the first b p T iterations and then switches to SAM algorithm for the rest. Therefore, the larger b p is, the longer SGD algorithm will be performed, and the less extra computation overhead will be incurred. From the results, we could find that for all the training cases in this group, as implementing more iterations with SAM algorithm, we could get better performance gradually, which could achieve better performance than those trained with the SAM scheme. And the best performance between this group and the constant group are very close (3.66%/19.47% for ResNet18 and 2.69%/16.31% for WideResNet28-10 in this group). Next, in the second group, we would arrange training in an opposite way from piecewise group 1, where we will keep all the settings except deploying a p = 1. Optimizers would perform SAM algorithm in the first b p T iterations and then switch to SGD for the rest steps. Actually, models could not get good performance under such arrangement. The results show that training needs to accumulate sufficient SAM iterations to completely outperform SGD scheme. Models could reach competitive performance only when performing SGD algorithm in the last few iterations. Intuitively, implementation pattern of piecewise group 2 would somewhat go against the core of sharpnessaware learning. Frequently implementing SGD algorithm near the end of training would be harmful to the convergence to flat minima. Unlike previous patterns, in piecewise group 3, we fix b p = 0.5 and change a p from 0.1 to 0.9 with an interval of 0.1. Now, optimizers will pick SAM algorithm with probability a p for the first half of training iterations and then switch the probability to 1 -a p for the rest. For all the training instances in this group, we have ∆η = 0.5. And, the actual training wall time between these cases are rather close (Time[m]: +8.2(±0.4) for ResNet18 and +13.9(±0.7) for WideResNet28-10). Note that the results of this group are plotted against the evolution of a p , not the propagation count. We could see in the results that model performance would gradually get higher as the growth probability of implementation with SAM algorithm in the second stage. This somehow again confirms the previous demonstration of avoiding frequently implementing SGD algorithm near the end of training. Linear Function Family For linear scheduling functions, the selecting probability p(t) is scheduled linearly, changing monotonously with either an increasing or a decreasing pattern. Optimizers would select to perform SAM algorithm with decreasing probability when a l ≤ 0 while with increasing probability when a l ≥ 0. Notably, from the summary table 2, the computation overhead of such implementation is actually decided by the scheduling probability at T /2. We would focus on two typical groups of linear scheduling functions in our experiments. Figure 2 show the scheduling functions and the results. In the first group, we would schedule the functions to pass through two given points, where the first point is (T /2, m) and the second point is either (0, 0) or (1, 1) depending on the value of m. Here, the parameter m denotes the probability to be set at the training iteration T /2. And we would set it from 0.1 to 0.9 with an interval of 0.1. Clearly, in this group, the probability of selecting SAM algorithm would increase over the iterations. Also, as m increases, SAM algorithm would experience an overall higher probability of selection. We could find in the results that as performing more SAM algorithm, model performance would be more and more better. And the trend of model performance in this group would be quite similar to that in piecewise group 1. Actually, these two groups share very close selection patterns in general, where the scheduling probability is changed instantaneously in piecewise group 1 while it becomes gradually in this group. As for the second group, the scheduling functions would pass through two points that are (T /2, 0.5) and (0, b l ). This means that training will always incur 0.5 extra propagation count in expectation, ∆η = 0.5. From the results, we can find that similar to those in piecewise group 3, model perfor- mance will also progressively become higher, but more mildly. Likewise, the two groups also have close selection pattern, as in the same way of that between piecewise group 1 and linear group 1.

3.3. SUMMARY

To give a summary view of these scheduling functions, Figure 3 gives the scatter plot of WideResNet28-10 between the model performance and the incurred extra propagation counts for all the scheduling function cases. From previous demonstrations and the figure, we can conclude, • Avoid to schedule the SGD algorithm with relatively high probability near the end of training since it would largely harm the training. • Generally, scheduling SAM algorithm with higher probability in total would bring better model performance, where the best model performance in RST would outperform those in SAM scheme. • Compared to other schedules, simple constant scheduling functions could give decent model performance. So, we recommend using constant scheduling functions in practice for both their simplicity and effectiveness.

4. GENERAL FRAMEWORK FOR RST

Recall that from Equation 7, SAM training is actually regularizing the gradient norm with γ sam = ρ, and RST to mix SGD algorithm and SAM algorithm would have a scaling effect on this penalty by a factor of p t . However, when the scheduling probability p t is low, RST may be unable to provide sufficient equivalent regularization effect on gradient norm. This motivates to expand RST to a general form (G-RST) which mixes between SGD algorithm and GNR algorithm (Equation 6) such that G-RST could freely adjust the scaling effect of the penalty degree on gradient norm, g t = (1 -X t ) • g (sgd) t + X t • g (gnr) t In this way, G-RST would be given an extra freedom to control the scaled penalty degree via γ in GNR, which would be γ rst = p t γ gnr . It allows training to impose arbitrary regularization on gradient norm while enjoying a high probability of selecting SGD algorithm. In our following experiments, we would use the constant scheduling functions because of their efficiency and simplicity as demonstrated previously. Here, we would consider p(t) = 0.5, so we need to set γ gnr = 2 when mixing, to provide an equivalent regularization as that in SAM scheme. We would first train models with G-RST on Cifar datasets, which involves both CNN models and ViT models (Dosovitskiy et al., 2021) . For CNN models, we would keep the basic settings the same as those in the previous section. As for ViT models, we would train each case for 1200 epochs and adopt some further data augmentation to get the best performance. Note that the base algorithm switch to Adam in ViT models. All the training details are reported in the Appendix. Table 3 shows the corresponding results of these models on Cifar datasets. We could observe from the table that compared to the SAM scheme, G-RST could improve the model performance further to some extent while saving 50% of the extra computation overhead for all the training cases. This indicates that adjusting the penalty coefficient in RST can give comparable effect as that in SAM. Following the same setting of p t as that on Cifar datasets, we would train ResNet-{50, 101} models on ImageNet for 100 epochs to further investigate the effectiveness of G-RST on large-scale dataset. Table 4 shows the final results, where each case is trained over three random seeds. Likewise, we can find that G-RST can also give better model performance while being 50% less computational expensive than SAM scheme, which again confirms the effectiveness of G-RST.

5. CONCLUSION

We propose a simple but efficient training scheme, called Randomized Sharpness-Aware Training (RST), for reducing the computation overhead in the sharpness-aware training. In RST, optimizers will be scheduled to randomly select from the base learning algorithm and sharpness-aware learning training scheme at each training iteration. Such a scheme can be interpreted as regularization on gradient norm with scaling effect. Then, we theoretically prove RST converges in finite training iterations. As for the scheduling functions, we empirically show that simple constant scheduling functions can achieve comparable results with other scheduling functions. Finally, we extend the RST to a general framework (G-RST), where the regularization effect can be adjusted freely. We show that G-RST can outperform SAM to some extent while reducing 50% extra computation cost.

A PROOF OF THEOREM 1 & 2

A.1 PROOF OF THEOREM 1 In randomized sharpness-aware training (RST), weights θ t are updated stochastically with a random variable X t ∼ B(1, p t ), θ t+1 = θ t -α t ∇L(θ t + X t ρ t ∇L(θ t )) For β-smoothness functions, we have L(θ 1 ) ≤ L(θ 2 ) + ∇L(θ 2 ) T (θ 1 -θ 2 ) + β 2 ||θ 1 -θ 2 || 2 (10) Then, we set θ 1 = θ t+1 and θ 2 = θ t , L(θ t+1 ) ≤ L(θ t ) + ∇L(θ t ), θ t+1 -θ t + β 2 ||θ t+1 -θ t || 2 ≤ L(θ t ) -∇L(θ t ), α t ∇L(θ t + X t ρ t ∇L(θ t )) + β 2 ||α t ∇L(θ t + X t ρ t ∇L(θ t ))|| 2 ≤ L(θ t ) -α t ∇L(θ t ), ∇L(θ t + X t ρ t ∇L(θ t )) + α 2 t β 2 ||∇L(θ t + X t ρ t ∇L(θ t ))|| 2 (11) For α t ≤ 1/β, L(θ t+1 ) ≤ L(θ t ) -α t ∇L(θ t ), ∇L(θ t + X t ρ t ∇L(θ t )) + α t 2 ||∇L(θ t + X t ρ t ∇L(θ t ))|| 2 (12) Next, add αt 2 ||∇L(θ t )|| 2 and subtract αt 2 ||∇L(θ t )|| 2 , L(θ t+1 ) ≤ L(θ t ) + α t 2 ||∇L(θ t )|| 2 -α t ∇L(θ t ), ∇L(θ t + X t ρ t ∇L(θ t )) + α t 2 ||∇L(θ t + X t ρ t ∇L(θ t ))|| 2 - α t 2 ||∇L(θ t )|| 2 ≤ L(θ t ) + α t 2 ||∇L(θ t + X t ρ t ∇L(θ t )) -∇L(θ t )|| 2 - α t 2 ||∇L(θ t )|| 2 ≤ L(θ t ) + α t 2 ||βX t ρ t ∇L(θ t )|| 2 - α t 2 ||∇L(θ t )|| 2 ≤ L(θ t ) - α t 2 (1 -X 2 t ρ 2 t β 2 )||∇L(θ t )|| 2 (13) So, ρ t ≤ 1/β such that the loss would decrease continuously in training, L(θ t+1 ) ≤ L(θ t ) ≤ L(θ t-1 ) • • • ≤ L(θ 0 ) (14) Rearrange 13, α t 2 (1 -X 2 t ρ 2 t β 2 )||∇L(θ t )|| 2 ≤ L(θ t ) -L(θ t+1 ) Taking expectation gives, E X α t 2 (1 -X 2 t ρ 2 t β 2 )||∇L(θ t )|| 2 ≤ E X [L(θ t )] -E X [L(θ t+1 )] (1 -p t ) α t 2 ||∇L(θ t )|| 2 + p t α t 2 (1 -ρ 2 t β 2 )||∇L(θ t )|| 2 ≤ E X [L(θ t )] -E X [L(θ t+1 )] α t 2 (1 -p t ρ 2 t β 2 )||∇L(θ t )|| 2 ≤ E X [L(θ t )] -E X [L(θ t+1 )] For ρ t = ρ/||∇L(θ t )|| in SAM optimization, the Equation 16, α t 2 (1 -p t ρ 2 t β 2 )||∇L(θ t )|| 2 ≤ E X [L(θ t )] -E X [L(θ t+1 )] α t 2 ||∇L(θ t )|| 2 - α t p t ρ 2 t β 2 2 ||∇L(θ t )|| 2 ≤ E X [L(θ t )] -E X [L(θ t+1 )] α t 2 ||∇L(θ t )|| 2 - α t p t ρ 2 β 2 2||∇L(θ t )|| 2 ||∇L(θ t )|| 2 ≤ E X [L(θ t )] -E X [L(θ t+1 )] α t 2 ||∇L(θ t )|| 2 ≤ E X [L(θ t )] -E X [L(θ t+1 )] + α t p t ρ 2 β 2 2 (17) Then, sum over the training steps, t∈{0,1,••• ,T -1} α t 2 ||∇L(θ t )|| 2 ≤ L(θ 0 ) -L * + t∈{0,1,••• ,T -1} α t p t ρ 2 β 2 2 (18) Here, L(θ 0 ) is the loss of the initialization model and L * denotes the optimal point, L * = min L(θ). Since min t∈{0,1,••• ,T -1} ||∇L(θ t )|| 2 ≤ ||L(θ)|| 2 , we have, min t∈{0,1,••• ,T -1} ||∇L(θ t )|| 2 ≤ 2(L(θ 0 ) -L * ) t∈{0,1,••• ,T -1} α t + Ξ where, Ξ = t∈{0,1,••• ,T -1} α t p t ρ 2 β 2 t∈{0,1,••• ,T -1} α t Generally, Equation 19indicates that for -suboptimal termination criteria ||L(θ t )|| ≤ , hybrid training would satisfy such convergence condition in finite training steps. Further, for constant learning rate schedules α t = C/β or cosine learning rate schedules α t = 2C/β • ( 1 2 + 1 2 cos( t T π)), and constant scheduling functions p t = p, we have min t∈{0,1,••• ,T -1} ||∇L(θ t )|| 2 ≤ 2β(L(θ 0 ) -L * ) CT + pρ 2 β 2 Here, we use T t=0 cos( t T π) = 0, which we would prove in the following lemma. In other words, the epsilon is associated with the O(1/T ). For decayed learning rate schedule α t = C/t, and constant scheduling functions p t = p, we have min t∈{0,1,••• ,T -1} ||∇L(θ t )|| 2 ≤ 2(L(θ 0 ) -L * ) C log T + pρ 2 β 2 In other words, the epsilon is associated with the O(1/ log T ). Lemma 1. For t ∈ {0, 1, 2, • • • , T }, we have T t=0 cos( t T π) = 0 Proof For trigonometric functions, T t=0 g( t T π) where g ∈ {sin, cos}. We would use the Euler's identity, e ix = cos x + i sin x Therefore, we have cos x = {e ix } and sin x = {e ix }, where {•} and {•} denote the real part and imaginary part. In this way, for g = cos, Equation 24 would be, T t=0 cos( t T π) = T t=0 {e i t T π } = { T t=0 e i t T π } = { e 0 (1 -e i T +1 T π ) 1 -e i 1 T π } = { e i T +1 2T π • (e -i T +1 2T π -e i T +1 2T π ) e i 1 2T π • (e -i 1 2T π -e i 1 2T π ) } = {e i T 2T π sin( T +1 2T π) sin( 1 2T )π } = cos( T 2T π) sin( T +1 2T π) sin( 1 2T )π = 0 (cos π 2 = 0) A.2 PROOF OF THEOREM 2 From the Polyak-Lojasiewicz condition, 1 2 ||∇L(θ t )|| 2 ≥ (L(θ t ) -L * ) From the previous Equation 16, we would have, α t 2 (1 -p t ρ 2 t β 2 )||∇L(θ t )|| 2 ≤ E X [L(θ t )] -E X [L(θ t+1 )] α t (1 -p t ρ 2 t β 2 )(E X [L(θ t )] -L * ) ≤ E X [L(θ t )] -E X [L(θ t+1 )] α t (1 -p t ρ 2 t β 2 )(E X [L(θ t )] -L * ) ≤ (E X [L(θ t )] -L * ) -(E X [L(θ t+1 )] -L * ) E X [L(θ t+1 )] -L * E X [L(θ t )] -L * ≤ 1 -α t (1 -p t ρ 2 t β 2 ) Then, performing iterative multiplication over the training steps gives, E X [L(θ t )] -L * L(θ 0 ) -L * ≤ t∈{0,1,••• ,T -1} 1 -α t (1 -p t ρ 2 t β 2 ) End of the proof.

B.1 TRIGONOMETRIC SCHEDULING FUNCTION

We would like to use WideResNet28-10 to further investigate the scheduling functions which are trigonometric functions p tr (t) in RST. Here, we would confine the trigonometric functions to only sinusoidal functions and cosine functions. And more specifically, we focus on investigating four scheduling functions,        p cos1 (t) = 1 2 + 1 2 cos t T π p cos2 (t) = 1 -p cos1 (t) = 1 2 -1 2 cos t T π p sin1 (t) = sin t T π p sin2 (t) = 1 -p sin1 (t) = 1 -sin t T π Note that all these functions are in the range between 0 and 1. Figure 4 shows the training scheme plots of the four functions and Table 5 shows the final results. From the table, when training with these trigonometric scheduling functions, training will incur 50% extra expected average propagation count for cosine functions and π/2 ≈ 64% for sinusoidal functions. For cosine functions, we could find that their pattern of scheduling probability could be quite close to linear functions. This could lead to that they may yield very similar performances. As for sinusoidal functions, implementations would present monotonously increasing or decreasing probability for the first half iterations and then switch to the opposite for the rest. Compared to that of cosine functions, as SAM would be implemented with more frequency in total, the corresponding results would be better. Additionally, the results have also confirmed that the performance would be degenerate when SGD is frequently selected near the end of training. And in summary, training with such complex trigonometric scheduling functions could not present better results than that with simple constant scheduling functions. We would still recommend to use simple constant scheduling functions in practical implementation.

B.2 γ rst IN G-RST

Based on the demonstrations on the G-RST, we would know that G-RST could adjust the regularization effect on the gradient norm freely for a given selecting probability. Therefore, we would perform some more tuning on the γ gnr to be mixed in RST to present the relationship between the model performance and the equivalent regularization degree γ rst . Here we would perform a grid searching over the selecting probability from 0.1 to 0.9 with an interval of 0.2, and then set the γ gnr (Equation 9 in the main paper) in the RST to fix the equivalent regularization effect γ rst across 0.5 to 1.5. Table 5 shows the corresponding 2D plot. From the table, we could find that when the selecting probability p t is very low, even if we impose a high regularization penalty, models could not be trained to achieve good performance. This is mainly because that based on γ rst = γ gnr p t , for these low p t , we have to mix a very high γ gnr to get a fair equivalent effect γ rst . When the γ gnr is very high in GNR, according to the paper Zhao et al. (2022) , it would cause a lose of precision on the approximations on the Hessian multiplication. Secondly, we could also find from the figure that when the equivalent regularization degree γ rst is around the range from 0.8 to 1, models could achieve the better performances than others. Imposing too much regularization on the gradient norm would instead harm the performance. For the fixed γ rst , increasing the selecting probability p t would somewhat improve the model performance, but not in a significant manner. In summary, it is recommended to set a moderate selecting probability and combine with a proper γ gnr that could lead to γ rst around. In this way, training would enjoy a gain on the computation efficiency and give satisfactory performance at the same time. And the Table 3 in the main paper actually follows

B.3 EXPERIMENT RESULTS WHEN USING CUTOUT REGULARIZATION

In addition to the basic data augmentation strategy used in the previous section, we would also investigate the effect when using the Cutout Regularization Devries & Taylor (2017) . Here, we would choose WideResNet28-10 as our main experiment target. Also, the training hyperparameters are the same as them used in the previous sections. The tables below show the final results, where trainings are going to be separately scheduled by constant scheduling functions (Table 6 ), the first group of piecewise scheduling functions (Table 7 ) and the first group of linear scheduling functions (Table 8 ) and trigonometric scheduling functions (Table 9 ). From the results, we would come to the same conclusions as those in the summary sections. In short, constant scheduling functions would be a good choice for practical implementation, which would be simple to implement and be able to yield at least comparable performance to other scheduling functions. We could see in the table that RST again could boost the computational efficiency and in the meantime acquire better model generalization compared to that trained using the SAM scheme.

B.5 USING RST SCHEME ON OTHER SAM VARIANTS

In this section, we are going to further show the effectiveness of our RST on SAM variants, where we would use ASAM (Kwon et al., 2021) and GSAM (Zhuang et al., 2022) as our investigation target. For both ASAM and GSAM, we would compare them with using our RST and G-RST schemes. Here, based on the previous demonstrations, the selecting probability in RST and G-RST is set constantly to 0.5. And for G-RST, since the essence of these SAM variants is regularizing the gradient norm, we would double the regularization effect in G-RST, the same as the implementations in previous experiments. Table 12 shows the final results. As we could see in the table, when using RST on ASAM and GSAM, we could obtain a similar results as using RST on SAM. Specifically, since RST and G-RST randomly selecting between sharpness-aware learning algorithm and the base learning algorithm, the computational efficiency could be largely improved for both ASAM and GSAM. And as previous demonstrations, RST would weaken the regularization effect, so we could see that the corresponding performance would be relatively lower than the standard sharpness-aware training. When doubling the regularization effect in G-RST, we could get comparable results with the standard sharpness-aware training, which again confirms the effectiveness of our method.

B.6 MIXING RST SCHEME WITH OTHER EFFICIENT SAM TECHNIQUES

In RST, the optimizer would choose to perform the base learning algorithm and the sharpness-aware algorithm. When selecting sharpness-aware algorithm, we could meanwhile adopt other efficient techniques to further improve the training efficiency. Here, we would study the mixing effect of RST with separately LookSAM (Liu et al., 2022) and weight masking techniques (Mi et al., 2022; Du et al., 2021b) . Table 13 shows the corresponding results. As we could see in the table, for all these efficient techniques, our RST could improve the computational efficiency further. However, if the selecting probability in RST is relatively low (0.5 in the table), it may harm the mixing effect. On the other hand, as properly raising the selecting probability (0.75 in the table), it is possible to acquire comparable results with these efficient techniques. 



Following the paper(Liu et al., 2022), LookSAM(5) denotes that update the descent gradient in SAM algorithm every five implementation iterations. 2 Unlike LookSAM, ESAM and SSAM are both implemented on the git repository https://github.com/Mi-Peng/Sparse-Sharpness-Aware-Minimization, where one A100 GPU is used. And SGD baseline is also obtained based on this repository.



Figure 1: (A) Scheduling function plots for the constant scheduling function family. (B) Scatter plot between the extra expected propagation count η (x-axis) and the extra practical training wall time (y-axis) incurred in RST. (C) Testing error rates of ResNet18 and WideResNet-28-10 models with error bars on Cifar10 (left) and Cifar100 (right) when training with these constant scheduling functions. The markers are scaled by the training wall time. Piecewise Function Family Generally, the selecting probability in piecewise function would experience a stage conversion during training. In the first stage, optimizers would be arranged to perform SAM algorithm with a probability of a p in the beginning b p T training iterations, and then in the second stage, this probability would change to 1 -a p for the rest training iterations. In our investigation, we would consider totally three typical groups of piecewise scheduling functions, where Figure 2 shows the corresponding scheduling function plots and their final results.

Figure 2: Testing error rates of ResNet18 and WideResNet-28-10 models with error bars on Cifar10 (left) and Cifar100 (right) when training separately with the three groups of piecewise scheduling functions and the two groups of linear scheduling functions in RST.set it from 0.1 to 0.9 with an interval of 0.1. Clearly, in this group, the probability of selecting SAM algorithm would increase over the iterations. Also, as m increases, SAM algorithm would experience an overall higher probability of selection. We could find in the results that as performing more SAM algorithm, model performance would be more and more better. And the trend of model performance in this group would be quite similar to that in piecewise group 1. Actually, these two groups share very close selection patterns in general, where the scheduling probability is changed instantaneously in piecewise group 1 while it becomes gradually in this group.

Figure3: Summary plot of model performance in regards to the extra computation overhead for all the scheduling function cases. mance will also progressively become higher, but more mildly. Likewise, the two groups also have close selection pattern, as in the same way of that between piecewise group 1 and linear group 1.

Figure 4: Scheduling function plots for the four trigonometric scheduling functions. The blue points stand for the instance of random variable X.

Figure 5: 2D image plot between the selecting probability p t and the equivalent regularization degree γ rst for WideResNet28-10 when training with RST.

Algorithm 1 Randomized Sharpness-Aware Training (RST) Input: Training set S = {(x i , y i )} N i=0 ; loss function L(•); batch size B; learning rate α; total iterations T ; neighborhood radius of SAM ρ, scheduling function p(t).

2)Here, p t denotes the scheduling probability of p(t) at training iteration t. Equation 2 indicates RST would incur extra more p t propagation count in expectation than the vanilla SGD training. Further, the average of the extra expected propagation count ∆η over the total training iterations T is,

Testing error rate of ResNet18 and WideResNet28-10 models on Cifar10 and Cifar100 datasets when training with the SGD scheme and SAM scheme respectively.

Scheduling functions p(t) and extra propagation counts ∆η of the three function families.

Testing error rate of CNN models and ViT models on Cifar10 and Cifar100 datasets when training with SGD, SAM and the G-RST where p(t) = 0.5.

Testing error rate of CNN models on ImageNet datasets when training with SGD, SAM and the G-RST where p(t) = 0.5.

Testing error rate of WideResNet28-10 models on Cifar10 and Cifar100 datasets when with the four trigonometric scheduling functions.

Testing error rate of WideResNet28-10 models on Cifar10 and Cifar100 datasets with Cutout regularization when training with constant scheduling functions.

Testing error rate of WideResNet28-10 models on Cifar10 and Cifar100 datasets with Cutout regularization when training with constant scheduling functions.

Testing error rate of WideResNet28-10 models on Cifar10 and Cifar100 datasets with Cutout regularization when training with constant scheduling functions.

Testing error rate of WideResNet28-10 models on Cifar10 and Cifar100 datasets with Cutout regularization when training with constant scheduling functions. From the previous results, we could see that the constant scheduling functions would already provide representative results. So here we would only investigate the results when trained with constant scheduling functions to make comparisons with the baselines.

Testing error rate of VGG16-BN models on Cifar10 and Cifar100 datasets when training with constant scheduling functions.

Testing error rate of ViT-S16 models on Cifar10 and Cifar100 datasets when training with constant scheduling functions. Note that the hyperparameters is different from those in the previous section. Here we only train for 300 epochs without mixup augmentation.

Testing error rate of CNN models and ViT models on Cifar10 and Cifar100 datasets when training with SGD, SAM and the G-RST where p(t) = 0.5.

C TRAINING DETAILS

Code is available at github.com/JustNobody0204/Submission-ICLR2023. The basic training hyperparameters are deployed as below, 

