Amos: AN ADAM-STYLE OPTIMIZER WITH ADAPTIVE WEIGHT DECAY TOWARDS MODEL-ORIENTED SCALE

Abstract

We present Amos, a stochastic gradient-based optimizer designed for training deep neural networks. It can be viewed as an Adam optimizer with theoretically supported, adaptive learning-rate decay and weight decay. A key insight behind Amos is that it leverages model-specific information to determine the initial learningrate and decaying schedules. When used for pre-training BERT variants and T5, Amos consistently converges faster than the state-of-the-art settings of AdamW, achieving better validation loss within ≤ 70% training steps and time, while requiring ≤ 51% memory for slot variables. Our code is open-sourced at: https: //anonymous-url.

1. INTRODUCTION

The Adam (Kingma & Ba, 2015) optimizer is widely used for training deep neural networks, demonstrating fast convergence especially in the early stages of training. Although previous works have found issues regarding the theoretical convergence of Adam as the training proceeds (Reddi et al., 2018) , in practice it is remedied by various learning-rate schedules and weight decay (Loshchilov & Hutter, 2019) . Specifically, Adam with linear learning-rate decay and constant weight decay is the standard setting for pre-training large language models such as BERT (Devlin et al., 2019) . However, these decay settings are usually ad-hoc, increase the number of hyper-parameters, and may introduce additional complexities in usage. For example, the linearly decaying learning-rate schedule requires knowing the number of training steps in advance, which makes it nontrivial to continuously train a model after the learning-rate decays to 0. In this work, we present Amos, a new optimizer with a theoretically supported and adaptive schedule for learning-rate and weight decay, which can significantly outperform the state-of-the-art AdamW settings for pre-training language models, provide guidance for hyper-parameter tuning, reduce the memory usage, and train continuously without having to specify the number of training steps a priori. A key insight behind Amos is a hyper-parameter η to be provided by the model architecture, which indicates the expected scale of the trainable weights θ of the model ( § 2), i.e., theoretically we assume that an optimal point θ * exists within the | θ * | ≤ η diameter. Deep neural networks are likely to satisfy such a constraint without degrading performance, because there exist many good local minima; and we show that an appropriate η for Amos can improve generalization and accelerate convergence ( § 5.2). In this work, η is calculated in a consistent way from the input/output scale of neural network components, which is hinted by the model design ( § A.4). Given η, Amos decides a learning-rate per variable, and its L2 regularization will lead the trained weights to the specified scale. The decay of the learning-rate is then determined by the L2 regularizer. Thus, Amos performs better because it can utilize the model-oriented information η efficiently; the name Amos stands for "Adaptive weight decay towards Model-Oriented Scale". Empirically, we focus on the Transformer architecture (Vaswani et al., 2017) since it is pre-dominant in pre-trained language models (Bommasani et al., 2021) , but add additional experiments on LSTM (Gers et al., 2000) and ResNet (He et al., 2016) . We apply Amos to the pre-training of 4 models: BERT (Devlin et al., 2019) , two Transformer variants with relative position representations (Su et al., 2021; Shaw et al., 2018) , and the T5 model (Raffel et al., 2020) ; some with various model sizes and batch sizes. In all experiments, Amos consistently outperforms the state-of-the-art setting, achieving better validation loss within ≤ 70% training steps and time ( § 5.1) . Compared to AdamW, the memory usage for slot variables is reduced to ≤ 51% in Amos ( § A.8). In addition, Amos does not calculate learning-rate from a maximum number of training steps, so one can seamlessly continue training from any checkpoints, which is not trivial for AdamW with linear learning-rate decay ( § 5.1).

2. THE ALGORITHM

For notation, we denote model weights by θ, and an online learning algorithm recursively calculates a sequence of weights, θ1 , θ2 , . . ., from initial weights θ0 and training examples z t at each step t = 0, 1, . . . . An optimizer uses the gradient gt = ∇ (z t ; θt ) to compute a weight update θt+1 ← θtδt , in order to minimize the loss function (z; θ). In neural network models, the model weights θ is an array of trainable tensors (i.e. variables) collected from all model components; we view a variable and its slices as subsets of the model weights (e.g. θ ⊆ θ is a variable slice that functions in part of the model). We use a bold letter to denote an array (e.g. θ t , θt ), and the same normal letter to denote a scalar element of that array (e.g. θ t ) for describing element-wise operations. We use tilde for information of the whole model (e.g. θt ), and drop tilde to indicate subsets (e.g. θ t ). To start, we recall the update rule of the RMSProp optimizer (Tieleman & Hinton, 2012) , which computes the weight update by δ t ← α √ vt g t , where α is a scalar learning-rate and v t a running average of the squared gradients g 2 t . Based on this, Adam (Kingma & Ba, 2015) replaces g t with its running average m t (i.e. momentum), and adopts bias correction mt , vt for running averages. Further, AdamW (Loshchilov & Hutter, 2019) allows a schedule for learning-rate α t (depending on the step t) and adds a weight decay: δ t ← α t 1 √ vt mt + γθ t , where γ is a constant hyper-parameter. For pre-training Transformer variants, the learning-rate schedule α t is set to linearly decay to 0 after warm-up. Therefore, a maximum number of training steps before the learning-rate decays to 0 has to be set as a hyper-parameter. Amos, with a similar construction, has the following update rule: δ t ← d t ξη √ vt g t + 1 2 γ t θ t where γ t ← c t ξ 2 vt M 2 (g t ) 2 . Here, M 2 (a) :=foot_0 k k i=1 a 2 i denotes the quadratic mean of entries of an array a ∈ R k . The update rule consists of a gradient descent part (the term containing g t ) and an L2 regularization part (the term containing θ t ) 1 , similar to AdamW. The full Amos is shown in Algorithm 1. We explain several novel aspects below. Model-oriented scale: For each variable a ⊆ θ in the model weights, we specify the scale η(a) we expect a to converge to, i.e. M 2 (a * ) ≈ η for an optimal θ * ⊇ a * . Different variables may have different scale η's. For a common case of a linear transformation, y = xW + u (W , u ⊆ θ, W ∈ R m×n , x ∈ R m ), we calculate η(W ) by assuming that x is random Gaussian with standard deviation σ x , and y random Gaussian with standard deviation σ y ; so we have η(W ) = σ y /(σ x √ m) in order to satisfy the input/output standard deviation (assuming entries of W to be Gaussian as well). Additionally, we set η(u) = σ y /2 to ensure that u has a slightly smaller magnitude than xW . The input/output standard deviation can be hinted by other layers in the model; for example, the activation function GELU (Hendrycks & Gimpel, 2016) usually expects the inputs to have standard deviation ≈ 1, because its non-linearity mostly lies within that range; also the output standard deviation of LayerNormalization (Ba et al., 2016 ) is expected to be 1. For Transformer variants, we will discuss the input/output standard deviation of all types of non-linear layers and derive η in § A.4. Factored initial learning-rate: In Amos, we use ξη as the initial learning-rate, where η is the modeloriented scale specified for each variable, and ξ is a global learning-rate shared across all variables. For online optimizers, the learning-rate is generally affected by both data and model; by factoring the initial learning-rate into ξ and η, we disentangle the two to some extent: While ξ is tuned and may depend on the data, η is calculated from the model architecture. Adaptive L2 regularization: Unlike AdamW which uses a constant γ for weight decay, the Amos weight decay γt is intended to control the scale of trained variables, rather than regularize the loss function; so γ t decays to 0 at t → ∞ to be less biased, and it is adaptive in the sense that γt depends on gt so that the variables not getting gradient updates are not regularized. Thus, the L2 regularization is robust to sparse gradients, and it does not introduce any additional hyper-parameter. We will give a heuristic derivation of the form of γ t in § 4. Decay factors: dt , ct are per-parameter decay factors such that d 0 = c 0 = 1 and d t , c t monotonically decrease to 0 at t → ∞. We provide a theoretical derivation of the asymptotic behavior of these factors in § A.2, together with a default form that works well empirically in all our experiments. The decay factors do not depend on a maximum number of training steps, thus enabling arbitrary continuous training. Memory Reduction: Most previous optimizers operate element-wise, so the slot variables (e.g. the running average ṽt , mt in Adam) have the same shape as θ, which can be memory consuming. In Amos, two slot variables (ṽ t , bt in Algorithm 1) are shared by certain slices in the model weights, reducing the memory usage of these slot variables. For example, if R m×n W ⊆ θ is a linear transformation, the corresponding v t ∈ R 1×n is shared by the input dimension of W , reducing the memory usage by m times. As a result, in Equation 1 and Algorithm 1, v t , b t , c t and d t are reduced and become scalars, to be used and updated by vector-valued g t and θ t . In this work, we reduce the input dimension of linear transformations, the embed dimension of embedding matrix, and all dimensions for other variables by default. An ablative study with different settings is found in § A.8.

