RANDOMIZED SHARPNESS-AWARE TRAINING FOR BOOSTING COMPUTATIONAL EFFICIENCY IN DEEP LEARNING

Abstract

By driving optimizers to converge to flat minima, sharpness-aware learning algorithms (such as SAM) have shown the power to achieve state-of-art performances. However, these algorithms will generally incur one extra forward-backward propagation at each training iteration, which largely burdens the computation especially for scalable models. To this end, we propose an efficient training scheme, called Randomized Sharpness-Aware Training (RST). Optimizers in RST would perform a Bernoulli trial at each iteration to choose randomly from base algorithms (SGD) and sharpness-aware algorithms (SAM) with a probability arranged by a predefined scheduling function. Due to the mixture of base algorithms, the overall count of propagation pairs could be largely reduced. Also, we give theoretical analysis on the convergence of RST. Then, we empirically study the computation cost and effect of various types of scheduling functions, and give directions on setting appropriate scheduling functions. Further, we extend the RST to a general framework (G-RST), where we can adjust regularization degree on sharpness freely for any scheduling function. We show that G-RST can outperform SAM in most cases while saving 50% extra computation cost.

1. INTRODUCTION

Deep neural networks (DNNs) have shown great capabilities in solving many real-world complex tasks (He et al., 2016; Redmon et al., 2016; Devlin et al., 2018) . However, it is quite challenging to efficiently train them to achieve good performance, especially for today's severely overparameterized networks (Dosovitskiy et al., 2021; Han et al., 2017) . Although such numerous parameters can improve the expressiveness of DNNs, yet they may complicate the geometry of the loss surface and generate more global and local minima in this huge hypothesis weight space. By leveraging the finding that flat minima could exhibit better generalization ability, Foret et al. (2021) propose a sharpness-aware learning method called SAM, where loss geometry will be connected to the optimization to guide optimizers to converge to flat minima. Training with the SAM has shown the power to significantly improve model performance for various tasks (Foret et al., 2021; Chen et al., 2021) . But on the other hand, the computation cost of SAM is almost twice that of the vanilla stochastic gradient descent (SGD), since it will incur one additional forward-backward propagation for each training iteration, which largely burdens the computation in practice. Recently, techniques are introduced to improve the computation efficiency in SAM. Specifically, instead of using the full batch samples, Bahri et al. (2021) and Du et al. (2021a) select only part of batch samples to make approximations for the two forward-backward propagations. Although the computation cost can be reduced to some extent, unfortunately, the forward-backward propagation count in the SAM training scheme will not change essentially. Mi et al. (2022) randomly masking out part of weights during optimization in expectation to reduce the amount of gradient computations at each iteration. However, the efficiency improvement of such a method is strongly limited by the chain rule of gradient computation (Du et al., 2021a) . Besides, Liu et al. (2022) propose to repeatedly use the past descent vertical gradients in SAM to reduce the incurred computational overhead. Meanwhile, random selection strategy is a powerful technique for boosting optimization efficiency, particularly in the field of gradient boosting (Friedman, 2001) , where a small set of learners in gradient boosting machines would be selected randomly to be optimized under certain rule (Lu & Mazumder, 2020; Konstantinov et al., 2021) . Inspired by such randomization scheme in gradient boosting, we would present a simple but efficient training scheme, called Randomized Sharpness-Aware Training (RST). In our RST, the learning process would be randomized, where optimizers would randomly select to perform from base learning algorithms and sharpness-aware learning algorithms at each training iteration with a given probability. And this selecting probability is arranged by a custom scheduling function predefined before training. The scheduling function not only controls how much propagation count would be reduced, but also impacts the model performance. Our contribution can be summarized as, 1. We propose a simple but efficient training scheme, called RST, which could reduce the propagation count via mixing base learning (SGD) algorithms and sharpness-aware learning (SAM) algorithms randomly. 2. We give interpretation of our RST scheme from the perspective of gradient norm regularization (GNR) (Zhao et al., 2022) , and theoretically prove the convergence of RST scheme. 3. We empirically study the effect when arranging different scheduling functions, including totally three typical types of function families with six function groups. 4. We extend the RST to a general framework (G-RST), where GRN algorithm is mixed such that regularization degree on gradient norm can be adjusted freely. By training both CNN models and ViT models on commonly-used datasets, we show that G-RST can outperform SAM mostly while saving at least 50% extra computation cost.

1.1. OTHER RELATED WORKS

We would like to discuss works associated with the research on flat minima. In Hochreiter & Schmidhuber (1997) , the authors are the first to point out that the flatness of minima could be associated with the model generalization, where models with better generalization should converge to flat minima. And such claim has been supported extensively by both empirical evidences and theoretical demonstrations (Keskar et al., 2017; Dinh et al., 2017) . In the meantime, researchers are also fascinating by how to implement practical algorithms to force the models to converge to such flat minima. We propose to minimize the KL-divergence between the output distributions yielded by the current model and the moving average of past models, similar to the idea of knowledge distillation.

2. METHOD 2.1 RANDOMIZED SHARPNESS-AWARE TRAINING (RST)

The general idea of RST follows a randomization scheme, where the learning process will be randomized. Specifically, for each training iteration t, optimizers would perform a Bernoulli trial to choose from base learning algorithms and sharpness-aware learning algorithms. Here, we will consider first mixing the two most commonly-used algorithms, SGD and SAM. Thus, in each Bernoulli trial, the optimizer would perform the SAM algorithm with a probability p(t) or perform the SGD algorithm with probability 1-p(t). Here, p(t) could be a predefined custom function of iteration t, and we would call it the scheduling function of RST. Apparently, the sample space for this Bernoulli trial corresponds to the set Ω = {SGD, SAM}. Correspondingly, a random variable could be defined on this sample space, X(t) : Ω → {0, 1}, where X(t) = 0 denotes performing the SGD algorithm while X(t) = 1 denotes performing the SAM algorithm. In summary, X(t) ∼ Bernoulli(p(t)), and θ 0 X(1) ---→ θ 1 • • • θ t X(t+1) -----→ θ t+1 , X(t) ∈ {0, 1}



By summarizing this problem to a specific minimax optimization, Foret et al. (2021) introduce the SAM training scheme, which successfully guides optimizers to converge to flat minima. Further, Zheng et al. (2021) perform gradient descent twice to solve the minimization and maximization respectively in this minimax optimization. In Kwon et al. (2021), Adaptive SAM training scheme for improving SAM to be able to remain steady when performing weight rescaling operations. Zhao et al. (2022) seek flat minima by explicitly penalizing the gradient norm of the loss function. Unlike SAM-related training schemes, without a restriction on neighborhood region, Du et al. (2022)

