ADAPTIVE WEIGHT DECAY: ON THE FLY WEIGHT DECAY TUNING FOR IMPROVING ROBUSTNESS

Abstract

We introduce adaptive weight decay, which automatically tunes the hyperparameter for weight decay during each training iteration. For classification problems, we propose changing the value of the weight decay hyper-parameter on the fly based on the strength of updates from the classification loss (i.e., gradient of cross-entropy), and the regularization loss (i.e., ℓ 2 -norm of the weights). We show that this simple modification can result in large improvements in adversarial robustness -an area which suffers from robust overfitting -without requiring extra data. Specifically, our reformulation results in 20% relative robustness improvement for CIFAR-100, and 10% relative robustness improvement on CIFAR-10 comparing to traditional weight decay. In addition, this method has other desirable properties, such as less sensitivity to learning rate, and smaller weight norms, which the latter contributes to robustness to overfitting to label noise, and pruning.

1. INTRODUCTION

Modern deep learning models have exceeded human capability on many computer vision tasks. Due to their high capacity for memorizing training examples (Zhang et al., 2021) , their generalization heavily relies on the training algorithm. To reduce memorization, several approaches have been taken including regularization and augmentation. Some of these augmentation techniques alter the network input (DeVries & Taylor, 2017; Chen et al., 2020; Cubuk et al., 2019; 2020; Müller & Hutter, 2021) , some alter hidden states of the network (Srivastava et al., 2014; Ioffe & Szegedy, 2015; Gastaldi, 2017; Yamada et al., 2019) , some alter the expected output (Warde-Farley & Goodfellow, 2016; Kannan et al., 2018) , and some effect multiple levels (Zhang et al., 2017; Yun et al., 2019; Hendrycks et al., 2019b) . Another popular approach to prevent overfitting is the use of regularizers, such as weight decay (Plaut et al., 1986; Krogh & Hertz, 1991) . Such methods prevent over-fitting by eliminating solutions that memorize training examples. Regularization methods are attractive beyond generalization on clean data as they are crucial in adversarial and noisy-data settings. In this paper, we revisit weight decay; a regularizer mainly used to avoid overfitting. The rest of the paper is organized as follows: In Section 2, we revisit tuning the hyper-parameters for weight decay. We introduce adaptive weight decay in Section 3, and further discuss its properties in 4. More specifically, we discuss the benefits of adaptive weight decay in the setting of adversarial training in subsection 4.1, noisy labels in subsection 4.2, and additional properties related to robustness in subsection 4.3.

2. WEIGHT DECAY

Weight decay which encourages weights of networks to have smaller magnitudes (Zhang et al., 2018) has widely been adopted to improve generalization. Although other forms of weight decay have been studied (Loshchilov & Hutter, 2017) , we focus on the popular ℓ 2 -norm variant. More precisely, we look at classification problems with cross-entropy as the main loss and weight decay as the regularizer, which was popularized by Krizhevsky et al. (2017) : Loss w (x, y) = CrossEntropy(w(x), y) + λ wd 2 ∥w∥ 2 2 , Setting 1 Setting 2 lr 0.01 0.1 λ wd 0.005 0.0005 Weight Decay Update (-wλ wd lr) -w t × 5 × 10 -5 -w t × 5 × 10 -5 Cross Entropy Update (-∇w t lr) -∇w t × 10 -2 -∇w t × 10 -1 Table 1 : Effect of moving along the diagonal on the 2D grid search for learning rate and λ wd on the updates in one gradient descent step. where w is the network parameters, and (x, y) is the training data, and λ wd is the hyper-parameter controlling how much weight decay penalizes the norm of weights compared to the main loss (i.e., cross-entropy loss). For instance, if λ wd is negligible, the optimization will likely over-fit data, whereas if λ wd is too large, the optimization will collapse to a low-weighted solution that does not fit the training data. Consequently, finding the correct value for the weight decay's hyper-parameter is crucial. Models trained using the right hyper-parameter for weight decay tend to have higher bias, which translates to less over-fitting and, thus, better generalization (Krogh & Hertz, 1991) .

2.1. TUNING HYPER-PARAMETERS

When tuning λ wd , it is crucial to search for learning-rate lr simultaneously,foot_0 as illustrated in Figure 1 for CIFAR-10 and CIFAR-100. Otherwise, in the case that the 2D grid search is not computationally practical, we show that separate 1D grid search on learning rate and weight decay is not optimal. See Appendix C. To better understand the relationship between adjacent cells on the same diagonal of the grid search, let us consider the two cells corresponding to sets of hyper-parameter values (λ wd = 0.005, lr = 0.01) and (λ wd = 0.0005, lr = 0.1). We compare one step of gradient descent using these hyperparameters. To derive the parameter w at step t + 1 from its value at step t we have: w t+1 = w t -∇w t lr -wλ wd lr, where ∇w t is the gradient computed from the cross-entropy loss and wλ wd is the gradient computed from the weight decay term from eq. 1. By comparing the two settings, we realize that the update coming from weight decay (i.e. -wλ wd lr) remains the same, but the update coming from cross entropy (i.e. -∇w t lr) differs by a factor of 10 as shown in Table 1 . In other words, by moving along cells on the same diagonal in Fig. 1 , we are changing the importance of the cross-entropy term compared to weight decay while keeping the updates from the weight decay intact. A question we ask is: What is the optimal ratio between the update coming from cross-entropy and the update coming from the weight decay in settings that generally perform better?



See Appendix C for more details on the importance of 2D grid search.



Figure 1: Grid Search on different values of learning rate and weight decay on accuracy of WRN28-10 trained with SGD on (a) CIFAR10 and (b) CIFAR100, and (c) ResNet50 optimized with ASAM (Kwon et al., 2021) for CIFAR-100.