Algorithm 1

The Amos optimizer at step t. Input gt = ∇ (z t ; θt ): The gradient of loss on a random example z t . Input θt : Trainable model weights at step t. Input ṽt-1 , bt : Slot variables of shape broadcastable to θ, initialized to 0. Input (Optional) mt : Slot variable of the same shape as θ, initialized to 0 for momentum. Hyper-parameter ξ: Global learning-rate. Hyper-parameter η: Expected scale for model weights θ. Hyper-parameter ct : Decay factor for L2 regularization. Defaults to c t = 1 + 1 4 √ ξb t -1 2 . Hyper-parameter dt : Decay factor for learning-rate. Defaults to d t = 1 + 1 4 √ ξηb t -1 . Hyper-parameter β ∈ [0, 1): Exponential decay rate for running average ṽt . 1: (Optional) g t ← χ max(χ, |g t |) g t Gradient clipping with hyper-parameter χ > 0. 2: v t ← βv t-1 + (1 -β) M 2 (g t ) 2 Running average of squared gradients. 3: vt ← v t /(1 -β t ) Bias correction. 4: γ t ← c t ξ 2 vt M 2 (g t ) 2 Adaptive L2 regularization strength. Hyper-parameter Tuning The running average v t in Amos is a low-cost estimator for E[M 2 (g t ) 2 ], where the expectation is taken over the example z t randomly drawn from the training data. It is similar to v t in Adam except that the mean square M 2 (g t ) 2 is used instead of element-wise g 2 t , due to the memory reduction. Thus, the hyper-parameter β behaves similarly to β 2 in Adam: Since the estimator mostly depends on the previous 1/(1 -β) steps, β should be close enough to 1 to make the estimator accurate, but not too large that the model weights in the previous 1/(1 -β) steps differ too much from the current step. We set β = 0.999 by default (the same as β 2 in Adam), and it is found that β should be smaller with larger batch size (Shazeer & Stern, 2018; Liu et al., 2019) . The global learning-rate ξ can depend on the step t to follow a warm-up schedule at the beginning of training, but a schedule with learning-rate decay is not necessary, since the decay factor dt is already included in Algorithm 1; most of the time ξ remains a constant. While this constant is the major hyper-parameter to be tuned, a good rule of thumb is to set ξ to the same order of magnitude as 1/ √ N , where N is the number of independent batches in the training set (see § 4 for a justification). This value is usually larger than the typical learning-rates used for Adam. It also implies that ξ should be in proportion to the square-root of the batch size, which we observe in practice as well ( § A.5). In addition, Algorithm 1 includes optional gradient clipping and momentum. Momentum in Amos is applied after the main update rule (unlike Adam which applies it before). It can improve performance for pre-training Transformer variants, but consume memory because the slot variable mt must have the same shape as θ. When momentum is applied, its decay rate µ is typically set to 0.9.

3. RELATED WORK

Besides RMSProp (Tieleman & Hinton, 2012) , Adam (Kingma & Ba, 2015) and AdamW (Loshchilov & Hutter, 2019) , the number of previous works on optimization is vast, so we focus on some directly related alternatives below. Also, we note that Amos is a stochastic first-order optimizer, in contrast to recent progress in the second-order optimization methods (Gupta et al., 2018) . The convergence of stochastic optimizers has been studied in terms of stochastic approximation (Bottou, 1998) , regret (Hazan, 2019) , or nonconvex stochastic programming (Ghadimi & Lan, 2013) . In particular, Reddi et al. ( 2018) observed cases of non-convergence of Adam (with constant learning-rate) and proposed a fix. In our work, we analyze the behavior of Amos in an intuitive and heuristic manner, but leave a rigorous convergence proof (e.g. based on regret) to future work.

AdaGrad:

The update rule of AdaGrad (Duchi et al., 2011)  is δ t ← α √ bt g t , where b t+1 ← b t + g 2 t is similar to the b t in Algorithm 1, in the sense that both AdaGrad and Amos use a (weighted) sum of squared gradients to decay learning-rates. Such decay is "adaptive" because the learning-rate will decay more for parameters getting more updates, which is suitable for sparse gradients. On the other hand, conventional wisdom is that the learning-rate in AdaGrad might decay "too fast" in some cases, which makes the convergence slow, and Adam mitigates this issue by using a running average of squared gradients instead of the decay factor. However, the AdamW setting suggests that the normalization factor by running average of squared gradients is not a replacement for learning-rate decay; one still needs a linearly decaying learning-rate schedule for better convergence. Thus, Amos integrates both the Adam-style gradient normalization and the AdaGrad-style learning-rate decay; with gradient normalization, the learning-rate can actually decay faster and it converges faster. SGD with L2 Regularization: For the classic Stochastic Gradient Descent (SGD) algorithm, it is recommended to decay the learning-rate by the factor 1 λt , where λ is the smallest eigen-value of the Hessian (Murata, 1998) . Although λ is generally unknown, adopting an L2 regularizer of strength λ guarantees that λ ≥ λ , so one can set the learning-rate to 1 λ t (Bottou, 2012) . In Amos, we adopt a similar idea to heuristically derive the learning-rate decay (see § A.3 for more detailed discussion), by connecting the decaying speed with the strength of L2 regularization (i.e., the L2 strength γ t in Algorithm 1 also appears in the update of b t ). Unlike SGD, both the learning-rate and L2 regularization in Amos decay adaptively. The adaptive L2 regularization, in particular, is a novel component unseen in previous optimizers.

LAMB:

