BRANCH-TRAIN-MERGE: EMBARRASSINGLY PARAL-LEL TRAINING OF EXPERT LANGUAGE MODELS

Abstract

We present Branch-Train-Merge (BTM), a communication-efficient algorithm for embarrassingly parallel training of language models (LMs). We show it is possible to independently train subparts of a new class of LMs on different subsets of the data, eliminating the massive multi-node synchronization currently required to train LMs. BTM learns a set of independent EXPERT LMs (ELMs), each specialized to a different textual domain, such as scientific or legal text. These ELMs can be added and removed to update data coverage, ensembled to generalize to new domains, or averaged to collapse back to a single LM for efficient inference. New ELMs are learned by branching from (mixtures of) ELMs in the current set, further training on new domains, and then merging the resulting models back into the set for future use. Experiments show that BTM improves in-and out-of-domain perplexities as compared to GPT-style transformer LMs, when controlling for training cost. Through extensive analysis, we show that these results are robust to different ELM initialization schemes, but require expert domain specialization; ensembles with random data splits do not perform well. Our results suggest that extreme parallelism could be used to efficiently scale LMs in future work.

1. INTRODUCTION

Training and inference in language models (LMs) typically require access to supercomputers that can achieve the massive multi-node synchronization required to compute model activations and gradients (Brown et al., 2020; Fedus et al., 2022; Zhang et al., 2022) . We develop a new class of LMs that is instead embarrassingly parallel: different parts of the model are independently trained on different subsets of the data, with no need for multi-node training or inference (Figure 1 ). Our new ELMFORESTfoot_0 model consists of a set of EXPERT LMs (ELMs), each specialized to a distinct domain in the training corpus, e.g., scientific or legal text. ELMs are independently functional LMs with no shared parameters, unlike recent mixture-of-experts models that only specialize the transformer feedforward layers to domains (Gururangan et al., 2022) . ELMs can be added or removed at any time to update data coverage, ensembled to generalize to new domains, or parameter averaged to collapse back to a single LM for more efficient inference. We also present the Branch-Train-Merge (BTM) algorithm for learning ELMs. BTM repeatedly expands the ELMFOREST by adding one or more new ELMs in parallel. Each new ELM is first branched by initializing a new LM with an average of the parameters of the most relevant LMs in the current set, then further trained on new domains with a standard cross-entropy loss, and finally merged into the ELMFOREST (Figure 2 ). The ELMFOREST is initalized with a single LM, trained on heterogeneous data to establish strong shared representations for future domain specialization. When evaluated in-and out-of-domain, ELMFORESTs trained with BTM outperform GPT-style transformer LMs, a domain-specialized mixture-of-experts baseline (Gururangan et al. 2022) , and nondomain-based ensemble baselines across a range of computational budgets -up to 1.3B parameters per ELM trained for 7000 GPU-hours in aggregate ( §4.2). These gains are biggest for ELMFOREST ensembles, which activate all experts during inference, but also hold when combine experts with parameter averaging. We also do extensive analysis of these results; expert specialization to domains is crucial ( §5.1), while the compute budget allocation ( §5.2) and the choice of training data for the initial ELM ( §5.3) are much less so. We release our code and models publicly. 2



Expert Language Models For Efficient Sparse Training URL anonymized for review.1

