BRANCH-TRAIN-MERGE: EMBARRASSINGLY PARAL-LEL TRAINING OF EXPERT LANGUAGE MODELS

Abstract

We present Branch-Train-Merge (BTM), a communication-efficient algorithm for embarrassingly parallel training of language models (LMs). We show it is possible to independently train subparts of a new class of LMs on different subsets of the data, eliminating the massive multi-node synchronization currently required to train LMs. BTM learns a set of independent EXPERT LMs (ELMs), each specialized to a different textual domain, such as scientific or legal text. These ELMs can be added and removed to update data coverage, ensembled to generalize to new domains, or averaged to collapse back to a single LM for efficient inference. New ELMs are learned by branching from (mixtures of) ELMs in the current set, further training on new domains, and then merging the resulting models back into the set for future use. Experiments show that BTM improves in-and out-of-domain perplexities as compared to GPT-style transformer LMs, when controlling for training cost. Through extensive analysis, we show that these results are robust to different ELM initialization schemes, but require expert domain specialization; ensembles with random data splits do not perform well. Our results suggest that extreme parallelism could be used to efficiently scale LMs in future work.

1. INTRODUCTION

Training and inference in language models (LMs) typically require access to supercomputers that can achieve the massive multi-node synchronization required to compute model activations and gradients (Brown et al., 2020; Fedus et al., 2022; Zhang et al., 2022) . We develop a new class of LMs that is instead embarrassingly parallel: different parts of the model are independently trained on different subsets of the data, with no need for multi-node training or inference (Figure 1 ). Our new ELMFORESTfoot_0 model consists of a set of EXPERT LMs (ELMs), each specialized to a distinct domain in the training corpus, e.g., scientific or legal text. ELMs are independently functional LMs with no shared parameters, unlike recent mixture-of-experts models that only specialize the transformer feedforward layers to domains (Gururangan et al., 2022) . ELMs can be added or removed at any time to update data coverage, ensembled to generalize to new domains, or parameter averaged to collapse back to a single LM for more efficient inference. We also present the Branch-Train-Merge (BTM) algorithm for learning ELMs. BTM repeatedly expands the ELMFOREST by adding one or more new ELMs in parallel. Each new ELM is first branched by initializing a new LM with an average of the parameters of the most relevant LMs in the current set, then further trained on new domains with a standard cross-entropy loss, and finally merged into the ELMFOREST (Figure 2 ). The ELMFOREST is initalized with a single LM, trained on heterogeneous data to establish strong shared representations for future domain specialization. When evaluated in-and out-of-domain, ELMFORESTs trained with BTM outperform GPT-style transformer LMs, a domain-specialized mixture-of-experts baseline (Gururangan et al. 2022) , and nondomain-based ensemble baselines across a range of computational budgets -up to 1.3B parameters per ELM trained for 7000 GPU-hours in aggregate ( §4.2). These gains are biggest for ELMFOREST ensembles, which activate all experts during inference, but also hold when combine experts with parameter averaging. We also do extensive analysis of these results; expert specialization to domains is crucial ( §5.1), while the compute budget allocation ( §5.2) and the choice of training data for the initial ELM ( §5.3) are much less so. We release our code and models publicly. 

2. ELMFORESTS

We define an ELMFOREST to be a set of EXPERT LMs (ELMs), each independently trained to specialize to a different subset of a corpus. ELMs are inspired by the experts in earlier MoE models (Jacobs et al., 1991) , but we define ours to be domain specialists and specialize the full LM instead of components. We follow Gururangan et al. 2022 and define domains by provenance, or the source of the document (e.g., legal document, computer science research paper), which yields simple and interpretable corpus segmentations useful for identifying ELMs in our experiments. 3 Potential extensions to multi-lingual, -modal, -task, or other types of data splits are left for future work. ELMs remain independent throughout training and inference, enabling the functions below.

2.1. ADDING AND REMOVING ELMS

Existing control techniques to steer LMs towards (Keskar et al., 2019; Gururangan et al., 2020; Dathathri et al., 2020) or away (Welleck et al., 2019) from certain behaviors tend to be expensive, require retraining the model, or do not provide strong guarantees on test-time behavior (Gehman et al., 2020) . In contrast, ELMFORESTs allow for explicit inference-time application of constraints on the provenance of training data. We modify the domain coverage of an ELMFOREST at any time by incorporating new ELMs specialized to different domains or removing existing ELMs in the set.

2.2. INFERENCE FROM ELMFORESTS

ELMFORESTs support two inference modes, which trade off efficiency for performance. Output Ensembling In the first method, we ensemble the output probabilities of multiple ELMs. This allows us to generalize to texts of unknown domain. We use the cached prior method proposed in Gururangan et al. (2022) . Consider the probabilistic view of language modeling, where we estimate p(X t | x <t ). We introduce a domain variable D, alongside each sequence. Then the next-step conditional distribution on the history x <t is: (2) p(X t | x <t )= 3 See §A.1 for a discussion on the possible limitations of this domain definition.



Expert Language Models For Efficient Sparse Training URL anonymized for review.



2

Figure 1: Fully Synchronized vs. Embarrassingly Parallel Training ( §3). (a) In fully synchronized data-parallel training of a TRANSFORMER-LM, all parameters are synchronized across all GPUs. This synchronization incurs hefty cross-node communication costs. (b) In embarrassingly parallel training (our work), individual models are trained on each domain, eliminating expensive cross-node parameter synchronization between those models.

t | x <t , D = j) • p(D = j | x <t ) (1)We estimate a domain posterior, or a probability of a sequence belonging to each of the k domains using Bayes' rule:p(D = j | x <t )= p(x <t | D = j) • p(D = j) p(x <t ) = p(x <t | D = j) • p(D = j) k j =1 p(x <t | D = j ) • p(D = j )