The LAMB optimizer (You et al., 2020) and its origin LARS (You et al., 2017) share several similar aspects with Amos. The idea of layer-wise learning-rate in LAMB and LARS is similar to the per-variable learning-rate η in Amos; they all normalize the gradients in some way; and they all imply scaling up the learning-rate as the batch size increases. In our experiments, scaling the global learning-rate of Amos in proportion to the square-root of the batch size indeed works ( § A.5), although we leave a systematic study of scaling-up to extremely large batch sizes and comparing with LAMB and LARS to future work. AdaFactor: In Adam, the slot variable ṽt for maintaining running average of squared gradients requires the same amount of memory as the model weights θ. In order to reduce the memory usage, AdaFactor (Shazeer & Stern, 2018) proposes to use nonnegative matrix factorization to decompose any matrix into two vectors. In contrast, Amos reduces memory usage by simply reducing some axes of the slot variables and broadcasting to the shape of model weights. This reduction is more efficient than AdaFactor, and our experiments suggest that it will not degrade performance ( § A.8).

4. DERIVATION OF AMOS

In this section, we heuristically derive the Amos update rule (Equation 1). We start from a general form of the weight update for a given variable θ, θ t+1 = θ t -α t g t where gt = ∇ (z t ; θt ), and gradually pin down to the specific form of Equation 1. Here, the step size α t > 0 is a scalar (due to our memory reduction mechanism in § 2) and is shared across the elements of the vector-valued g t , θ t ∈ R k . We are focusing on a subset of model parameters, but furthermore note that α t may differ for different variables. Then, the following Descent Lemma (Murata, 1998) provides a sanity check for a wide range of possible forms of α t , while also suggests some constraints. Its proof can be found in § A.1. Lemma 4.1 (Descent Lemma). If α t does not depend on z t , then there exists t > 0 such that E t [E t+1 [ (z t+1 ; θt+1 )]] ≤ E t [ (z t ; θt )] for any α t < t , where E t [•] denotes the expectation taken over the random example z t drawn from the training data at step t, while conditioned on z t-1 , . . . , z 0 of the previous steps. In light of Lemma 4.1, we require (I) α t does not depend on z t (but may differ for different variables), and (II) α t decays to 0 at t → ∞, so the step-size can be sufficiently small that the Descent Lemma applies and Equation 2 will always make progress on average. In the Amos update rule, α t = d t ξη √ vt and vt depends on z t , which seems to violate requirement (I) above. However, vt should be regarded as an approximation of E[M 2 (g t ) 2 ], where E[•] denotes the expectation taken over examples randomly drawn from the training data, which is z t independent. Next, we add an L2-regularization term to Equation 2: θ t+1 = θ t -(α t g t + ρ t θ t ) where ρ t ≥ 0 can depend on g t (hence "adaptive"), but we require (III) E[ρ t ] does not depend on g t . The intuition behind is that an L2-regularization should have the same strength across all variables, rather than be affected by the typical gradient magnitude on each variable. It is the same intuition that motivates the weight decay decoupled from gradient adaptive factors (Loshchilov & Hutter, 2019) . The first challenge for Amos is to keep a balance between α t and ρ t , so that M 2 (θ t ) will converge to the pre-specified, per-variable hyper-parameter η. In order to achieve this, we will declare some intuitions on the largeness of g t , E[g t ] and ρ t θ t , as a guide for our heuristic derivation. For deep neural networks, g t 's upon different z t 's appear to be randomly noisy, so they will cancel out when being averaged to E[g t ]; which means that M 2 (E[g t ]) is usually much smaller than M 2 (g t ). On the other hand, θ t does not depend on z t , and it changes slowly between different steps, so the update by ρ t θ t is easier to accumulate than α t g t . This means that the magnitude of ρ t θ t can be kept smaller than α t g t while still compete with α t E[g t ]. In Amos, ρ t = d t 1 2 c t ξ 2 vt M 2 (g t ) 2 decays to 0 faster than α t (due to the extra decay factor c t ), which we assume will make ρ t θ t small enough compared to α t g t , when t is large. Quantitatively, we consider the error εt = θtθ * , where θ * is a local minimum. Equation 3 implies M 2 (ε t+1 ) 2 = M 2 (ε t ) 2 - 2 k (α t g t + ρ t θ t ) • ε t + M 2 (α t g t + ρ t θ t ) 2 ≈ M 2 (ε t ) 2 - 2 k (α t g t + ρ t θ t ) • ε t + α 2 t M 2 (g t ) 2 , where we investigate a time point t large enough that the model nearly converges. At this point, ρ t θ t is small compared to α t g t , so it can be approximately omitted in the third term. And we should have E[g t ] ≈ 0 and M 2 (θ t ) ≈ η if the trained weights converge to scale η. So taking E[•] of Equation 4, we should get E[M 2 (ε t+1 ) 2 ] ≈ M 2 (ε t ) 2 - 2 k E[ρ t ]θ t • ε t + α 2 t E[M 2 (g t ) 2 ]. Furthermore, in order for the model to converge, we should have E[M 2 (ε t+1 )foot_1 ] ≤ M 2 (ε t ) 2 from the above. Hence, we should have α 2 t E[M 2 (g t ) 2 ] ≤ 2 k E[ρ t ]θ t • ε t ≤ 2E[ρ t ] M 2 (θ t ) M 2 (ε t ) ≈ 2E[ρ t ]η M 2 (ε t ) as a necessary condition for the trained weights to converge to scale η. By setting ρ t to the smallest possible, we get 2ρ t η M 2 (ε t ) = α 2 t M 2 (g t ) 2 , which is an important relation connecting ρ t to α t . We require (IV) Equation 5 to be satisfied throughout the course of training, and use it ubiquitously in our derivation. It is out of the scope of this work to prove whether Equation 5 actually makes M 2 (θ t ) converge to η; but the requirements so far already determine a basic form of the Amos update rule (as shown in Lemma 4.2 below), and our experiments suggest that Amos indeed brings the trained weights to the specific scale ( § 5.2). Lemma 4.2 (Basic Form of Amos). Assume Equation 5, requiring that α t does not depend on z t and E[ρ t ] does not depend on g t . Then, we have α t ∝ 1 E[M 2 (g t ) 2 ] and ρ t ∝ M 2 (g t ) 2 E[M 2 (g t ) 2 ] . The proof is found in § A.1. It is noteworthy that the Adam-style gradient normalization naturally occurs in α t . Based on Lemma 4.2, Amos is derived by specifying the initial learning-rate and decay schedule. For that, we need the following assumption to quantify the largeness of g t and E[g t ]. Assumption 1. A scalar ξ > 0 exists such that M 2 (E[g t ]) E[M 2 (g t ) 2 ] ≥ ξ for all t and across all variables. This assumption formalizes two intuitions, i.e. randomly noisy g t will cancel out when being averaged to E[g t ] (so ξ has a small value), and as the training proceeds, M 2 (g t ) will decrease 2 along with M 2 (E[g t ]) (so the ratio remains larger than a constant ξ > 0). Assumption 1 is verified by our experiments ( § A.6). The value of ξ is related to the global learning-rate in Amos (as shown in Lemma 4.3 below), which is tuned as a hyper-parameter in practice. However, we also provide an intuitive estimation of ξ, which is usually a good start for hyper-parameter tuning. The intuition is to view the canceling out of g t averaged to E[g t ] as similar to the average of N i.i.d. samples drawn from a distribution of mean 0. According to the Law of Large Numbers, the variance of the average (i.e. M 2 (E[g t ]) 2 ) is about 1/N of the variance of the distribution (i.e. E[M 2 (g t ) 2 ]), so ξ ≈ 1 √ N . In reality, the gradients of deep neural networks, computed over mini-batches, appear to be highly random. So N is usually of the same order of magnitude as the number of independent batches in the training data. Now, we can derive the optimal initial learning-rate as below, under an ideal condition that g 0 points to the same direction as ε 0 . The proof is found in § A.1. Lemma 4.3 (Initial Learning-rate). Assume Equation 2, Assumption 1, α 0 = α/ E[M 2 (g 0 ) 2 ] and that g 0 points to the same direction as ε 0 . Then, E[M 2 (ε 1 ) 2 ] ≤ M 2 (ε 0 ) 2 -2αξ M 2 (ε 0 ) + α 2 and the RHS achieves minimum at α = ξ M 2 (ε 0 ) ≈ ξη. Lemma 4.3 suggests the initial learning-rate α 0 = ξη √ E[M2(g0) 2 ] . Then, we get ρ 0 = 1 2 ξ 2 M2(g0) 2 E[M2(g0) 2 ] from Equation 5. By adding the decay factors, we reveal the Amos update rule (Equation 1): δ t ← α t g t + ρ t θ t , where α t = d t ξη E[M 2 (g t ) 2 ] and ρ t = d t 1 2 γ t = d t 1 2 c t ξ 2 M 2 (g t ) 2 E[M 2 (g t ) 2 ] . (6) Here, d t and c t monotonically decrease to 0 and d 0 = c 0 = 1. In particular, c t decaying to 0 ensures that ρ t decays to 0 faster than α t , so ρ t θ t can be sufficiently small compared to α t g t for large t, which justifies the approximation of Equation 4. In § A.2, we will further derive that c t = (1 + pb t ) -1 2 and d t = (1 + qb t ) -1 , where p, q are constants, together with the update rule of b t . The specific p = 1 4 √ ξ and q = 1 4 √ ξη are found through experiments and work well in practice.

