A CLOSER LOOK AT CODISTILLATION FOR DISTRIBUTED TRAINING

Abstract

Codistillation has been proposed as a mechanism to share knowledge among concurrently trained models by encouraging them to represent the same function through an auxiliary loss. This contrasts with the more commonly used fully-synchronous data-parallel stochastic gradient descent methods, where different model replicas average their gradients (or parameters) at every iteration and thus maintain identical parameters. We investigate codistillation in a distributed training setup, complementing previous work which focused on extremely large batch sizes. Surprisingly, we find that even at moderate batch sizes, models trained with codistillation can perform as well as models trained with synchronous data-parallel methods, despite using a much weaker synchronization mechanism. These findings hold across a range of batch sizes and learning rate schedules, as well as different kinds of models and datasets. Obtaining this level of accuracy, however, requires properly accounting for the regularization effect of codistillation, which we highlight through several empirical observations. Overall, this work contributes to a better understanding of codistillation and how to best take advantage of it in a distributed computing environment.

1. INTRODUCTION

Several recent improvements in the performance of machine learning models can be attributed to scaling the training of neural network models (He et al., 2016; Goyal et al., 2017; Vaswani et al., 2017; Devlin et al., 2018; Shoeybi et al., 2019; Huang et al., 2019; Kaplan et al., 2020; Lepikhin et al., 2020; Brown et al., 2020) . Most approaches to scaling up training leverage some form of data parallelism (using multiple workers to compute gradients on different training samples in parallel), and the most common approach to data-parallel training is synchronous first-order optimization. In synchronous data-parallel training, several replicas of a neural network model are created, each on a different worker. The workers process different mini-batches locally at each step using an optimizer such as Stochastic Gradient Descent (SGD) or Adam (Kingma & Ba, 2015) , and the replicas synchronize (i.e., average either gradients or parameters) at every step by communicating either with a centralized parameter server (Li et al., 2014) or using all reduce (Goyal et al., 2017) . More computing resources can be used in parallel by increasing the number of workers, effectively increasing the batch size used to compute a stochastic gradient. Increasing the batch size reduces the gradient's variance and ideally makes it possible to increase the learning rate in proportion to the number of workers. By doing so, the number of steps required to reach a given model quality is also reduced in proportion to the number of workers, and a near-linear speedup is achieved (Goyal et al., 2017) . However, it has been observed that the linear learning rate scaling strategy leads to performance degradation for very large batch sizes (Goyal et al., 2017) , and even with more principled learning rate scaling mechanisms, synchronous SGD with larger batches eventually yields diminishing returns (Johnson et al., 2020) . Synchronous data-parallel methods ensure that all models are precisely synchronized at every step during training. This incurs substantial communication overhead, which increases with the number of replicas, and can quickly become a bottleneck limiting the processing units' utilization (e.g., GPU or TPU), especially when devices communicate over commodity interconnects such as Ethernet. A number of approaches have been proposed to reduce communication overhead, including using mixed-precision (Jia et al., 2018) or other forms of compression (Alistarh et al., 2017; Bernstein 

