Amos: AN ADAM-STYLE OPTIMIZER WITH ADAPTIVE WEIGHT DECAY TOWARDS MODEL-ORIENTED SCALE

Abstract

We present Amos, a stochastic gradient-based optimizer designed for training deep neural networks. It can be viewed as an Adam optimizer with theoretically supported, adaptive learning-rate decay and weight decay. A key insight behind Amos is that it leverages model-specific information to determine the initial learningrate and decaying schedules. When used for pre-training BERT variants and T5, Amos consistently converges faster than the state-of-the-art settings of AdamW, achieving better validation loss within ≤ 70% training steps and time, while requiring ≤ 51% memory for slot variables. Our code is open-sourced at: https: //anonymous-url.

1. INTRODUCTION

The Adam (Kingma & Ba, 2015) optimizer is widely used for training deep neural networks, demonstrating fast convergence especially in the early stages of training. Although previous works have found issues regarding the theoretical convergence of Adam as the training proceeds (Reddi et al., 2018) , in practice it is remedied by various learning-rate schedules and weight decay (Loshchilov & Hutter, 2019) . Specifically, Adam with linear learning-rate decay and constant weight decay is the standard setting for pre-training large language models such as BERT (Devlin et al., 2019) . However, these decay settings are usually ad-hoc, increase the number of hyper-parameters, and may introduce additional complexities in usage. For example, the linearly decaying learning-rate schedule requires knowing the number of training steps in advance, which makes it nontrivial to continuously train a model after the learning-rate decays to 0. In this work, we present Amos, a new optimizer with a theoretically supported and adaptive schedule for learning-rate and weight decay, which can significantly outperform the state-of-the-art AdamW settings for pre-training language models, provide guidance for hyper-parameter tuning, reduce the memory usage, and train continuously without having to specify the number of training steps a priori. A key insight behind Amos is a hyper-parameter η to be provided by the model architecture, which indicates the expected scale of the trainable weights θ of the model ( § 2), i.e., theoretically we assume that an optimal point θ * exists within the | θ * | ≤ η diameter. Deep neural networks are likely to satisfy such a constraint without degrading performance, because there exist many good local minima; and we show that an appropriate η for Amos can improve generalization and accelerate convergence ( § 5.2). In this work, η is calculated in a consistent way from the input/output scale of neural network components, which is hinted by the model design ( § A.4). Given η, Amos decides a learning-rate per variable, and its L2 regularization will lead the trained weights to the specified scale. The decay of the learning-rate is then determined by the L2 regularizer. Thus, Amos performs better because it can utilize the model-oriented information η efficiently; the name Amos stands for "Adaptive weight decay towards Model-Oriented Scale". Empirically, we focus on the Transformer architecture (Vaswani et al., 2017) since it is pre-dominant in pre-trained language models (Bommasani et al., 2021) , but add additional experiments on LSTM (Gers et al., 2000) and ResNet (He et al., 2016) . We apply Amos to the pre-training of 4 models: BERT (Devlin et al., 2019) , two Transformer variants with relative position representations (Su et al., 2021; Shaw et al., 2018) , and the T5 model (Raffel et al., 2020) ; some with various model sizes and batch sizes. In all experiments, Amos consistently outperforms the state-of-the-art setting, achieving better validation loss within ≤ 70% training steps and time ( § 5.1). Compared to AdamW,