5. EXPERIMENTS

We focus on the Transformer model (Vaswani et al., 2017) , and pre-train several variants as below. BERT: A Transformer Encoder model with learned position embeddings (Devlin et al., 2019) . We experiment with the base (12-layer 768-hidden) and large (24-layer 1024-hidden) model sizes. RoPE: A Transformer Encoder variant with the Rotary Position Encoding (Su et al., 2021) . RoPE is integrated in some recent large-scale language models (Chowdhery et al., 2022) . It encodes relative positions but the encoding is not learned. We experiment with the base (12-layer 768-hidden) and large (24-layer 1024-hidden) model sizes. Relative Position Embeddings (RPE): A Transformer Encoder variant with learned relative position embeddings (Shaw et al., 2018) . It achieves better performance but the pre-training is more costly on TPU (Tian et al., 2021) . We experiment with the base (12-layer 768-hidden) model size. T5 Encoder-Decoder (T5): A Transformer Encoder-Decoder model implemented by Raffel et al. (2020) . We experiment with the large (24-layer 1024-hidden) model size. For encoder only models, we pre-train with the Masked Language Modeling loss (Devlin et al., 2019) on Wikipediafoot_2 and the Books Corpus (Zhu et al., 2015) . In all experiments, across different model architectures, model sizes, batch sizes, datasets and loss functions, Amos (pink curve) outperforms the state-of-the-art AdamW setting, with the loss always significantly lower beyond 30% of the training procedurefoot_3 , and the validation loss achieving the final value of AdamW-300k within < 70% training steps or timefoot_4 . For BERT-base (Figure 1a ), Amos achieves the same within only 145k steps (< 50%), and the Amos checkpoint at 150k outperforms the final checkpoint of AdamW-300k in fine-tuning on MNLI (Williams et al., 2018) as well ( § A.7). In Figure 1a , we also tried starting from the final checkpoint of AdamW-200k and resetting the learning-rate as if it is linearly decaying to max training step 300k (AdamW-Cont.). The loss spikes higher and does not go further lower than the value at 200k, suggesting that the hyper-parameter of max training steps has to be set a priori, and continuous training is not trivial with AdamW. In addition, we tried a learning-rate schedule (AdamW-rsqrt) that takes the same value at step 10k but adopts a decay in proportion to t -1/2 (where t is the step) beyond. Although this setting does not require max training steps, it converges slower than both AdamW-200k and AdamW-300k. For the RPE model (Figure 1c ), we tried setting η of the relative position embeddings to a smaller value (Amos-*Scale, see § A.4 for more details), and found significant impact especially on the validation loss. Similar results are observed when we change η for a certain type of layers in the BERT-large model (Figure 2a , Amos-*Scale, see § A.4). It suggests that the model-specific information η indeed contributes to the performance of Amos, which according to previous work (Kaplan et al., 2020) is unlikely achieved by tuning the learning-rate schedule alone.

5.2. SCALES OF TRAINED VARIABLES

In Figure 3 we show how the scale of entries of some variables evolve as the training proceeds. With AdamW, both the token embeddings and the bias converge to similar scales (Figure 3ab ); while with Amos the token embeddings converge to ≈ 1/d (where d is the hidden size) and the bias to ≈ 0.5, as specified by the hyper-parameter η. It shows that the algorithm of Amos can lead variables to converge to drastically different scales, which is unlikely with AdamW. In Figure 3c , comparing Amos and Amos-*Scale, the relative position embeddings in a typical layer of the RPE model converge to different scales, which shows that the scale is indeed controlled by the hyper-parameter η. Recall that Figure 1c shows this has impact on the performance. In order to further illustrate the relation among the optimizer, validation performance and the scale of variables, we train a single layer LSTM on the Penn Tree Bank (PTB) corpus (Marcus et al., 1993) . The model size is 256 for hidden states and 1024 for memory. We set dropout rate 0.55 for hidden states (which is important for training on PTB) and 0.1 for memory. Sequence length and batch size are set to 64. We compare Amos, AdamW, and Adam (without weight decay). For Amos, the global learning-rate is set to 0.01 and η for the LSTM kernel is set to 1 √ 32 (calculated from input scale 1 4 , input dimension 512 and output scale 1). For AdamW and Adam, the learning-rate is set to 0.0015 (about the same as Amos for the LSTM kernel), and the weight decay is set to 0.01 for AdamW. The results are shown in Figure 4 . Without weight decay, the scale of the LSTM kernel trained by Adam can keep increasing; so Adam is better than AdamW on training loss but worse on validation perplexity (i.e. the model trained by Adam generalizes worse). On the other hand, Amos achieves the same training loss as Adam, while keeping the scale of the kernel as specified. It results in a much better validation perplexity which matches the state-of-the-artfoot_5 . Overall, we conclude that controlling the scale of trained variables can help the generalization performance of deep neural networks, and the model-specific information from η enables Amos to do this.

6. CONCLUSION

