WHY (AND WHEN) DOES LOCAL SGD GENERALIZE BETTER THAN SGD?

Abstract

Local SGD is a communication-efficient variant of SGD for large-scale training, where multiple GPUs perform SGD independently and average the model parameters periodically. It has been recently observed that Local SGD can not only achieve the design goal of reducing the communication overhead but also lead to higher test accuracy than the corresponding SGD baseline (Lin et al., 2020b), though the training regimes for this to happen are still in debate (Ortiz et al., 2021). This paper aims to understand why (and when) Local SGD generalizes better based on Stochastic Differential Equation (SDE) approximation. The main contributions of this paper include (i) the derivation of an SDE that captures the long-term behavior of Local SGD in the small learning rate regime, showing how noise drives the iterate to drift and diffuse after it has reached close to the manifold of local minima, (ii) a comparison between the SDEs of Local SGD and SGD, showing that Local SGD induces a stronger drift term that can result in a stronger effect of regularization, e.g., a faster reduction of sharpness, and (iii) empirical evidence validating that having a small learning rate and long enough training time enables the generalization improvement over SGD but removing either of the two conditions leads to no improvement.

1. INTRODUCTION

As deep models have grown larger, training them with reasonable wall-clock times has led to new distributed environments and new variants of gradient-based training. Recall that Stochastic Gradient Descent (SGD) tries to solve min θ∈R d E ξ∼ D [ℓ(θ; ξ)], where θ ∈ R d is the parameter vector of the model, ℓ(θ; ξ) is the loss function for a data sample ξ drawn from the training distribution D, e.g., the uniform distribution over the training set. SGD with learning rate η and batch size B does the following update at each step, using a batch of B independent ξ t,1 , . . . , ξ t,B ∼ D: θ t+1 ← θ t -ηg t , where g t = 1 B B i=1 ∇ℓ(θ t ; ξ t,i ). Parallel SGD tries to improve wall-clock time when the batch size B is large enough. It distributes the gradient computation to K ≥ 2 workers, each of whom focuses on a local batch of B loc := B/K samples and computes the average gradient over the local batch. Finally, g t is obtained by averaging the local gradients over the K workers. 2019; Keskar et al., 2017; Jastrzębski et al., 2017) . Reducing this generalization gap is the goal of much subsequent research. It was suggested that the generalization gap arises because larger batches lead to a reduction in the level of noise in batch gradient (see Appendix A for more discussion). The Linear Scaling Rule (Krizhevsky, 2014; Goyal et al., 2017; Jastrzębski et al., 2017) tries to fix this by increasing the learning rate in proportion to batch size. This is found to reduce the generalization gap for (parallel) SGD, but does not entirely eliminate it. To reduce the generalization gap further, Lin et al. (2020b) discovered that a variant of SGD, called Local SGD (Yu et al., 2019; Wang & Joshi, 2019; Zhou & Cong, 2018) , can be used as a strong component. Perhaps surprisingly, Local SGD itself is not designed for improving generalization, but for reducing the high communication cost for synchronization among the workers, which is another important issue that often bottlenecks large-batch training (Seide et al., 2014; Strom, 2015; Chen et al., 2016; Recht et al., 2011) . Instead of averaging the local gradients per step as in parallel SGD, Local SGD allows K workers to train their models locally and averages the local model parameters whenever they finish H local steps. Here every worker samples a new batch at each local step, and in this paper we focus on the case where all the workers draw samples with or without replacement from the same training set. See Appendix C for the pseudocode. More specifically, Lin et al. (2020b) proposed Post-local SGD, a hybrid method that starts with parallel SGD (equivalent to Local SGD with H = 1 in math) and switches to Local SGD with H > 1 after a fixed number of steps t 0 . They showed through extensive experiments that Postlocal SGD significantly outperforms parallel SGD in test accuracy when t 0 is carefully chosen. In Figure 1 , we reproduce this phenomenon on both CIFAR-10 and ImageNet. As suggested by the success of Post-local SGD, Local SGD can improve the generalization of SGD by merely adding more local steps (while fixing the other hyperparameters), at least when the training starts from a model pre-trained by SGD. But the underlying mechanism is not very clear, and there is also controversy about when this phenomenon can happen (see Section 2.1 for a survey). The current paper tries to understand: Why does Local SGD generalize better? Under what general conditions does this generalization benefit arise? Previous theoretical research on Local SGD is mainly restricted to the convergence rate for minimizing a convex or non-convex objective (see Appendix A for a survey). A related line of works (Stich, 2018; Yu et al., 2019; Khaled et al., 2020) showed that Local SGD has a slower convergence rate compared with parallel SGD after running the same number of steps/epochs. This convergence result suggests that Local SGD may implicitly regularize the model through insufficient optimization, but this does not explain why parallel SGD with early stopping, which may incur an even higher training loss, still generalizes worse than Post-local SGD. Our Contributions. In this paper, we provide the first theoretical understanding on why (and when) switching from parallel SGD to Local SGD improves generalization. 1. In Section 2.2, we conduct ablation studies on CIFAR-10 and ImageNet and identify a clean setting where adding local steps to SGD consistently improves generalization: if the learning rate is small and the total number of steps is sufficient, Local SGD eventually generalizes better than the corresponding (parallel) SGD baseline. 2. In Section 3.2, we derive a special SDE that characterizes the long-term behavior of Local SGD in the small learning rate regime, as inspired by a previous work (Li et al., 2021b) that proposed this type of SDE for modeling SGD. These SDEs can track the dynamics after the iterate has reached close to a manifold of minima. In this regime, the expected gradient is near zero, but the gradient noise can drive the iterate to wander around. In contrast to the conventional SDE (3) for



(a) CIFAR-10, B = 4096, ResNet-56.

ImageNet, B = 8192, ResNet-50.

Figure 1: Post-Local SGD (H > 1) generalizes better than SGD (H = 1). We switch to Local SGD at the first learning rate decay (epoch #250) for CIFAR-10 and at the second learning rate decay (epoch #100) for ImageNet. See Appendix M.1 for training details.

