STABLE WEIGHT DECAY REGULARIZATION

Abstract

Weight decay is a popular regularization technique for training of deep neural networks. Modern deep learning libraries mainly use L 2 regularization as the default implementation of weight decay. Loshchilov & Hutter (2018) demonstrated that L 2 regularization is not identical to weight decay for adaptive gradient methods, such as Adaptive Momentum Estimation (Adam), and proposed Adam with Decoupled Weight Decay (AdamW). However, we found that the popular implementations of weight decay, including L 2 regularization and decoupled weight decay, in modern deep learning libraries usually damage performance. First, the L 2 regularization is unstable weight decay for all optimizers that use Momentum, such as stochastic gradient descent (SGD). Second, decoupled weight decay is highly unstable for all adaptive gradient methods. We further propose the Stable Weight Decay (SWD) method to fix the unstable weight decay problem from a dynamical perspective. The proposed SWD method makes significant improvements over L 2 regularization and decoupled weight decay in our experiments. Simply fixing weight decay in Adam by SWD, with no extra hyperparameter, can outperform complex Adam variants, which have more hyperparameters.

1. INTRODUCTION

Weight decay is a popular and even necessary regularization technique for training deep neural networks that generalize well (Krogh & Hertz, 1992) . People commonly use L 2 regularization as "weight decay" for training of deep neural networks and interpret it as a Gaussian prior over the model weights. This is true for vanilla SGD. However, Loshchilov & Hutter (2018) revealed that, when the learning rate is adaptive, the commonly used L 2 regularization is not identical to the vanilla weight decay proposed by Hanson & Pratt (1989) : θ t = (1 -λ 0 )θ t-1 -ηg t , where λ 0 is the weight decay hyperparameter, θ t is the model parameters at t-th step , η is the learning rate, and g t is the gradient of the minibatch loss function L(θ) at θ We try to answer when the unstable weight decay problem happens and how to mitigate it. While Loshchilov & Hutter (2018) discovered that weight decay should be decoupled from the gradients, we further discovered that weight decay should be coupled with the effective learning rate. We organize our main findings as follows. 1. The effect of weight decay can be interpreted as iteration-wisely rescaling the loss landscape and the learning rate at the same time. We formulate the weight decay rate based on the rescaling ratio per unit stepsize. We call it unstable weight decay if the weight decay rate is not constant during training, seen in Definition 1. Our empirical analysis suggests that the unstable weight decay problem may undesirably damage performance of popular optimizers. 2. L 2 regularization is unstable weight decay in all optimizers that use Momentum. Most popular optimizers in modern deep learning libraries use Momentum and L 2 regularization at the same time, including SGD (with Momentum). Unfortunately, L 2 regularization often damages performance in the presence of Momentum. 3. Decoupled weight decay is unstable weight decay in adaptive gradient methods. All adaptive gradient methods used L 2 regularization as weight decay until Loshchilov & Hutter (2018) proposed decoupled weight decay. However, decoupled weight decay only solves part of the unstable weight decay problem. Decoupled weight decay is still unstable in the presence of Adaptive Learning Rate. 4. Always make weight decay rates stable. We proposed the stable weight decay (SWD) method which applies a bias correction factor on decoupled weight decay to make weight decay more stable during training. SWD makes significant improvements over L 2 regularization and decouple weight decay in our experiments. We display the test performance in Table 1 . The Adam with SWD (AdamS) is displayed in Algorithm 4.

2. A DYNAMICAL PERSPECTIVE ON WEIGHT DECAY

In this section, we study how weight decay affects learning dynamics and how to quantitatively measure the effect of weight decay from a viewpoint of learning dynamics. We also reveal the hidden weight decay problem in SGD with Momentum. A dynamical perspective on weight decay. We first present a new theoretical tool for understanding the effect of weight decay. The vanilla weight decay described by Hanson & Pratt (1989) is given by Equation 1. A more popular implementation for vanilla SGD in modern deep learning libraries is given by θ t = (1 -ηλ)θ t-1 -η ∂L(θ t-1 ) ∂θ , where we denote the training loss of one minibatch as L(θ) and weight decay should be coupled with the learning rate. We define new coordinates w t ≡ θ t (1 -ηλ) -t , which is an iteration-dependent rescaled system of θ. In the system of w, we may define the loss function of w as L w (w t ) ≡ L((1 -ηλ) t w t ) = L(θ t ).



Test performance comparison of optimizers. We report the mean and the standard deviations (as the subscripts) of the optimal test errors computed over three runs of each experiment. AdamS generalizes better than popular adaptive gradient methods significantly and often compares favorably with the baseline optimizer SGD.