We have presented the Amos optimizer, which uses an adaptive L2 regularizer to control learning-rate decay and guides trained weights towards a specified model-oriented scale. It demonstrates faster convergence than the state-of-the-art in pre-training language models, where the training process is long and decaying schedule is crucial. On the other hand, its ability to control the scale of trained weights also brings better generalization to small models such as a single layer LSTM. Besides pre-training, we expect Amos to have advantages in fine-tuning as well, especially for multi-modal models that combine heterogeneous components of varied scales and/or pre-trained with different recipes. Hopefully, the model-specific information η can help us fine-tune such models that were previously difficult with other optimizers (Liang et al., 2022; Kumar et al., 2022) . Proof of Lemma 4.2. In the LHS of Equation 5, only ρ t can depend on z t ; while in the RHS, M 2 (g t ) 2 depends on z t but α t does not. In order to satisfy Equation 5 on every z t , it is necessary that ρ t has a M 2 (g t ) 2 factor: ρ t ∝ M 2 (g t ) 2 . Moreover, we require that E[ρ t ] does not depend on g t , so ρ t should be normalized by E[M 2 (g t ) 2 ]: ρ t ∝ M2(gt) 2 E[M2(gt) 2 ] . This, substituted back into Equation 5, implies that α t ∝ 1 √ E[M2(gt) 2 ] . Proof of Lemma 4.3. Equation 2 implies ε 1 = ε 0 -α 0 g 0 . Taking E[M 2 (•)] of this equation, we have E[M 2 (ε 1 ) 2 ] = M 2 (ε 0 ) 2 - 2 k α t E[g 0 ] • ε 0 + α 2 0 E[M 2 (g 0 ) 2 ] = M 2 (ε 0 ) 2 - 2 k α E[g 0 ] • ε 0 E[M 2 (g 0 ) 2 ] + α 2 . Since g 0 and ε 0 point to the same direction, we have 1 k E[g 0 ] • ε 0 = M 2 (E[g 0 ]) M 2 (ε 0 ). By Assumption 1 we have M 2 (E[g 0 ])/ E[M 2 (g 0 ) 2 ] ≥ ξ. Hence, E[M 2 (ε 1 ) 2 ] ≤ M 2 (ε 0 ) 2 -2αξ M 2 (ε 0 ) + α 2 . The RHS above is a quadratic function of α, which achieves minimum at α = ξ M 2 (ε 0 ). Finally, since (usually) θ 0 is initialized close to 0, and M 2 (θ * ) ≈ η, we have M 2 (ε 0 ) ≈ η.

A.2 HEURISTIC DERIVATION OF DECAY FACTORS

Substituting Equation 6into Equation 5, we get the following equivalent of Equation 5: c t M 2 (ε t ) = d t η. Without knowing any specific relation among g t , θ t and ε t , we found it difficult to theoretically decide an optimal c t . Given that c t decreases to 0, we set c t to decrease according to M 2 (ε t ) in Amos, i.e. c t ∼ r M 2 (ε t ), where r is a constant and ∼ denotes asymptotically equal at t → ∞. Thus, by Equation 7we have d t ∼ r η M 2 (ε t ) 2 . We will analyze the evolution of M 2 (ε t ) 2 to derive c t and d t . Taking E[•] of Equation 4, and applying Equation 6 and Equation 7, we get E[M 2 (ε t+1 ) 2 ] ≈ M 2 (ε t ) 2 - 2 k c t ξ M 2 (ε t ) E[g t ] E[M 2 (g t ) 2 ] + d t 2 E[γ t ]θ t •ε t + c 2 t ξ 2 M 2 (ε t ) 2 . (8) As in the derivation of the initial learning-rate, we make an optimistic estimation that E[g t ] and ε t have the same direction. Then, applying Assumption 1 we have 1 k E[g t ] E[M 2 (g t ) 2 ] • ε t = M 2 (E[g t ]) E[M 2 (g t ) 2 ] M 2 (ε t ) ≥ ξ M 2 (ε t ), and Equation 8implies E[M 2 (ε t+1 ) 2 ] ≤ M 2 (ε t ) 2 -2c t ξ 2 M 2 (ε t ) 2 - d t k E[γ t ]θ t • ε t + c 2 t ξ 2 M 2 (ε t ) 2 ≤ M 2 (ε t ) 2 -c t ξ 2 M 2 (ε t ) 2 - d t k E[γ t ]θ t • ε t (9) where in the last equation we have used the fact that c t ≤ 1. Now, in order to estimate θ t • ε t , we assume that θ t will be evenly distributed on the hypersphere of radius M 2 (ε t ) around θ * as the training proceeds. Then, if k ≥ 3, for most θ t from the distribution we will have θ * • ε t ≈ 0. In this case, we have 1 k θ t • ε t = 1 k (ε t + θ * ) • ε t ≈ M 2 (ε t ) 2 , and "on average" it is safe to assumefoot_6 that 1 k θ t • ε t ≥ q M 2 (ε t ) 2 for some constant q > 0. Then, Equation 9becomes E[M 2 (ε t+1 ) 2 ] ≤ M 2 (ε t ) 2 -E[γ t ](1 + d t q) M 2 (ε t ) 2 (10) where we have used the fact that E[γ t ] = c t ξ 2 . In light of Equation 10, we consider the following asymptotic difference equation: e t+1 ∼ e t -γ t (1 + d t q)e t (11) where e t is intended to follow the asymptotic behavior of M 2 (ε t ) 2 . Since we have d t ∼ r η M 2 (ε t ) 2 , it is natural to assume d t ∼ r η e t . Then, we transform Equation 11 as the following: 1 e t+1 ∼ 1 e t • 1 1 -γ t (1 + d t q) ∼ 1 e t 1 + γ t (1 + d t q) ∼ 1 e t + γ t ( 1 e t + qr η ), in which we have used the approximation 1/(1 -x) ∼ 1 + x applied to x = γ t (1 + d t q). Thus, the update rule of b t in Algorithm 1 can be revealed by setting b t = η qr 1 e t : b t+1 = b t + γ t (b t + 1). And d t ∼ r η e t implies d t ∼ 1 qb t , so we set d t = 1 1 + qb t to satisfy both the asymptotic behavior and d 0 = 1. Similarly, since c t ∼ r M 2 (ε t ) we have c t ∼ 1 √ pb t where p = q rη . So we set c t = 1 √ 1 + pb t to satisfy the asymptotic behavior and c 0 = 1.

A.3 CONNECTION TO SGD

