ADAPTIVE WEIGHT DECAY: ON THE FLY WEIGHT DECAY TUNING FOR IMPROVING ROBUSTNESS

Abstract

We introduce adaptive weight decay, which automatically tunes the hyperparameter for weight decay during each training iteration. For classification problems, we propose changing the value of the weight decay hyper-parameter on the fly based on the strength of updates from the classification loss (i.e., gradient of cross-entropy), and the regularization loss (i.e., ℓ 2 -norm of the weights). We show that this simple modification can result in large improvements in adversarial robustness -an area which suffers from robust overfitting -without requiring extra data. Specifically, our reformulation results in 20% relative robustness improvement for CIFAR-100, and 10% relative robustness improvement on CIFAR-10 comparing to traditional weight decay. In addition, this method has other desirable properties, such as less sensitivity to learning rate, and smaller weight norms, which the latter contributes to robustness to overfitting to label noise, and pruning.

1. INTRODUCTION

Modern deep learning models have exceeded human capability on many computer vision tasks. Due to their high capacity for memorizing training examples (Zhang et al., 2021) , their generalization heavily relies on the training algorithm. To reduce memorization, several approaches have been taken including regularization and augmentation. Some of these augmentation techniques alter the network input (DeVries & Taylor, 2017; Chen et al., 2020; Cubuk et al., 2019; 2020; Müller & Hutter, 2021) , some alter hidden states of the network (Srivastava et al., 2014; Ioffe & Szegedy, 2015; Gastaldi, 2017; Yamada et al., 2019) , some alter the expected output (Warde-Farley & Goodfellow, 2016; Kannan et al., 2018) , and some effect multiple levels (Zhang et al., 2017; Yun et al., 2019; Hendrycks et al., 2019b) . Another popular approach to prevent overfitting is the use of regularizers, such as weight decay (Plaut et al., 1986; Krogh & Hertz, 1991) . Such methods prevent over-fitting by eliminating solutions that memorize training examples. Regularization methods are attractive beyond generalization on clean data as they are crucial in adversarial and noisy-data settings. In this paper, we revisit weight decay; a regularizer mainly used to avoid overfitting. The rest of the paper is organized as follows: In Section 2, we revisit tuning the hyper-parameters for weight decay. We introduce adaptive weight decay in Section 3, and further discuss its properties in 4. More specifically, we discuss the benefits of adaptive weight decay in the setting of adversarial training in subsection 4.1, noisy labels in subsection 4.2, and additional properties related to robustness in subsection 4.3.

2. WEIGHT DECAY

Weight decay which encourages weights of networks to have smaller magnitudes (Zhang et al., 2018) has widely been adopted to improve generalization. Although other forms of weight decay have been studied (Loshchilov & Hutter, 2017) , we focus on the popular ℓ 2 -norm variant. More precisely, we look at classification problems with cross-entropy as the main loss and weight decay as the regularizer, which was popularized by Krizhevsky et al. (2017) : Loss w (x, y) = CrossEntropy(w(x), y) + λ wd 2 ∥w∥ 2 2 , (1) 1

