Amos: AN ADAM-STYLE OPTIMIZER WITH ADAPTIVE WEIGHT DECAY TOWARDS MODEL-ORIENTED SCALE

Abstract

We present Amos, a stochastic gradient-based optimizer designed for training deep neural networks. It can be viewed as an Adam optimizer with theoretically supported, adaptive learning-rate decay and weight decay. A key insight behind Amos is that it leverages model-specific information to determine the initial learningrate and decaying schedules. When used for pre-training BERT variants and T5, Amos consistently converges faster than the state-of-the-art settings of AdamW, achieving better validation loss within ≤ 70% training steps and time, while requiring ≤ 51% memory for slot variables. Our code is open-sourced at: https: //anonymous-url.

1. INTRODUCTION

The Adam (Kingma & Ba, 2015) optimizer is widely used for training deep neural networks, demonstrating fast convergence especially in the early stages of training. Although previous works have found issues regarding the theoretical convergence of Adam as the training proceeds (Reddi et al., 2018) , in practice it is remedied by various learning-rate schedules and weight decay (Loshchilov & Hutter, 2019) . Specifically, Adam with linear learning-rate decay and constant weight decay is the standard setting for pre-training large language models such as BERT (Devlin et al., 2019) . However, these decay settings are usually ad-hoc, increase the number of hyper-parameters, and may introduce additional complexities in usage. For example, the linearly decaying learning-rate schedule requires knowing the number of training steps in advance, which makes it nontrivial to continuously train a model after the learning-rate decays to 0. In this work, we present Amos, a new optimizer with a theoretically supported and adaptive schedule for learning-rate and weight decay, which can significantly outperform the state-of-the-art AdamW settings for pre-training language models, provide guidance for hyper-parameter tuning, reduce the memory usage, and train continuously without having to specify the number of training steps a priori. A key insight behind Amos is a hyper-parameter η to be provided by the model architecture, which indicates the expected scale of the trainable weights θ of the model ( § 2), i.e., theoretically we assume that an optimal point θ * exists within the | θ * | ≤ η diameter. Deep neural networks are likely to satisfy such a constraint without degrading performance, because there exist many good local minima; and we show that an appropriate η for Amos can improve generalization and accelerate convergence ( § 5.2). In this work, η is calculated in a consistent way from the input/output scale of neural network components, which is hinted by the model design ( § A.4). Given η, Amos decides a learning-rate per variable, and its L2 regularization will lead the trained weights to the specified scale. The decay of the learning-rate is then determined by the L2 regularizer. Thus, Amos performs better because it can utilize the model-oriented information η efficiently; the name Amos stands for "Adaptive weight decay towards Model-Oriented Scale". Empirically, we focus on the Transformer architecture (Vaswani et al., 2017) since it is pre-dominant in pre-trained language models (Bommasani et al., 2021) , but add additional experiments on LSTM (Gers et al., 2000) and ResNet (He et al., 2016) . We apply Amos to the pre-training of 4 models: BERT (Devlin et al., 2019) , two Transformer variants with relative position representations (Su et al., 2021; Shaw et al., 2018) , and the T5 model (Raffel et al., 2020) ; some with various model sizes and batch sizes. In all experiments, Amos consistently outperforms the state-of-the-art setting, achieving better validation loss within ≤ 70% training steps and time ( § 5.1). Compared to AdamW, the memory usage for slot variables is reduced to ≤ 51% in Amos ( § A.8). In addition, Amos does not calculate learning-rate from a maximum number of training steps, so one can seamlessly continue training from any checkpoints, which is not trivial for AdamW with linear learning-rate decay ( § 5.1).

2. THE ALGORITHM