The derivation of decay factors in Amos ( § A.2) is largely inspired by SGD (Murata, 1998) . In this section, we recall the theory of learning-rate schedule of SGD and discuss its relation with Amos. The update rule of SGD is simply δ t ← α t g t , where α t is a scalar learning-rate. It is recommended to set the learning-rate schedule to α t = α 1+αλt , where α is the initial learning-rate and λ is the smallest eigen-value of the Hessian (Bottou, 2012) . This is based on the following discussion. Lemma A.1. Assume θt is in a neighborhood of a local minimum θ * , such that the gradient E[g t ] is approximated by H εt via Taylor expansion. Here, H = E[∇ 2 (z t ; θ * )] is the Hessian at θ * . Let 0 < λ be the smallest eigen-value of H. Then, E[M 2 (ε t+1 ) 2 ] ≤ M 2 (ε t ) 2 -2λα t M 2 (ε t ) 2 + α 2 t E[M 2 (g t ) 2 ] (12) and the minimum of RHS of Equation 12 is achieved by 2, we get Equation 12. Now the RHS is a quadratic function of α t , and it takes minimum at Equation 13. So the lemma follows. α t = λ M 2 (ε t ) 2 E[M 2 (g t ) 2 ] and E[M 2 (ε t+1 ) 2 ] ≤ M 2 (ε t ) 2 - λ 2 M 2 (ε t ) 4 E[M 2 (g t ) 2 ] . ( ) Proof. Since θ * is a local minimum, we have E[∇ (z t ; θ * )] = 0 and E[g t ] ≈ H εt , where H is positive definite. Given λ the smallest eigen-value of H, we have E[g t ] • εt ≥ λ εt 2 . Applying this to E[M 2 (•)] of Equation Note that both Amos and SGD analyze the evolution of M 2 (ε t ) 2 by estimating α t g t • ε t . For SGD this is achieved by approximating E[g t ] with the Hessian. For Amos, on the other hand, we have to make Assumption 1 due to the gradient normalization factor 1/ E[M 2 (g t ) 2 ]. In both cases, the learning-rate decay is derived by setting α t in terms of M 2 (ε t ) so that M 2 (ε t ) 2 decreases fast, then solve the asymptotic behavior of M 2 (ε t ). Heuristic derivation of α t : We assume lim t→∞ E[M 2 (g t ) 2 ] = ν > 0. In light of Equation 13, we consider the following asymptotic difference equation: e t+1 ∼ e t - λ 2 ν e 2 t ( ) where e t is intended to follow the asymptotic behavior of M 2 (ε t ) 2 . We transform Equation 14as: 1 e t+1 ∼ 1 e t • 1 1 -λ 2 ν e t ∼ 1 e t (1 + λ 2 ν e t ) = 1 e t + λ 2 ν so we have 1 e t ∼ λ 2 ν t. Now, since α t = λ M 2 (ε t ) 2 E[M 2 (g t ) 2 ] we have α t ∼ λ ν e t ∼ 1 λt . So α t = α 1 + αλt satisfies both the asymptotic behavior and α 0 = α. In the above derivation, the assumption lim t→∞ E[M 2 (g t ) 2 ] = ν > 0 states that E[M 2 (g t ) 2 ] will converge to some non-zero value and will not further decrease. This is often described intuitively as "the stochastic noise of sampled gradients does not vanish", a characteristic feature in the theory of SGD. It is in drastic contrast with Assumption 1: We assume that E[M 2 (g t ) 2 ] decreases along with M 2 (E[g t ]) in Amos. Ma et al. (2018) pointed out that the vanishing of E[M 2 (g t ) 2 ] might lead to faster convergence; but to our knowledge, Amos is the first work to use the vanishing of E[M 2 (g t ) 2 ] to actually develop an optimizer that empirically converges faster. For SGD, the hyper-parameter λ is generally unknown; but if we adopt an L2 regularizer of strength λ , it is guaranteed that λ ≥ λ , so one can safely set the learning-rate to α 1+αλ t (Bottou, 2012) . In Amos, the strength of L2 regularization γt takes a similar role in controlling the speed of learning-rate decay. We expect this work to inspire more theoretical investigation into this principle.

A.4 THE CALCULATION OF η

As explained in § 2, for a linear transformation y = xW + u (W , u ⊆ θ, W ∈ R m×n , x ∈ R m ), we set η(W ) = σ y /(σ x √ m) and η(u) = σ y /2 , where σ x is the standard deviation of entries of x and σ y the standard deviation of entries of y. The values of σ x and σ y are constrained by connected layers, and non-linear layers usually expect entries of input/output tensors from some approximate range. In Table 1 , we show 3 types of non-linear layers that occur in Transformer, and specify their input/output range (i.e. expected standard deviation) used for calculating η. For activations, e.g. GELU (Hendrycks & Gimpel, 2016) in the Multi-Layer Perceptron (MLP) block, the input range is set to 1 because the non-linearity of the activation function mostly lies within that range; and the output range is set to 1/2 because the activation function, as similar to ReLU (Nair & Hinton, 2010) , will map negative values (which account for 1/2 of the input dimension) to close to 0 and approximately retain positive values. For Softmax of n classes, the input range is set to 1 because the derivative of exp(x) is close to 1 within the |x| ≤ 1 range (so Softmax is most sensitive to values within this range); and the output range is set to 1/n because the output is an n-dimension vector of L2 norm ≤ 1 (so the quadratic mean of entries ≤ 1/n). For LayerNormalization (Ba et al., 2016) , the input range is arbitrary because the input will be normalized. The output range is expected to be 1. We will discuss the calculation of η for specific models in the next sub-sections.

A.4.1 BERT, ROPE AND RPE

For BERT, RoPE and RPE, the multi-headed attention layer receives the hidden state x, and the linear transformations xQ (i.e. the query) and xK (i.e. the key) are expected to have standard deviation 1 so that the dot-product 1/h(xQ) • (xK) (i.e. attention score) has standard deviation 1 as well (and this is why there is the scaling factor 1/h, where h is the size per-head), which is expected by the Softmax for calculating the attention probability. Therefore, the output ranges of Q and K are 1. For RoPE, the dot-product is replaced by a bi-linear form which encodes relative positions, but this does not change the scale because the bi-linear form is orthogonal. For other linear transformations in the model, the outputs are either fed into the activation function of an MLP (which requires input range 1), or serve as a summand in a Residual Connection where the residual part comes from a LayerNormalization (which has range 1). So all the linear transformations have output range 1 in these model architectures. Thus, we set the η of bias in all linear transformations to 0.5, and the η for kernels is categorized by the input range and dimension, as we show in Table 2 . The input embeddings (i.e. token embeddings, position embeddings and segment-type embeddings) are inputs to LayerNormalization so their scales are not constrained there; but the token embeddings are also used as the linear kernel for producing the logits of token generation, which expects input range 1 (because it comes from LayerNormalization) and input dimension d (where d is the hidden size), so η is set to 1/d. For the linear kernel of the MLP output layer (MLP/Dense2/Kernel), the input range is 1/2 because it comes from a non-linear activation, and input dimension m is the size of intermediate activation in the MLP, so η is 2/m. For all other linear kernels, the input range is 1 because it comes from LayerNormalization, and input dimension is the hidden size d. So η is 1/d. The relative position embeddings in the RPE model is used as input to the key and value transformations at each layer, similar to the hidden state. We set η to 0.5 so its scale is close to the hidden state (which has scale 1) but will not dominate it. Experiments with Amos-*Scale In § 5.1, we have experimented with pre-training RPE and BERTlarge with different η (Amos-*Scale). For RPE (Figure 1c ), we tried setting η of the relative position embeddings to 1/d instead of 0.5. For BERT-large (Figure 2a ), we tried setting η of MLP/Dense2/Kernel to 1/d instead of 2/m. They both had impact on performance. Especially for BERT-large, 1/d and 2/m only differ by a √ 2 factor (because m = 4d), still the performance gap is significant. It illustrates the importance of setting η appropriately.

A.4.2 T5

For the T5 model, η is set as in Table 3 . It is different from Table 2 , due to several differences between the T5 architecture and BERT, as discussed below. 1. Linear transformations do not have bias terms in T5. 2. Attention score is calculated by (xQ) • (xK) in T5, without the scaling factor. Instead, the query kernel Q is initialized to a smaller scale 1/(hd), with an extra 1/h factor compared to K. Thus, we accordingly set η of the query kernel to 1/(hd). 3. The token embeddings are no longer re-used for producing logits of token generation. So we set η to 1, which is the same as the scale for initialization. 4. The MLP activation function (i.e. gated-GELU) used in T5 is different from BERT. Still, η for the linear kernel of the MLP output (MLP/wo/Kernel) is set to the same. 5. We set η of the relative attention bias to 0.5 so its scale is close to the attention score (which has scale 1) but will not dominate it.

