PROGRESSIVELY STACKING 2.0: A MULTI-STAGE LAYERWISE TRAINING METHOD FOR BERT TRAIN-ING SPEEDUP

Abstract

Pre-trained language models, such as BERT, have achieved significant accuracy gain in many natural language processing tasks. Despite its effectiveness, the huge number of parameters makes training a BERT model computationally very challenging. In this paper, we propose an efficient multi-stage layerwise training (MSLT) approach to reduce the training time of BERT. We decompose the whole training process into several stages. The training is started from a small model with only a few encoder layers and we gradually increase the depth of the model by adding new encoder layers. At each stage, we only train the top (near the output layer) few encoder layers which are newly added. The parameters of the other layers which have been trained in the previous stages will not be updated in the current stage. In BERT training, the backward computation is much more timeconsuming than the forward computation, especially in the distributed training setting in which the backward computation time further includes the communication time for gradient synchronization. In the proposed training strategy, only top few layers participate in backward computation, while most layers only participate in forward computation. Hence both the computation and communication efficiencies are greatly improved. Experimental results show that the proposed method can achieve more than 110% training speedup without significant performance degradation.

1. INTRODUCTION

In recent years, the pre-trained language models, such as BERT (Devlin et al., 2018) , XLNet (Yang et al., 2019) , GPT (Radford et al., 2018) , have shown their powerful performance in various areas, especially in the field of natural language processing (NLP). By pre-trained on unlabeled datasets and fine-tuned on small downstream labeled datasets for specific tasks, BERT achieved significant breakthroughs in eleven NLP tasks (Devlin et al., 2018) . Due to its success, a lot of variants of BERT were proposed, such as RoBERTa (Liu et al., 2019b) , ALBERT (Lan et al., 2019 ), Structbert (Wang et al., 2019) etc., most of which yielded new state-of-the-art results. Despite the accuracy gains, these models usually involve a large number of parameters (e.g. BERT-Base has more than 110M parameters and BERT-Large has more than 340M parameters), and they are generally trained on large-scale datasets. Hence, training these models is quite time-consuming and requires a lot of computing and storage resources. Even training a BERT-Base model costs at least $7k (Strubell et al., 2019) , let alone the other larger models, such as BERT-Large. Such a high cost is not affordable for many researchers and institutions. Therefore, improving the training efficiency should be a critical issue to make BERT more practical. Some pioneering attempts have been made to accelerate the training of BERT. You et al. (2019) proposed a layerwise adaptive large batch optimization method (LAMB), which is able to train a BERT model in 76 minutes. However, the tens of times speedup is based on the huge amount of computing and storage resources, which is unavailable for common users. Lan et al. ( 2019) proposed an ALBERT model, which shares parameters across all the hidden layers, so the memory consumption is greatly reduced and training speed is also improved due to less communication overhead. Gong et al. (2019) proposed a progressively stacking method, which trains a deep BERT Figure 1 : The framework of MSLT method. The green blocks mean that these layers only participate in forward computation and are not updated. The red blocks mean that these layers are trained in this stage and they participate in both forward and backward computations. network by progressively stacking from a shallow one. Utilizing the similarity of the attention distributions across different layers, such a strategy achieves about 25% speedup without significant performance loss. Progressively stacking provides a novel training strategy, namely training a BERT model from shallow to deep. However, progressively stacking only has a high training efficiency at the initial stage in which the model depth is small. As the training goes on, the model depth increases and the training speed decreases. The low efficiency of the later stages makes the overall speedup of progressively stacking limited. Note that in the progressively stacking method, the bottom layers are trained with longer time than the top layers. However, we observe that though the bottom layers are updated all the time, they do not have significant changes in the later stages, in terms of the attention distribution which can reflect the functionality of the encoder layers to some extent (Gong et al., 2019) . In other words, most optimization of the bottom layers has been finished in the early stage when the model is shallow. Motivated by this observation, in this work, we propose a novel multi-stage layerwise training (MSLT) approach, which can greatly improve the training efficiency of BERT. We decompose the training process of BERT into several stages, as shown in Fig. 1 . We start the training from a small BERT model with only a few encoder layers and gradually add new encoder layers. At each stage (except the first stage), only the output layer and the newly added top encoder layers are updated, while the other layers which have been trained in the previous stages will be fixed in the current stage. After all the encoder layers are trained, to make the network better behaved, we further retrain the model by updating all the layers together. Since the whole model has already been well trained, this stage only requires a few steps (accounting for about 20% of the total steps). Compared with the progressively stacking method, which requires a lot of steps (accounting for about 70% of the total steps (Gong et al., 2019) ) to train the whole model, our method is much more time-efficient. Experimental results demonstrate the effectiveness and efficiency of the proposed method in two aspects: 1) with the same data throughput (same training steps), our method can achieve comparable performance, compared with the original training method, but consumes much less training time; 2) with the same training time, our method can achieve better performance than the original method. According to the results, the proposed method achieves more than 110% speedup without significant performance degradation. To avoid misunderstanding, it should be mentioned that some widely-known methods such as model compression (Han et al., 2015a; b) and knowledge distillation (Yim et al., 2017; Hinton et al., 2015; Sanh et al., 2019) are designed for network speedup in the inference phase. Namely, these methods are used after the model has been trained. While in this paper, we focus on the model training speedup.

