EFFICIENT LARGE-SCALE TRANSFORMER TRAINING VIA RANDOM AND LAYERWISE TOKEN DROPPING

Abstract

Large-scale transformer models have become the de-facto architectures for various machine learning applications, e.g., CV and NLP. However, those large models also introduce prohibitive training costs. To mitigate this issue, we propose a novel random and layerwise token dropping method (random-LTD), which skips the computation of a subset of the input tokens at all middle layers. Particularly, random-LTD achieves considerable speedups and comparable accuracy as the standard training baseline. Compared to other token dropping methods, random-LTD does not require (1) any importance score-based metrics, (2) any special token treatment (e.g., [CLS]), and (3) many layers in full sequence length training except the first and the last layers. Besides, a new LayerToken learning rate schedule is proposed for pretraining problems that resolve the heavy tuning requirement for our proposed training mechanism. Finally, we demonstrate that random-LTD can be applied to broader applications, including GPT and BERT pretraining as well as ViT and GPT finetuning tasks. Our results show that random-LTD can save about 33.3% theoretical compute cost and 25.6% wall-clock training time while achieving similar zero-shot evaluations on GPT-3 1.3B as compared to baseline.

1. INTRODUCTION

Large-scale transformers have been demonstrated to have supreme performance on natural language processing (Tenney et al., 2019; Radford et al., 2019; Raffel et al., 2019 ), computer vision (Dosovitskiy et al., 2020) , and other applications (Gong et al., 2021; Guo et al., 2021) . However, both the pretraining procedure and some downstream finetuning tasks (e.g., long document summary) are time-consuming and resource-hungry. Thus, there is a need to speed up the training and reduce the compute cost for large-scale transformer pretraining and finetuning. Recently, Hou et al. (2022) adopt the token pruning/dropping/bypassing technique (Kim et al., 2021; Goyal et al., 2020; Kim & Cho, 2020) from BERT inference to BERT pretraining by skipping the compute of part of the input tokens at some middle layers. The results of (Hou et al., 2022) (referred to as TokenBypass) show that it can theoretically reduce the pretraining cost by 25% for both BERT base and BERT large without losing accuracy on finetuning tasks. Although achieving great speedup, TokenBypass (1) needs an import-score metric to determine the dropped tokens and special token treatment to keep important tokens (e.g., [CLS]), both of which require manual designs; (2) has to keep the first half layers and the last layer (in total, half of the depth) in full sequence length training, which limits its layer-bypassing ability. (3) solely focuses on BERT Masked-LM pretraining tasks and has not been applied to other tasks, e.g., causal-LM. In this work, we address those challenges and introduce our random and layerwise token-dropping method (random-LTD). In summary, our contributions are as follows: • All tokens are treated equally without any special token treatment or import-score measurement, i.e., no manual design, and are dropped in a purely random manner. Meanwhile, instead of fully bypassing the dropped token for all middle layers (Hou et al., 2022) , each layer in random-LTD drops tokens independently from the other layers. This helps the multi-head attention in the middle layers capture the dependency relation across different tokens suggested in (Vig & Belinkov, 2019 ). • random-LTD applies token dropping at all middle layers except the very first and last layers, which further reduces manual design and increases training efficiency. We also propose a new monotonic sequence length growth method as training evolves to (1) reduce the gradient noise introduced by random-LTD for better convergence and (2) close the training and inference (autoregressive generation) gap, since random-LTD breaks the autoregressive manner in middle layers during training, for GPT models. • To reduce the tuning effort for the newly proposed training procedure, we introduce a new LayerToken learning rate schedule, which scales the learning rate based on the sum of consumed tokens of each layer for pretraining tasks. 1 We show its superb performance for random-LTD on GPT/BERT pretraining compared to the standard iteration-based learning rate schedule. • We extensively test random-LTD on both pretraining tasks, including GPT and BERT pretraining, and finetuning tasks, including causal-LM finetuning for GPT and image classification for ViT. For all tasks, random-LTD achieves similar accuracy as the original baseline method with up to 33.3% theoretical cost saving and up to 25.6% wall-clock time saving. • Finally, we show that random-LTD has a potential regularization effect, which can be used for both pretraining and finetuning problems.

2. BACKGROUND

Transformer (Vaswani et al., 2017) architecture is a stack of transformer layers, each of which has two main ingredients, i.e., the multi-head attention (MHA) and the feed-forward connection network (FFC). Suppose the transformer has l layers denoted as L 1 , . . . , L l . Let X i ∈ R s×d be the output tensor of i-th transformer layer, and x 0 the input (after embedding) of the transformer. Here s is the sequence length and d is the hidden dimension. Token dropping (or token bypassing/pruning) (Kim et al., 2021; Goyal et al., 2020; Kim & Cho, 2020; Press et al., 2021; Wang et al., 2021) was originally proposed for BERT inference to reduce the computational overhead. In this case, if a token i (X j,i ) is decided to be dropped at layer j (L j ), the compute cost of this token through all remaining layers (L k where k > j) is eliminated. As such, the sequence length s i of the i-th layer's input X i-1 will be a non-increasing array, i.e., s 0 ≥ s 1 ... where the first (layer 1 to i) and the last few layers (layer L l-j to L l ) of the BERT capture all tokens (i.e., no token dropping) and the middle layers bypass s ′ ≤ s tokens from L i to L l-j . Particularly, the authors (only) test on the encoder transformer (12-layer BERT base and 24-layer BERT large ), and let i = l/2 -1, j = 1, s ′ = s/2. (2) special token treatment, where special tokens (e.g., [MASK], [CLS], [SEP]) are never dropped. Compared to TokenBypass from (Hou et al., 2022) , our random-LTD (1) does not require importance score metric, special token treatment, or the sandwich token dropping rule, which dramatically reduces the manual design effort; (2) has been broadly tested on pretraining tasks, including GPT and BERT, as well as finetuning tasks, including ViT classification and GPT causal-LM. Meanwhile, we found out that directly applying TokenBypass to causal-LM leads to severe accuracy degradation. Please see the detailed description of random-LTD in Section 3 and our extensive evaluation in Section 4 and 5. We also include a thorough discussion of other efficient training methods in Appendix A.

3.1. RANDOM AND LAYERWISE TOKEN DROPPING METHOD

Layerwise Token Dropping Mechanism. As pointed out in Section 2, existing inference and training token dropping methods either permanently drop tokens from the compute graph at intermediate layers, or at least make some tokens fully skip a consecutive series of middle layers. However, several works (Vig & Belinkov, 2019; Michel et al., 2019; Voita et al., 2019) have shown that MHA focuses



Note that the numbers of consumed tokens for different layers are different.



≥ s l . However, such a configuration has been shown instability for adaptive token-dropping inference (Kim & Cho, 2020). Therefore, Kim & Cho (2020) utilize the sandwich rule and distillation from (Yu & Huang, 2019) to stabilize training and boost accuracy. But these two methods also significantly increase the training cost. Thus, such techniques cannot be applied to speed up the pretraining procedure. Recently, Hou et al. (2022) extended token dropping from inference to BERT pretraining (referred to as TokenBypass). Hou et al. (2022) use several importance scores/metrics to determine the dropped tokens, e.g., cumulative loss and frequency of each token. To overcome the training instability issue, the authors proposed two main mechanisms: (1) the sandwich token dropping rule,