A.5 DETAILED EXPERIMENT SETTINGS AND LEARNING-RATE SEARCH

In this section, we discuss detailed settings of the pre-training experiments in § 5.1. The hyperparameters and required computation resources are shown in Table 4 . For pre-training BERT with AdamW, we follow the settings of Liu et al. (2019) . For RPE, pre-training on TPU is slow, so we use a different configuration with more TPU cores to train the base-sized model. For T5, we found that using β = 0.98 for Amos and AdamW causes training instability, so we decrease the value to β = 0.95. The settings of AdaFactor follow Raffel et al. (2020) and Shazeer & Stern (2018) . For encoder-only models (i.e. BERT, RoPE and RPE) trained on the Wikipedia+Books corpus, we use the Penn TreeBank corpus (Marcus et al., 1993) as the validation set. The training precision is float32. Number of warm-up steps is set to 10k for AdamW and 20k for Amos. For T5, the training loss is cross-entropy with an extra regularization term, (log Z) 2 (where Z is the normalization factor in Softmax), which makes the logits close to mean 0 and self-normalized. In The number of warm-up steps is set to 10k for both AdamW and AdaFactor. Learning-rate decay is in proportion to t -1/2 (where t is the step) for AdaFactor and linear for AdamW. For pre-training BERT-base, we present a learning-rate search in Figure 5 . For AdamW (Figure 5a ), a smaller learning-rate significantly slows down the convergence, while a larger one results in a bumpy validation loss but almost the same performance. On the other hand, both smaller or larger learning-rate can degrade performance for Amos (Figure 5bc ). Comparing Figure 5b and Figure 5c , we also verify a theoretical prediction about the global learning-rate of Amos in § 4, i.e. the best learning-rate for Amos is in proportion to the square-root of the batch size: Training with 4× the batch size matches 2× the learning-rate.

A.6 VERIFICATION OF ASSUMPTION 1

In Assumption 1, we have assumed that M2(E[gt]) √ E[M2(gt) 2 ] ≥ ξ > 0 for all t and across all variables. E[g t ] and E[M 2 (g t ) 2 ] can be estimated by taking the running average of g t and M 2 (g t ) 2 , respectively; so in Figure 6 we track the pre-training of the BERT-base model, calculate the running averages with exponential decay rate 0.98, and show some typical plots of the ratio. We note two characteristics of the plots: (1) the ratios are increasing as the training proceeds, which suggests that taking a global constant ξ to satisfy Assumption 1 is indeed possible; (2) starting points on the left of these plots are similar across different learning rates, which suggests that it is detectable in the early stage of training whether a learning-rate is too small or too large. In fact, in all plots for all variables we can see that the ratio M2(E[gt]) √ E[M2(gt) 2 ] ≥ 0.01; the appropriate global learning-rate can be read from these plots.

A.7 FINE-TUNING RESULTS

In Table 5 , we show fine-tuning results on the MNLI (Williams et al., 2018) dataset. We compare checkpoints pre-trained for 150k and 300k steps with Amos, and the final checkpoints of AdamW-200k and AdamW-300k. We fine-tune all checkpoints using the Adam optimizer with learning-rate 5e-6, batch size 16, and evaluate by the best accuracy on the MNLI dev set among every 1k of 200k training steps. We run each experiment 3 times and report the mean and standard deviation. The checkpoint pre-trained for 150k by Amos already outperforms the final checkpoint of AdamW-300k. Table 5 : Fine-tuned accuracy on MNLI dev set. We show the mean and standard deviation of 3 runs. Thus, the faster convergence by Amos in pre-training indeed transfers to better performance in fine-tuning; we can save 50% of the pre-training cost by using Amos instead of AdamW.

A.8 ABLATION OF MEMORY REDUCTION

In this section, we experiment with different settings of the memory reduction. We compare the current setting of reducing the input dimension for linear transformations (Reduce 1Axis), to no memory reduction at all (No Reduce), and the setting of reducing both axes for linear transformations (Reduce Dense). For embedding matrices, no axis is reduced in the No Reduce setting, and the embed dimension is reduced for both Reduce 1Axis and Reduce Dense. We have tried reducing both axes for embedding matrices as well, but found the training unstable in this setting. The comparison of memory usage for slot variables is shown below. AdaFactor (No Momentum) Reduce Dense < Reduce 1Axis AdamW No Reduce. Without memory reduction, Amos (No Reduce) consumes more memory than AdamW because it has more slot variables (ṽ t , bt , mt vs. ṽt , mt ). When memory reduction is applied, the memory usage of ṽt , bt becomes negligible compared to the momentum mt , so Amos (Reduce 1Axis and Reduce Dense) requires < 51% memory for slot variables than AdamW. The memory reduction method used by Amos is more efficient than the matrix factorization used by AdaFactor, but in the pre-training of T5 (Figure 2c ), AdaFactor achieved favorable performance (although slightly worse in the end than AdamW with linear learning-rate decay) without using momentum, reducing the memory usage further. Whether Amos can achieve a similar performance without using momentum is unclear yet.

No Reduce

Reduce Dense Reduce 1Axis A.9 TRAINING RESNET50 ON IMAGENET In this section, we apply Amos to the training of ResNet50 (He et al., 2016) on the ImageNet dataset (Deng et al., 2009) . ResNet50 is a deep Convolutional Neural Network of 50 layers, with Batch Normalization (Ioffe & Szegedy, 2015) and Residual Connection. ImageNet is a 1000-class image regularization; so we tried an ad hoc setting Amos-Extra, where the Amos update rule (Equation 1) is replaced by δ t ← d t ξη √ vt g t + ( 1 2 γ t + 0.001)θ t with everything else kept the same (we also tried other constants, but 0.001 was the best). As shown in Figure 8a , Amos-Extra (0.242 lowest error rate) significantly improves the performance on ImageNet. In Figure 8b , we compare the out-of-the-box Amos with Adam (no weight decay). The learning-rate schedule of Adam is set to cosine decay with 5% warmup, and the number of training steps is set to 140k. The base learning-rate is tuned by a random search of log scale between 1e-5 and 1e-2, with 25 runs. Other hyper-parameters are set to the default (i.e. β 1 = 0.9 and β 2 = 0.999). Amos outperforms all the 25 runs; the best 6 of the 25 are shown in Figure 8b . As alternative settings for Amos, we have also tried β = 0.98, 0.999, or ξ = 0.02, or even changed the decay factors to c t = 1 + 1 16 √ ξb t -1 2 and d t = 1 + 1 16 √ ξηb t -1 . All the other settings converge to almost the same validation error rate, sometimes with slightly slower convergence. In Figure 8c , we compare Amos-Extra with the state-of-the-art settings of AdamW. The learning-rate schedule of AdamW is set to cosine decay with 5% warmup, and the number of training steps is set to 187k. The base learning-rate, weight decay strength, and label smoothing rate (defaults to 0.1 for other experiments) are tuned by random search, of log scale between 1e-4 and 1e-2, log scale between 1e-2 and 1.0, and linear scale between 0.0 and 0.2, respectively, with 25 runs. Other hyper-parameters are set to the default (i.e. β 1 = 0.9 and β 2 = 0.999). Among the 25 runs, 9 of them outperform Amos-Extra, which are shown in Figure 8c . The best performing settings of AdamW gain their advantage close to the end of training, which is probably due to the interaction between weight decay and cosine learning-rate schedule. On the other hand, Amos-Extra demonstrates faster and more stable convergence. To conclude, when applied to ResNet50 on ImageNet, Amos can outperform Adam out-of-the-box, and become comparable to the state-of-the-art AdamW settings by adding a small constant weight decay term. However, the extra weight decay term is ad hoc, cannot be covered by our current theory (because we have assumed that the L2 regularization is weak enough and decays to 0, not to bias the loss function but only constrain the scale of trained variables), and probably is not the optimal way to strengthen L2 regularization. It leaves the problem of searching for a more general working theory that enables stronger L2 to future work.



