RANDOMIZED SHARPNESS-AWARE TRAINING FOR BOOSTING COMPUTATIONAL EFFICIENCY IN DEEP LEARNING

Abstract

By driving optimizers to converge to flat minima, sharpness-aware learning algorithms (such as SAM) have shown the power to achieve state-of-art performances. However, these algorithms will generally incur one extra forward-backward propagation at each training iteration, which largely burdens the computation especially for scalable models. To this end, we propose an efficient training scheme, called Randomized Sharpness-Aware Training (RST). Optimizers in RST would perform a Bernoulli trial at each iteration to choose randomly from base algorithms (SGD) and sharpness-aware algorithms (SAM) with a probability arranged by a predefined scheduling function. Due to the mixture of base algorithms, the overall count of propagation pairs could be largely reduced. Also, we give theoretical analysis on the convergence of RST. Then, we empirically study the computation cost and effect of various types of scheduling functions, and give directions on setting appropriate scheduling functions. Further, we extend the RST to a general framework (G-RST), where we can adjust regularization degree on sharpness freely for any scheduling function. We show that G-RST can outperform SAM in most cases while saving 50% extra computation cost.

1. INTRODUCTION

Deep neural networks (DNNs) have shown great capabilities in solving many real-world complex tasks (He et al., 2016; Redmon et al., 2016; Devlin et al., 2018) . However, it is quite challenging to efficiently train them to achieve good performance, especially for today's severely overparameterized networks (Dosovitskiy et al., 2021; Han et al., 2017) . Although such numerous parameters can improve the expressiveness of DNNs, yet they may complicate the geometry of the loss surface and generate more global and local minima in this huge hypothesis weight space. By leveraging the finding that flat minima could exhibit better generalization ability, Foret et al. ( 2021) propose a sharpness-aware learning method called SAM, where loss geometry will be connected to the optimization to guide optimizers to converge to flat minima. Training with the SAM has shown the power to significantly improve model performance for various tasks (Foret et al., 2021; Chen et al., 2021) . But on the other hand, the computation cost of SAM is almost twice that of the vanilla stochastic gradient descent (SGD), since it will incur one additional forward-backward propagation for each training iteration, which largely burdens the computation in practice. Recently, techniques are introduced to improve the computation efficiency in SAM. Specifically, instead of using the full batch samples, Bahri et al. (2021) and Du et al. (2021a) select only part of batch samples to make approximations for the two forward-backward propagations. Although the computation cost can be reduced to some extent, unfortunately, the forward-backward propagation count in the SAM training scheme will not change essentially. Mi et al. (2022) randomly masking out part of weights during optimization in expectation to reduce the amount of gradient computations at each iteration. However, the efficiency improvement of such a method is strongly limited by the chain rule of gradient computation (Du et al., 2021a) . Besides, Liu et al. (2022) propose to repeatedly use the past descent vertical gradients in SAM to reduce the incurred computational overhead. Meanwhile, random selection strategy is a powerful technique for boosting optimization efficiency, particularly in the field of gradient boosting (Friedman, 2001) , where a small set of learners in

