WIDE-MINIMA DENSITY HYPOTHESIS AND THE EXPLORE-EXPLOIT LEARNING RATE SCHEDULE Anonymous

Abstract

Several papers argue that wide minima generalize better than narrow minima. In this paper, through detailed experiments that not only corroborate the generalization properties of wide minima, we also provide empirical evidence for a new hypothesis that the density of wide minima is likely lower than the density of narrow minima. Further, motivated by this hypothesis, we design a novel explore-exploit learning rate schedule. On a variety of image and natural language datasets, compared to their original hand-tuned learning rate baselines, we show that our explore-exploit schedule can result in either up to 0.84% higher absolute accuracy using the original training budget or up to 57% reduced training time while achieving the original reported accuracy. For example, we achieve state-of-the-art (SOTA) accuracy for IWSLT'14 (DE-EN) and WMT'14 (DE-EN) datasets by just modifying the learning rate schedule of a high performing model.

1. INTRODUCTION

One of the fascinating properties of deep neural networks (DNNs) is their ability to generalize well, i.e., deliver high accuracy on the unseen test dataset. It is well-known that the learning rate (LR) schedules play an important role in the generalization performance (Keskar et al., 2016; Wu et al., 2018; Goyal et al., 2017) . In this paper, we study the question, what are the key properties of a learning rate schedule that help DNNs generalize well during training? We start with a series of experiments training Resnet18 on Cifar-10 over 200 epochs. We vary the number of epochs trained at a high LR of 0.1, called the explore epochs, from 0 to 100 and divide up the remaining epochs equally for training with LRs of 0.01 and 0.001. Note that the training loss typically stagnates around 50 epochs with 0.1 LR. Despite that, we find that as the number of explore epochs increase to 100, the average test accuracy also increases. We also find that the minima found in higher test accuracy runs are wider than the minima from lower test accuracy runs, corroborating past work on wide-minima and generalization (Keskar et al., 2016; Hochreiter & Schmidhuber, 1997; Jastrzebski et al., 2017; Wang et al., 2018) . Moreover, what was particularly surprising was that, even when using fewer explore epochs, a few runs out of many trials still resulted in high test accuracies! Thus, we not only find that an initial exploration phase with a high learning rate is essential to the good generalization of DNNs, but that this exploration phase needs to be run for sufficient time, even if the training loss stagnates much earlier. Further, we find that, even when the exploration phase is not given sufficient time, a few runs still see high test accuracy values. To explain these observations, we hypothesize that, in the DNN loss landscape, the density of narrow minima is significantly higher than that of wide minima. A large learning rate can escape narrow minima easily (as the optimizer can jump out of them with large steps). However, once it reaches a wide minima, it is likely to get stuck in it (if the "width" of the wide minima is large compared to the step size). With fewer explore epochs, a large learning rate might still get lucky occasionally in finding a wide minima but invariably finds only a narrower minima due to their higher density. As the explore duration increase, the probability of eventually landing in a wide minima also increase. Thus, a minimum duration of explore is necessary to land in a wide minimum with high probability. cosine decay's better generalization compared to linear decay. Moreover, the hypothesis enables a principled learning rate schedule design that explicitly accounts for the requisite explore duration. Motivated by the hypothesis, we design a novel Explore-Exploit learning rate schedule, where the initial explore phase optimizes at a high learning rate in order to arrive in the vicinity of a wide minimum. This is followed by an exploit phase which descends to the bottom of this wide minimum. We give explore phase enough time so that the probability of landing in a wide minima is high. For the exploit phase, we experimented with multiple schemes, and found a simple, parameterless, linear decay to zero to be effective. Thus, our proposed learning rate schedule optimizes at a constant high learning rate for some minimum time, followed by a linear decay to zero. We call this learning rate schedule the Knee schedule. We extensively evaluate the Knee schedule across a wide range of models and datasets, ranging from NLP (BERT pre-training, Transformer on WMT'14(EN-DE) and IWSLT'14(DE-EN)) to CNNs (ImageNet on ResNet-50, Cifar-10 on ResNet18), and spanning multiple optimizers: SGD Momentum, Adam, RAdam, and LAMB. In all cases, Knee schedule improves the test accuracy of state-of-the-art hand-tuned learning rate schedules, when trained using the original training budget. The explore duration is a hyper-parameter in Knee schedule but even if we set the explore duration to a fixed 50% fraction of total training budget, we find that it still outperforms prior schemes. We also experimented with reducing the training budget, and found that Knee schedule can achieve the same accuracy as the baseline under significantly reduced training budgets. For the BERT LARGE pretraining, and ImageNet experiments, we are able to train in 33%, 57% and 44% less training budget, respectively, for the same test accuracy. This corresponds to significant savings in GPU compute, e.g. savings of over 1000 V100 GPU-hours for BERT LARGE pretraining. The main contributions of our work are: 1. A hypothesis of lower density of wide minima in the DNN loss landscape, backed by extensive experiments, that explains why a high learning rate needs to be maintained for sufficient duration to achieve good generalization. 2. The hypothesis also explains the good performance of heuristic-based schemes such as cosine decay, and promotes a principled design of learning rate decay schemes. 3. Motivated by the hypothesis, we design an Explore-Exploit learning rate schedule called Knee schedule that outperforms prior heuristic-based learning rate schedules, including achieving state-of-the-art results in IWSLT'14 (DE-EN) and WMT'14 (DE-EN) datasets.

2. WIDE-MINIMA DENSITY HYPOTHESIS

Many popular learning rate (LR) schedules, such as the step decay schedules for image datasets, start the training with high LR, and then reduce the LR periodically. For example, consider the case of Cifar-10 on Resnet-18, trained using a typical step LR schedule of 0.1, 0.01, 0.001 for 100, 50, 50 epochs each. In many such schedules, even though training loss stagnates after several epochs of high LR, one still needs to continue training at high LR in order to get good generalization. For example, Figure 1 shows the training loss for Cifar-10 on Resnet-18, trained with a fixed LR of 0.1 (orange curve), compared to a model trained via a step schedule with LR reduced at epoch 50 (blue curve). As can be seen from the figure, the training loss stagnates after ≈ 50 epochs for the orange curve, and locally it makes sense to reduce the learning rate to decrease the loss. However, as shown in Table 1 , generalization is directly correlated with duration of training at high LR, with the highest test accuracy achieved when the high LR is used for 100 epochs, well past the point where training loss stagnates. To understand the above phenomena, we perform another experiment. We train Cifar-10 on Resnet-18 for 200 epochs, using a high LR of 0.1 for only 30 epochs and then use LR of 0.01 and 0.001 for 85 epochs each. We repeat this training 50 times with different random weight initializations. On an average, as expected, this training yields a low test accuracy of 94.81. However, in 1 of the 50 runs, we find that the test accuracy reaches 95.24, even higher than the average accuracy of 95.1 obtained while training at high LR for 100 epochs! Hypothesis. To explain the above observations, i.e., using a high learning rate for short duration results in low average test accuracy with rare occurrences of high test accuracy, while using the same