Following Loshchilov & Hutter (2019), we decouple the gradient of an L2 regularization term (taking the form of a weight decay) apart from the adaptive gradient normalization factor 1 √ vt. When an adaptive optimizer is used,Loshchilov & Hutter (2019) point out that the decoupled weight decay is not equivalent to the L2 regularization without explicit decoupling, and the former is more appropriate. In this work, we always treat L2 regularization as decoupled weight decay, and use the two terms interchangeably. As the training proceeds, E[gt] will converge to ≈ 0, so Assumption 1 is related to the observation that, for highly expressive models, gt = ∇ (zt; θ * ) can get close to 0 for every zt in the training data(Ma et al., 2018). However, Assumption 1 only requires that M2(gt) decreases as fast as E[gt], which is empirically verified ( § A.6). Whether M2(gt) actually converges to 0 is not guaranteed (because the training may stop early, or E[gt] not get to exactly 0 due to L2-regularization, etc.) and not used in our theory. On the other hand, M2(gt) is always large compared to E[gt], because ξ is a small value. https://en.wikipedia.org/wiki/Main_Page We have tried different learning-rates in preliminary experiments and the best was chosen. A learning-rate search for BERT-base is presented in § A.5. In our JAX(Bradbury et al., 2018) implementation, the running time per training step for all optimizers (AdamW, Amos and AdaFactor) are almost the same. See Melis et al. (2020) for a setting that achieves the state-of-the-art performance for a single layer LSTM on PTB. It uses RMSProp and dynamically decays the learning-rate by watching the performance on the validation set. To our knowledge, no previous work has been able to achieve the state-of-the-art with a straightforward setting of the optimizer as we do with Amos. It not useful in this work to provide a rigorous definition of "on average". We only point out its deep connection with Stein's example(Stein, 1956) that if k ≥ 3, an estimator with L2 regularization can be better than the maximum likelihood estimator without L2. https://github.com/google/init2winit https://en.wikipedia.org/wiki/Fisher%E2%80%93Tippett%E2%80%93Gnedenko_ theorem#Gumbel_distribution



Figure 1: Pre-training 3 models of the base (12-layer 768-hidden) size: (a) BERT, (b) RoPE and (c) RPE. We show training loss on the top and validation loss on the bottom.

Figure 3: Plots of the quadratic mean of entries of variables over pre-trained steps.

Figure 5: Validation loss for pre-training BERT-base. We compare different learning-rates for (a) AdamW with batch size 1024, (b) Amos with batch size 1024 and (c) Amos with batch size 256.

Figure2c, we plot cross-entropy for validation loss instead of the loss used for training. The training precision of T5 is bfloat16. Possibly because linear transformations in T5 do not have bias terms, we found the model easier to train than BERT, and Amos can be applied without warm-up of learning-rate. The number of warm-up steps is set to 10k for both AdamW and AdaFactor. Learning-rate decay is in proportion to t -1/2 (where t is the step) for AdaFactor and linear for AdamW.

Figure 7: Pre-training BERT-base using Amos with different memory reduction settings. In Figure 7, we show the training and validation loss of pre-training BERT-base by Amos with different memory reduction settings. Reduce Dense is slightly worse in training loss compared to No but not so much in validation loss. On the other hand, Reduce 1Axis is almost the same as No Reduce in training loss, and generalizes even slightly better in validation loss than the other two. So the current Reduce 1Axis setting for Amos is favorable.

Figure 8: Training ResNet50 on ImageNet. We plot error rate of the validation set.

Updated model weights θ t+1 ← θ t -δ t . Output: Updated slot variables rt , bt+1 and optional mt+1 .

The input/output range of non-linear layers we specify in this work for calculating η.

The η calculated for variables in BERT, RoPE and RPE. MLP/Dense2/Kernel is the linear kernel for the output layer of the MLP block. Other linear kernels include e.g. query, key and value kernels in the multi-headed attention layer.

The η calculated for variables in T5. MLP/wo/Kernel is the linear kernel for the output layer of the MLP block.



annex

Ethics Statement This work includes pre-training language models, which have the potential risk of inherited bias from the training data. Our empirical contribution is on accelerating the pre-training process and thus does not focus on addressing such risk. For fair comparison, the pre-training data we have used are the same as previous works, and consequently the models we trained to evaluate our approach are similar to those already open-sourced. We refer to Bommasani et al. (2021) for a discussion of the risks of pre-trained language models.Reproducibility Statement Proof of lemmas in § 4 is given in § A.1. Following the derivation of the Amos update rule, a heuristic derivation of the asymptotic behavior of the Amos decay factors is found in § A.2, and its connection with SGD is discussed in § A.3. Assumption 1 in our derivation is verified by experiments in § A.6. We explain the calculation of η for the Transformer models in § A.4. 

A APPENDIX

A.1 PROOF OF LEMMAS Proof of Lemma 4.1. We havewhere denotes element-wise multiplication, o( αt )/ αt → 0 at αt → 0, and arrays are flattened to vectors for the dot-product. Since z t+1 and z t are drawn from the same distribution, we haveMoreover, because αt does not depend on z t , we haveand the lemma follows by taking αt small enough so that o( αt ) can be omitted. In order to calculate the hyper-parameter η for ResNet, we specify the input/output range of 3 types of non-linear layers in Table 6 . This is similar to Transformers, with the only specialty that the output range of a Max-pooling layer is set to 1/ √ 2 ln n, where n is the patch size. This is because the maximum of n normally distributed random variables 9 has a standard deviation of about 1/ √ 2 ln n.The calculated η for different types of variables in ResNet is shown in Table 7 .BatchNormalization is treated the same as LayerNormalization in Transformer.The projection kernel of the first residual block is scaled up by √ 2 ln n because of its previous max-pooling layer of patch size n.The 2nd and 3rd convolution kernels in each residual block is scaled up by √ 2 because their inputs come from a ReLU activation.The variables for bias and other linear kernels are treated the same as in Transformer.Settings of Amos-*Scale We also tried an Amos-*Scale setting where the η for the projection kernel of the first residual block is set to 1/d instead of (2 ln n)/d (in ResNet50, n = 3 × 3 = 9).

A.9.2 RESULTS

In Figure 8a , we show the validation error rate of Amos and Amos-*Scale, where the error rate for Amos (0.261 lowest) is slightly better than the Amos-*Scale setting (0.263 lowest). Furthermore, it is known that a strong L2 regularization is beneficial for many popular image classification tasks (Loshchilov & Hutter, 2019) , but Amos does not have a hyper-parameter to adjust the strength of L2 

