SALR: SHARPNESS-AWARE LEARNING RATES FOR IMPROVED GENERALIZATION

Abstract

In an effort to improve generalization in deep learning, we propose SALR: a sharpness-aware learning rate update technique designed to recover flat minimizers. Our method dynamically updates the learning rate of gradient-based optimizers based on the local sharpness of the loss function. This allows optimizers to automatically increase learning rates at sharp valleys to increase the chance of escaping them. We demonstrate the effectiveness of SALR when adopted by various algorithms over a broad range of networks. Our experiments indicate that SALR improves generalization, converges faster, and drives solutions to significantly flatter regions.

1. INTRODUCTION

Since then, the generalization ability of flat minimizers has been repeatedly observed in many recent works (Neyshabur et al., 2017a; Goyal et al., 2017; Li et al., 2018; Izmailov et al., 2018) . Indeed, flat minimizers can potentially tie together many of the aforementioned approaches aimed at generalization. For instance, (1) higher gradient variance, when batches are small, increases the probability to avoid sharp regions (same can be said for SGD compared to GD) (Kleinberg et al., 2018) (2) averaging over multiple hypotheses leads to wider optima in ensembles and Bayesian deep learning (Izmailov et al., 2018) (3) regularization techniques such as dropout or over-parameterization can adjust the loss landscape into one that allows first order methods to favor wide valleys (Chaudhari et al., 2019; Allen-Zhu et al., 2019) . In this paper we study the direct problem of developing an algorithm that can converge to flat minimizers.Specifically, we introduce SALR: a sharpness aware learning rate designed to explore the



Figure 1: A conceptual sketch of flat and sharp minima (Keskar et al., 2017). Generalization in deep learning has recently been an active area of research. The efforts to improve generalization over the past two decades have brought upon many cornerstone advances and techniques; be it dropout (Gal & Ghahramani, 2016), batch-normalization (Ioffe & Szegedy, 2015), data-augmentation (Shorten & Khoshgoftaar, 2019), weight decay (Loshchilov & Hutter, 2019), adaptive gradient-based optimization (Kingma & Ba, 2015), architecture design and search (Radosavovic et al., 2020), ensembles and their Bayesian counterparts (Garipov et al., 2018; Izmailov et al., 2018), amongst many others. Yet, recently, researchers have discovered that the concept of sharpness/flatness plays a fundamental role in generalization. Though sharpness was first discussed in the context of neural networks in the early work of Hochreiter & Schmidhuber (1997), it was only brought to the forefront of deep learning research after the seminal paper by Keskar et al. (2017). While trying to investigate decreased generalization performance when large batch sizes are used (LeCun et al., 2012) in stochastic gradient descent (SGD), Keskar et al. (2017) notice that this phenomena can be justified by the ability of smaller batches to reach flat minimizers. Such flat minimizers in turn, generalize well as they are robust to low precision arithmetic or noise in the parameter space(Dinh et al., 2017; Kleinberg et al., 2018), as shown in Figure1.

