SALR: SHARPNESS-AWARE LEARNING RATES FOR IMPROVED GENERALIZATION

Abstract

In an effort to improve generalization in deep learning, we propose SALR: a sharpness-aware learning rate update technique designed to recover flat minimizers. Our method dynamically updates the learning rate of gradient-based optimizers based on the local sharpness of the loss function. This allows optimizers to automatically increase learning rates at sharp valleys to increase the chance of escaping them. We demonstrate the effectiveness of SALR when adopted by various algorithms over a broad range of networks. Our experiments indicate that SALR improves generalization, converges faster, and drives solutions to significantly flatter regions.

1. INTRODUCTION

Figure 1 : A conceptual sketch of flat and sharp minima (Keskar et al., 2017) . Generalization in deep learning has recently been an active area of research. The efforts to improve generalization over the past two decades have brought upon many cornerstone advances and techniques; be it dropout (Gal & Ghahramani, 2016) , batch-normalization (Ioffe & Szegedy, 2015) , data-augmentation (Shorten & Khoshgoftaar, 2019) , weight decay (Loshchilov & Hutter, 2019) , adaptive gradient-based optimization (Kingma & Ba, 2015) , architecture design and search (Radosavovic et al., 2020) , ensembles and their Bayesian counterparts (Garipov et al., 2018; Izmailov et al., 2018) , amongst many others. Yet, recently, researchers have discovered that the concept of sharpness/flatness plays a fundamental role in generalization. Though sharpness was first discussed in the context of neural networks in the early work of Hochreiter & Schmidhuber (1997), it was only brought to the forefront of deep learning research after the seminal paper by Keskar et al. (2017) . While trying to investigate decreased generalization performance when large batch sizes are used (LeCun et al., 2012) in stochastic gradient descent (SGD), Keskar et al. (2017) notice that this phenomena can be justified by the ability of smaller batches to reach flat minimizers. Such flat minimizers in turn, generalize well as they are robust to low precision arithmetic or noise in the parameter space (Dinh et al., 2017; Kleinberg et al., 2018) , as shown in Figure 1 . Since then, the generalization ability of flat minimizers has been repeatedly observed in many recent works (Neyshabur et al., 2017a; Goyal et al., 2017; Li et al., 2018; Izmailov et al., 2018) . Indeed, flat minimizers can potentially tie together many of the aforementioned approaches aimed at generalization. For instance, (1) higher gradient variance, when batches are small, increases the probability to avoid sharp regions (same can be said for SGD compared to GD) (Kleinberg et al., 2018) (2) averaging over multiple hypotheses leads to wider optima in ensembles and Bayesian deep learning (Izmailov et al., 2018) (3) regularization techniques such as dropout or over-parameterization can adjust the loss landscape into one that allows first order methods to favor wide valleys (Chaudhari et al., 2019; Allen-Zhu et al., 2019) . In this paper we study the direct problem of developing an algorithm that can converge to flat minimizers.Specifically, we introduce SALR: a sharpness aware learning rate designed to explore the loss-surface of an objective function and avoid undesired sharp local minima. SALR dynamically updates the learning rate based on the sharpness of the neighborhood of the current solution. The idea is simple: automatically increase the learning rates at relatively sharp valleys in an effort to escape them. One of the key features of SALR is that it can be fitted into any gradient based method such as Adagrad (Duchi et al., 2011) , ADAM (Kingma & Ba, 2015) and also into recent approaches towards escaping sharp valleys such as Entropy-SGD (Chaudhari et al., 2019) .

1.1. RELATED WORK

From a theoretical perspective, generalization of deep learning solutions has been explained through multiple lenses. One of which is uniform stability (Bottou & Le Cun, 2005; Bottou & Bousquet, 2008; Hardt et al., 2016; Gonen & Shalev-Shwartz, 2017; Bottou et al., 2018 ). An algorithm is uniformly stable if for all data sets differing in only one element, nearly the same outputs will be produced (Bousquet & Elisseeff, 2002) . Hardt et al. (2016) show that SGD satisfies this property and derive a generalization bound for models learned with SGD. (Sontag, 1998) , the former works show that deep hypothesis spaces are typically more advantageous in representing complex functions. Besides that, the importance of flatness on generalization has been theoretically highlighted through PAC-Bayes bounds (Dziugaite & Roy, 2017; Neyshabur et al., 2017b; Wang et al., 2018) . These papers highlight the ability to derive non-vacuous generalization bounds based on the sharpness of a model class while arguing that relatively flat solutions yield tight bounds. From an algorithmic perspective, approaches to recover flat minima are still limited. Most notably, Chaudhari et al. ( 2019) developed the Entropy-SGD algorithm. Entropy-SGD defines a localentropy-based objective which smoothens the energy landscape based on its local geometry. This in turn allows SGD to attain flatter solutions. Indeed, this approach was motivated by earlier work in statistical physics (Baldassi et al., 2015; 2016) which proves the existence of non-isolated solutions that generalize well in networks with discrete weights. Such non-isolated solutions correspond to flat minima in continuous settings. The authors then propose a set of approaches based on ensembles and replicas of the loss to favor wide solutions. Not too far, recent methods in Bayesian deep learning (BDL) have also shown potential to recover flat minima. BDL basically averages over multiple hypotheses weighted by their posterior probabilities (ensembles being a special case of BDL (Izmailov et al., 2018) ). One example, is the stochastic weighted averaging (SWA) algorithm proposed by Izmailov et al. (2018) . SWA simply averages over multiple points along the trajectory of SGD to potentially find flatter solutions compared to SGD. Another example is the SWA-Gaussian (SWAG). SWAG defines a Gaussian posterior approximation over neural network weights. Afterwards, samples are taken from the approximated distribution to perform Bayesian model averaging (Maddox et al., 2019) . Here we also note the recent work by Patel (2017) which partially motivates our method. Upon the aformentioned observations in Keskar et al. (2017 ), Patel (2017) shows that the learning rate lowerbound threshold for the divergence of batch SGD, run on quadratic optimization problems, increases for larger batch-sizes. In general non-convex settings, given a problem with N local minimizers, one can compute N lower bound thresholds for local divergence of batch SGD. The number of minimizers for which batch SGD can converge is non-decreasing in the batch size. This is used to explain the tendency of low-batch SGD to converge to flatter minimizers compared to large-batch SGD. The former result links the choice of batch size and its effect on generalization to the choice of the learning rate. With the latter being a tunable parameter, to our knowledge, developing a dynamic choice of the learning rate that targets convergence to flat minimizers has not been studied before.

2. GENERAL FRAMEWORK

In this paper, we propose a framework that dynamically chooses a Sharpness-Aware Learning Rate to promote convergence to flat minimizers. More specifically, our proposed method locally approximates sharpness at the current iterate and dynamically adjusts the learning rate accordingly. In sharp regions, relatively large learning rates are attained to increase the chance of escaping that region. In contrast, when the current iterate belongs to a flat region, our method returns a relatively small



From a different viewpoint, Choromanska et al. (2015); Kawaguchi (2016); Poggio et al. (2017); Mohri et al. (2018) attribute generalization to the complexity of the hypothesis-space. Using measures like Rademacher complexity (Mohri & Rostamizadeh, 2009) and the Vapnik-Chervonenkis (VC) dimension