For notation, we denote model weights by θ, and an online learning algorithm recursively calculates a sequence of weights, θ1 , θ2 , . . ., from initial weights θ0 and training examples z t at each step t = 0, 1, . . . . An optimizer uses the gradient gt = ∇ (z t ; θt ) to compute a weight update θt+1 ← θtδt , in order to minimize the loss function (z; θ). In neural network models, the model weights θ is an array of trainable tensors (i.e. variables) collected from all model components; we view a variable and its slices as subsets of the model weights (e.g. θ ⊆ θ is a variable slice that functions in part of the model). We use a bold letter to denote an array (e.g. θ t , θt ), and the same normal letter to denote a scalar element of that array (e.g. θ t ) for describing element-wise operations. We use tilde for information of the whole model (e.g. θt ), and drop tilde to indicate subsets (e.g. θ t ). To start, we recall the update rule of the RMSProp optimizer (Tieleman & Hinton, 2012) , which computes the weight update by δ t ← α √ vt g t , where α is a scalar learning-rate and v t a running average of the squared gradients g 2 t . Based on this, Adam (Kingma & Ba, 2015) replaces g t with its running average m t (i.e. momentum), and adopts bias correction mt , vt for running averages. Further, AdamW (Loshchilov & Hutter, 2019) allows a schedule for learning-rate α t (depending on the step t) and adds a weight decay: δ t ← α t 1 √ vt mt + γθ t , where γ is a constant hyper-parameter. For pre-training Transformer variants, the learning-rate schedule α t is set to linearly decay to 0 after warm-up. Therefore, a maximum number of training steps before the learning-rate decays to 0 has to be set as a hyper-parameter. Amos, with a similar construction, has the following update rule: δ t ← d t ξη √ vt g t + 1 2 γ t θ t where γ t ← c t ξ 2 vt M 2 (g t ) 2 . Here, M 2 (a) :=foot_0 k k i=1 a 2 i denotes the quadratic mean of entries of an array a ∈ R k . The update rule consists of a gradient descent part (the term containing g t ) and an L2 regularization part (the term containing θ t ) 1 , similar to AdamW. The full Amos is shown in Algorithm 1. We explain several novel aspects below. Model-oriented scale: For each variable a ⊆ θ in the model weights, we specify the scale η(a) we expect a to converge to, i.e. M 2 (a * ) ≈ η for an optimal θ * ⊇ a * . Different variables may have different scale η's. For a common case of a linear transformation, y = xW + u (W , u ⊆ θ, W ∈ R m×n , x ∈ R m ), we calculate η(W ) by assuming that x is random Gaussian with standard deviation σ x , and y random Gaussian with standard deviation σ y ; so we have η(W ) = σ y /(σ x √ m) in order to satisfy the input/output standard deviation (assuming entries of W to be Gaussian as well). Additionally, we set η(u) = σ y /2 to ensure that u has a slightly smaller magnitude than xW . The input/output standard deviation can be hinted by other layers in the model; for example, the activation function GELU (Hendrycks & Gimpel, 2016) usually expects the inputs to have standard deviation ≈ 1, because its non-linearity mostly lies within that range; also the output standard deviation of LayerNormalization (Ba et al., 2016 ) is expected to be 1. For Transformer variants, we will discuss the input/output standard deviation of all types of non-linear layers and derive η in § A.4. Factored initial learning-rate: In Amos, we use ξη as the initial learning-rate, where η is the modeloriented scale specified for each variable, and ξ is a global learning-rate shared across all variables. For online optimizers, the learning-rate is generally affected by both data and model; by factoring the initial learning-rate into ξ and η, we disentangle the two to some extent: While ξ is tuned and may depend on the data, η is calculated from the model architecture.



Following Loshchilov & Hutter (2019), we decouple the gradient of an L2 regularization term (taking the form of a weight decay) apart from the adaptive gradient normalization factor 1 √ vt. When an adaptive optimizer is used, Loshchilov & Hutter (2019) point out that the decoupled weight decay is not equivalent to the L2 regularization without explicit decoupling, and the former is more appropriate. In this work, we always treat L2 regularization as decoupled weight decay, and use the two terms interchangeably.

