GSHARD: SCALING GIANT MODELS WITH CONDI-TIONAL COMPUTATION AND AUTOMATIC SHARDING

Abstract

Neural network scaling has been critical for improving the model quality in many real-world machine learning applications with vast amounts of training data and compute. Although this trend of scaling is affirmed to be a sure-fire approach for better model quality, there are challenges on the path such as the computation cost, ease of programming, and efficient implementation on parallel devices. In this paper we demonstrate conditional computation as a remedy to the above mentioned impediments, and demonstrate its efficacy and utility. We make extensive use of GShard, a module composed of a set of lightweight annotation APIs and an extension to the XLA compiler to enable large scale models with up to trillions of parameters. GShard and conditional computation enable us to scale up multilingual neural machine translation Transformer model with Sparsely-Gated Mixture-of-Experts. We demonstrate that such a giant model with 600 billion parameters can efficiently be trained on 2048 TPU v3 cores in 4 days to achieve far superior quality for translation from 100 languages to English compared to the prior art.

1. INTRODUCTION

Scaling neural networks brings dramatic quality gains over a wide array of machine learning problems such as computer vision, language understanding and neural machine translation (Devlin et al., 2018; Mahajan et al., 2018; Arivazhagan et al., 2019; Huang et al., 2019; Brown et al., 2020b) . This general tendency motivated recent studies to scrutinize the factors playing a critical role in the success of scaling, including the amounts of training data, the model size, and the computation being utilized as found by past studies (Advani & Saxe, 2017; Hestness et al., 2019; Geiger et al., 2020) . While the final model quality was found to have a power-law relationship with these factors (Hestness et al., 2017; Kaplan et al., 2020) , the significant quality gains brought by larger models also came with various practical challenges. Training efficiency, which we define as the amount of compute and time used to achieve a superior model quality against the best system existed, is oftentimes left out. In this study, we strive for improving the model quality while being training efficiently. We built a 600 billion parameters sequence-to-sequence Transformer model with Sparsely-Gated Mixture-of-Experts layers, which enjoys sub-linear computation cost and O(1) compilation time. We trained this model with 2048 TPU v3 devices for 4 days on a multilingual machine translation task and achieved far superior translation quality compared to prior art when translating 100 languages to English with a single non-ensemble model. We conducted experiments with various model sizes and found that the translation quality increases as the model gets bigger, yet the total wall-time to train only increases sub-linearly with respect to the model size, as illustrated in Figure 1 . To train such an extremely large model, we relied on the following key design choices.

