MIME: MIMICKING CENTRALIZED STOCHASTIC AL-GORITHMS IN FEDERATED LEARNING

Abstract

Federated learning (FL) is a challenging setting for optimization due to the heterogeneity of the data across different clients. This heterogeneity has been shown to cause a client drift, which can significantly degrade the performance of algorithms designed for the FL setting. In contrast, centralized learning with centrally collected data is not affected by such a drift and has seen great empirical and theoretical progress with innovations such as momentum and adaptivity. In this work, we propose a general algorithmic framework, MIME, which mitigates client drift and adapts arbitrary centralized optimization algorithms such as SGD and Adam to the federated learning setting. MIME uses a combination of control-variates and server-level statistics (e.g. momentum) at every client-update step to ensure that each local update mimics that of the centralized method run on iid data. Our thorough theoretical and empirical analyses strongly establish MIME's superiority over other baselines. Under review as a conference paper at ICLR 2021 SVRG-style correction to mimic the updates of the centralized algorithm run on i.i.d. data. These global statistics are computed only at the server level and kept fixed throughout the local steps, thereby avoiding a bias due to the atypical local data of any single client. Contributions. We summarize our main results below. • We formalize the cross-device federated learning problem, and propose a new framework MIME that can adapt arbitrary centralized algorithms to this setting. • We prove that incorporating server momentum into each local client update reduces client drift and leads to optimal statistical rates. • Further, we quantify the usefulness of performing multiple local updates on a single client by carefully tracking the bias (client-drift) introduced. This is the first analysis showing improved rates by taking additional multiple steps for general smooth functions. • Finally, we also propose a simpler variant, MIMELITE, with an empirical performance similar to MIME. We report the results of thorough experimental analysis demonstrating that both MIME and MIMELITE are faster than FEDAVG. Related work. Analysis of FedAvg: Much of the recent work in federated learning has focused on analyzing FEDAVG. For identical clients, FEDAVG coincides with parallel SGD, for which Zinkevich et al. (2010) derived an analysis with asymptotic convergence. Sharper and more refined analyses of the same method, sometimes called local SGD, were provided by Stich (2019), and more recently by Stich and Karimireddy (2019), Patel and Dieuleveut (2019), Khaled et al. (

1. INTRODUCTION

Federated learning has become an important paradigm in large-scale machine learning where the training data remains distributed over a large number of clients, which may be mobile phones or network sensors (Konečnỳ et al., 2016b; a; McMahan et al., 2017; Mohri et al., 2019; Kairouz et al., 2019) . A centralized model, here referred to as a server model, is then trained, without ever transmitting client data over the network, thereby providing some basic levels of data privacy and security. Two important settings are distinguished in Federated learning (Kairouz et al., 2019 , Table 1 ): the cross-device and the cross-silo settings. The cross-silo setting corresponds to a relatively small number of reliable clients, typically organizations, such as medical or financial institutions. In contrast, in the cross-device federated learning setting, the number of clients may be extremely large and include, for example, all 3.5 billion active android phones (Holst, 2019) . Thus, in that setting, we may never make even a single pass over the entire clients' data during training. The cross-device setting is further characterized by resource-poor clients communicating over a highly unreliable network. Together, the essential features of this setting give rise to unique challenges not present in the cross-silo setting. Here, we are interested in the cross-device setting, for which we will formalize and study stochastic optimization algorithms. The de facto standard algorithm for this setting is FEDAVG (McMahan et al., 2017) , which performs multiple SGD updates on the available clients, before communicating to the server. While this approach can reduce the total amount of communication required, performing multiple steps on the same client can lead to 'over-fitting' to its atypical local data, a phenomenon known as client drift (Karimireddy et al., 2020) . Furthermore, algorithmic innovations such as momentum (Sutskever et al., 2013; Cutkosky and Orabona, 2019) , adaptivity (Kingma and Ba, 2014; Zaheer et al., 2018; Zhang et al., 2019) , and clipping (You et al., 2017; 2019; Zhang et al., 2020) are critical to the success of deep learning applications and need to be incorporated into the client updates, replacing the SGD update of FEDAVG. Perhaps due to such deficiencies, there exists a large gap in performance between the centralized setting, where data is centrally collected on the server, and the federated setting (Zhao et al., 2018; Hsieh et al., 2019; Hsu et al., 2019; Karimireddy et al., 2020) . To overcome such deficiencies, we propose a new framework, MIME, that mitigates client drift and adapts arbitrary centralized optimization algorithms, e.g. SGD with momentum or Adam, to the federated setting. In each local client update, MIME uses global statistics, e.g. momentum, and an 2. Each client is likely to participate at most once, due to the extremely large number of clients; furthermore, each individual client may have very little data of its own. 3. There may be a wide heterogeneity or non-i.i.d.-ness due to the difference of data distributions for the clients. Thus, our objective will be to minimize the following quantity within the fewest number of clientserver communication rounds: f (x) = E i∼D f i (x) := 1 n i ni ν=1 f i (x; ζ i,ν ) . Here, f i denotes the loss function of client i and {ζ i,1 , . . . , ζ i,ni } its local data. Since the number of clients is extremely large, while size of each local data is rather modest, we represent the former as an expectation and the latter as a finite sum. In each round, the algorithm samples a subset of clients (of size S) and performs some updates to the server model. There is some inherent tension between the second and the third challenge outlined above: if there exists a client with arbitrarily different data whom we may never encounter during training, then there is no hope to actually minimize f . Thus for (1) to be tractable, it is necessary to assume bounded dissimilarity between different f i . (A1) G 2 -BGD or bounded gradient dissimilarity: there exists G ≥ 0 such that E i∼D [ ∇f i (x) -∇f (x) 2 ] ≤ G 2 , ∀x . Next, we also characterize the variance in the Hessians. Note that if f i (•; ζ) is L-smooth, (A2) is always satisfied with δ ≤ 2L and hence is more of a definition rather than an assumption. Note that however, in realistic examples we expect the clients to be similar and hence that δ L. (A2) δ-BHD or bounded Hessian dissimilarity: Almost surely, f is δ-weakly convex i.e. ∇ 2 f i (x) -δI and the loss function of any client i satisfies ∇ 2 f i (x; ζ) -∇ 2 f (x) ≤ δ , ∀x . In addition, we assume that f (x) is bounded from below by f and is L-smooth, as is standard.

3. USING MOMENTUM TO REDUCE CLIENT DRIFT

In this section we examine the tension between reducing communication by running multiple client updates each round, and degradation in performance due to client drift (Karimireddy et al., 2020) . To simplify the discussion, we assume a single client is sampled each round and that clients use full-batch gradients. Server-only approach. A simple way to avoid the issue of client drift is to take no local steps. We sample a client i ∼ D and run SGDm with momentum parameter β and step size η: x t = x t-1 -η ((1 -β)∇f i (x t-1 ) + βm t-1 ) , m t = (1 -β)∇f i (x t-1 ) + βm t-1 . Here, the gradient ∇f i (x t ) is unbiased i.e. E[∇f i (x t )] = ∇f (x t ) and hence we are guaranteed convergence. However, this strategy can be communication-intensive and we are likely to spend all our time waiting for communication with very little time spent on computing the gradients. FedAvg approach. To reduce the overall communication rounds required, we need to make more progress in each round of communication. Starting from y 0 = x t-1 , FEDAVG (McMahan et al., 2017) runs multiple SGD steps on the sampled client i ∼ D y k = y k-1 -η∇f i (y k-1 ) for k ∈ [K] , and then a pseudo-gradient gt = -(y K -x t ) replaces ∇f i (x t-1 ) in the SGDm algorithm (2). This is referred to as server-momentum since it is computed and applied only at the server level (Hsu et al., 2019) . However, such updates give rise to client-drift resulting in performance worse than the naive server-only strategy (2). This is because by using multiple local updates, (3) starts overfitting to the local client data, optimizing f i (x) instead of the actual global objective f (x). The net x x 1 x 2 x t m t x t+1 x t x t+1 m t

FEDAVG updates MIME updates

Figure 1 : Client-drift in FEDAVG (left) and MIME (right) is illustrated for 2 clients with 3 local steps and momentum parameter β = 0.5. The local SGD updates of FEDAVG (shown using arrows for client 1 and client2) move towards the average of client optima x 1 +x 2 2 which can be quite different from the true global optimum x . Server momentum only speeds up the convergence to the wrong point in this case. In contrast, MIME uses unbiased momentum and applies it locally at every update. This keeps the updates of MIME closer to the true optimum x . effect is that FEDAVG moves towards an incorrect point (see Fig 1, left) . If K is sufficiently large, approximately y K x i , where x i := arg min x f i (x) ⇒ E i∼D [g t ] (x t -E i∼D [x i ]) . Further, the server momentum is based on gt and hence is also biased. Thus, it cannot correct for the client drift. We next see how a different way of using momentum can mitigate client drift. Mime approach. FEDAVG experiences client drift because both the momentum and the client updates are biased. To fix the former, we compute momentum using only global statistics as in (2): m t = (1 -β)∇f i (x t-1 ) + βm t-1 . To reduce the bias in the local updates, we will apply this unbiased momentum every step: y k = y k-1 -η((1 -β)∇f i (y k-1 ) + βm t-1 ) for k ∈ [K] . Note that the momentum term is kept fixed during the local updates i.e. there is no local momentum used, only global momentum is applied locally. Since m t-1 is a moving average of unbiased gradients computed over multiple clients, it intuitively is a good approximation of the general direction of the updates. By taking a convex combination of the local gradient with m t-1 , the update ( 5) is potentially also less biased. In this way MIME combines the communication benefits of taking multiple local steps and prevents client-drift (see Fig 1, right) . Sec. C makes this intuition precise.

4. MIME FRAMEWORK

In this section we describe how to adapt arbitrary centralized algorithms (and not just SGDm) to the federated learning problem (1) while ensuring there is no client-drift. Algorithm 1 describes two variants MIME and MIMELITE, which consists of three components i) a base algorithm we are trying to mimic, ii) how we compute the global statistics, and iii) the local client updates. Base algorithm. We assume the centralized base algorithm we are imitating can be decomposed into two steps: an update step U which updates the parameters x, and a statistics step V(•) which keeps track of global statistics s. Each step of the base algorithm B = (U, V) uses a gradient g to x ← x -η U(g, s) , s ← V(g, s) . (BASEALG) V may track multiple statistics which we represent collectively as s. While SGDm (2) is clearly of this form, Appendix ?? shows this for other algorithms like Adam, etc. Algorithm 1 Mime and MimeLite input: initial x and s, learning rate η and base algorithm B = (U, V) for each round t = 1, • • • , T do sample subset S of clients communicate (x, s) to all clients i ∈ S communicate c ← 1 |S| j∈S ∇f j (x) (only for Mime) on client i ∈ S in parallel do initialize local model y i ← x for k = 1, • • • , K do sample mini-batch ζ from local data y i ← y i -ηU(∇f i (y i ; ζ) -∇f i (x; ζ) + c, s) (Mime) y i ← y i -ηU(∇f i (y i ; ζ), s) (MimeLite) end for compute full local-batch gradient ∇f i (x) communicate (y i , ∇f i (x)) end on client s ← V 1 |S| i∈S ∇f i (x), s (update optimization statistics) x ← 1 |S| i∈S y i (update server parameters) end for Compute statistics globally, apply locally. When updating the statistics of the base algorithm, we use only the gradient computed at the server parameters. Further, they remain fixed throughout the local updates of the clients. This ensures that these statistics remain unbiased and representative of the global function f (•). At the end of the round, the server performs s ← V 1 |S| i∈S ∇f i (x), s , where ∇f i (x) = 1 ni ni ν=1 ∇f i (x; ζ i,ν ) . (STATS) Note that we use full-batch gradients computed at the server parameters x, not client parameters y i . Local client updates. Each client i ∈ S performs K updates using U of the base algorithm and a minibatch gradient. There are two variants possible corresponding to MIME and MIMELITE differentiated using colored boxes. Starting from y i ← x, repeat the following K times y i ← y i -ηU(g, s) , where g ← ∇f i (y i ; ζ) -∇f i (x; ζ) + 1 |S| j∈S ∇f j (x) or ∇f i (y i ; ζ) . (CLTSTEP) MIMELITE simply uses the local minibatch gradient whereas MIME uses an SVRG style correction (Johnson and Zhang, 2013) . This is done to reduce the noise from sampling a local mini-batch. While this correction yields faster rates in theory (and in practice for convex problems), in deep learning applications we found that MIMELITE closely matches the performance of MIME. Finally, there are two modifications made in practical FL: we weigh all averages across the clients by the number of datapoints n i , and we perform K epochs instead of K steps (McMahan et al., 2017) . The former modifies the objective (1) with f i being weighted by n i , and the latter has been empirically observed to perform better, but lacks strong justification (Wang et al., 2020b) . 

5. THEORETICAL ANALYSIS OF MIME

( G √ S ) 3 2 + L - FEDAVG 1 FedSGD (Karimireddy et al., 2020) (Arjevani et al., 2019 ) G 2 S 2 + G 3/2 + L G 2 µS + G µ √ + L µ SCAFFOLD 2 (Karimireddy et al., 2020) N S 2 3 L N S + L µ MIME 3 MimeSGD G 2 S 2 + δ G 2 µS + δ µ MimeMVR ( G √ S ) 3 2 + δ - Lower bound Ω( G √ S ) 3 2 Ω( G 2 S ) 1 Requires K ≥ σ 2 G 2 number of local updates with within-client variance of σ 2 . 2 In cross-device FL, the total number of clients (N ) can be of the same order as number of rounds (since we only make few passes over all clients), or even ∞, making the bounds vacuous. 3 Requires K ≥ L/δ number of local updates. Theorem I. For L-smooth f with G 2 gradient dissimilarity (A1), δ Hessian dissimilarity (A2) and F := (f (x 0 ) -f ), let us run MimeMVR for T rounds and generate an output x out . This output satisfies E ∇f (x out ) 2 ≤ under the following conditions • PL-Strongly convex without momentum: for η = Õ min 1 δK+µK+L , 1 µT , β = 0, and T = Õ LG 2 µS + L + δK µK F log 1 , • Non-convex without momentum: for η = O min 1 δK+L , ( SF G 2 T K 2 ) 1/2 , β = 0, T = O LG 2 F S 2 + (L + δK)F K , • Non-convex with momentum: for η = O min 1 δK+L , ( SF G 2 T K 3 ) 1/3 , β = 1 -O( δ 2 (T G 2 ) 2/3 ), T = O (1 + δ)G 2 F S 2 3/4 + (L + δK)F K . The expectation in E ∇f (x out ) 2 ≤ is taken both over the sampling of the clients during the running of the algorithm, the sampling of the mini-batches in local updates, and the choice of x out (which is chosen randomly from the client iterates y i as described in the Appendix). Table 1 shows that the rate (ignoring constants) of FEDAVG on non-convex functions is G 2 S 2 + G 3/2 + L . This is slower than simply running SGD which obtains a rate of G 2 S 2 + L . In contrast, MimeSGD obtains a rate of G 2 S 2 + δ where δ L thus improving upon both SGD and FedAvg. While asymptotically these three rates may seem equivalent, in machine learning we care about low accuracy settings where is not too small (Bottou, 2010) and so the lower order terms matter, as (bottom) on simulated data, all with momentum (β = 0.5). FedAvg gets slower as the gradientdissimilarity (G) increases (to the right). MimeLite shows a similar pattern, but is consistently better than FedAvg. Mime is significantly faster than both and is unaffected by heterogeneity (G). also additionally evidenced by our experiments (Sec. 6). We can also compare with SCAFFOLD (Karimireddy et al., 2020) which obtains a rate of N S 2 3 L where N is the total number of clients. While asymptotically this is a faster rate, N in the cross-device setting is potentially infinite or at least comparable to the total number of training rounds, making these bounds vacuous. This too is reflected in our experiments (Fig. 5 ). 2019). The momentum β used in this case is of the order of (1 -O(T G 2 ) -2/3 ) i.e. as T increases, our momentum parameter asymptotically approaches 1. In contrast, previous analyses of distributed momentum (e.g. Yu et al. (2019a) ) prove rates of the form G 2 S(1-β) 2 , which are worse than that of standard SGD by a factor of 1 1-β . Thus, ours is the first result which theoretically showcases the usefulness of using large momentum in distributed and federated learning. Our theory suggests that the momentum parameter should be increased if G increases i.e. as the clients become more heterogeneous, there is stronger client-drift and hence we need more momentum to compensate. Our analysis is is highly non-trivial and involves three crucial ingredients: i) computing the momentum at the server level to ensure that it remains unbiased and then applying it locally during every client update to reduce variance, ii) carefully keeping track of the bias introduced via additional local steps, and iii) an SVRG correction to allow using mini-batches. Our experiments (Sec. 6) verify that the first two theoretical insights are indeed applicable in deep learning settings as well, whereas the latter seems to matter more in convex settings. See App. C where we make this discussion more concrete and Appendices F-G for detailed proofs and theorem statements.

6. EXPERIMENTAL ANALYSIS

We run experiments on simulated and real datasets to confirm our theory. Our main findings are i) MIME outperforms FEDAVG across all settings, ii) its SVRG correction is useful for convex problems, and iii) momentum significantly improves performance for non-convex problems. We consider four algorithms: SERVER-ONLY, FEDAVG, MIME, and MIMELITE. Each of these adapt base optimizers SGD, SGDm, and Adam. The SERVER-ONLY method computes a full batch gradient on each of the sampled clients and uses their aggregate directly in the base optimizer (akin to (2)). For FEDAVG, we follow Reddi et al. (2020) who run multiple epochs of SGD on each client sampled, and then aggregate the net client updates. This aggregated update is used as a pseudogradient in the base optimizer (called server optimizer). The learning rate for the server optimizer is fixed to 1 as in (Wang et al., 2020c) . This is done to ensure all algorithms have the same number of hyper-parameters. Finally, MIME and MIMELITE follow Algorithm 1 and also run a fixed number of epochs on the client. Aggregation is weighted by the number of samples on the clients. Mime and MimeLite have very similar performance and are consistently the best. FedAvg is even worse than the server-only baselines. Also, Mime makes better use of momentum than FedAvg, with a large increase in performance (right).

6.1. SIMULATED CONVEX EXPERIMENTS

Our simulated experiments use two clients each with a simple scalar quadratic loss, as in (Karimireddy et al., 2020) . We use full-batch gradients with both clients participating every round. The simulated data has Hessian dissimilarity δ = 1 (A2) and smoothness L = 2. We vary the gradient dissimilarity (A1) as G ∈ [1, 10, 100]. All the algorithms use momentum with β = 0.5 and their learning rates were tuned up to a tolerance of 5E-3 to ensure lowest loss after 60 rounds. The results are collected in Fig. 2 . When G is small, we see that FEDAVG can outperform the SERVER-ONLY (SGDm) baseline, though its loss quickly plateaus. On increasing G, FEDAVG becomes even slower. MIMELITE differs from FEDAVG only in how the momentum is used. In all settings, it slightly outperforms FEDAVG though even it sees a substantial slow down as we increase G. This reflects our theory which predicts that for convex cases, momentum does not give significant gains. MIME, on the other hand, is substantially faster than all other methods and is even unaffected by changing G. Thus, in this simple convex setting, the SVRG correction completely eliminates client drift.

6.2. REAL WORLD EXPERIMENTS

We run real world deep learning experiments on EMNIST62 with a 2 layer MLP model and on Cifar100 with ResNet20, both accessed through Tensorflow Federated (TFF, 2020a). All methods run 10 local epochs, batch size 20, and the learning rates for all methods were individually tuned. We refer to Appendix A for additional details and results. FedAvg is slower than even the naive server-only methods which make no local updates. This perfectly mirrors our theory that Mime > server-only > FedAvg. The performance of MimeLite is because SVRG correction may not be necessary in deep learning (Defazio and Bottou, 2019) . With momentum > without momentum. Fig. 3 (right) examines the impact of momentum on FedAvg and Mime. Momentum slightly improves the performance of FedAvg, whereas it has a significant impact on the performance of Mime. This is also in line with our theory and confirms that Mime's strategy of applying it locally at every client update makes better use of momentum. Fixed statistics > updated statistics. Finally, we check how the performance of Mime changes if instead of keeping the momentum fixed throughout a round, we let it change. The momentum is reset at the end of the round ignoring the changes the clients make to it. Appendix B.1 shows that this consistently worsens the performance, confirming that it is better to keep the statistics fixed. Together, the above observations validate all aspects of Mime (and MimeLite) design: compute statistics at the server level, and apply them unchanged at every client update.

7. CONCLUSION

Our work initiated a formal study of the cross-device federated learning problem. We argued that the natural heterogeneity among the clients gives rise to client drift and significantly hampers the performance of approaches such as FEDAVG. We then showed how momentum can be an excellent tool to overcome this client drift if used correctly. We use Tensorflow federated datasets (TFF, 2020a) to generate the datasets. Our federated learning simulation code is written in Jax (Frostig et al., 2018) . Our Resnet18 model is based off of (Haiku, 2020) (Resnet v2), and following (Hsieh et al., 2019; Reddi et al., 2020) we replace batch norm with group norm with 2 groups. Black and white was reversed in EMNIST62 (i.e. subtracted from 1) to make them similar to MNIST. CIFAR100 used the usual pre-processing (normalization and centering), and data augmentation (random crop and horizontal flipping) following (kuangliu, 2020 (accessed June 4, 2020).

A.2 PRACTICALITY OF EXPERIMENTS

In the experiments we only cared about the number of communication rounds, ignoring that MIME actually needs twice the number of bits per round and that the SERVER-ONLY methods have a much smaller computational requirement. This is standard in the federated learning setting as introduced by McMahan et al. ( 2017) and is justified because most of the time in cross-device FL is spent in establishing connections with devices rather than performing useful work such as communication or computation. In other words, latency and not bandwidth or computation are critical in cross device FL. However, one can certainly envision cases where this is not true. Incorporating communication compression strategies (Suresh et al., 2017; Alistarh et al., 2017; Karimireddy et al., 2019) or clientmodel compression strategies (Caldas et al., 2018a; Frankle and Carbin, 2019; Hamer et al., 2020) into our MIME framework can potentially address such issues and are important future research directions. As we already discussed previously, we believe both the datasets and the tasks being studied here are quite realistic in nature. We now discuss our choice of other parameters in the experiment setup (number of training rounds, sampled clients, batch-size, etc.) Each round of federated learning takes 2-3 mins in the real world and is relatively independent of the size of communication (Bonawitz et al., 2019) implying that training 1000 rounds takes 1.4-2 days. This underscores the importance of ensuring that the algorithms for federated learning converge in as few rounds as possible, as well as have very easy to set default hyper-parameters. Thus in our experimental setup we keep all parameters other than the learning rate to their default values. In practice, this learning rate can be set by set using a small dataset on the server (as in (Hard et al., 2020) ). The choice of batch size being 10 was made both keeping in mind the limited memory available to each client as well as to match prior work. Finally, while we limit ourselves to sampling 20 workers per round due to computational constraints, in real world FL thousands of devices are often available for training simultaneously each round (Bonawitz et al., 2019) . They also note that the probability of each of these devices being available has clear patterns and is far from uniform sampling. Conducting a large scale experimental study which mimics these alternate forms of heterogeneity is an important direction for future work.

A.3 HYPERPARAMETER SEARCH

We use the EMNIST62 with MLP model as a 'test-bed' for exploring different algorithms given it being both a representative task of cross-device FL as well as being computationally efficient. All plots reported are for this setting. A more fine-grained search over hyperparameters to report the final test accuracies is made over the rest of the tasks/datasets. For all SGDm methods, we pick momentum β = 0.9. For Adam methods, we fix β 1 = 0.9, β 2 = 0.99, and ε = 1 × 10 -3 similar to (Reddi et al., 2020) . None of the algorithms use weight decay, clipping etc. The learning rate is then tuned to obtain the best test accuracy. For all experiments, unless explicitly mentioned otherwise, the learning rate is searched over a grid of η ∈ [1 × 10 1 , 1, 1 × 10 -1 , 1 × 10 -2 , 1 × 10 -3 , 1 × 10 -4 , 1 × 10 -5 ] .

A.4 COMPARISON WITH PREVIOUS RESULTS

As far as we are aware, (Reddi et al., 2020) is the only prior work which conducts a systematic experimental study of federated learning algorithms over multiple realistic datasets. The algorithms comparable across the two works (e.g. FedSGD, FedSGDm, and FedAdam) have qualitatively similar performance except with one exception: FedAdam consistently underperforms FedSGDm. We believe this difference is because (Reddi et al., 2020) additionally tune the server learning rate and the parameter for Adam. As we explain in Section A.2, we chose to keep these parameters to some default values to compare methods in the 'low-tuning' setting. We also point that while FedAdam struggles to perform in this setup, MimeAdam and MimeLiteAdam are very stable and even often outperform their SGDm counterparts. This is also the default/recommended behavior in TensorFlow Federated (TFF, 2020b) and (Wang et al., 2020c FedProx with SGDm run on EMNIST62 with a 2 hidden layer (300u-100) MLP. For FedProx and SCAFFOLD, in addition to tuning the learning rate, we search for the best server momentum β ∈ [0, 0.9, 0.99]. FedProx uses an additional regularizer µ which we search over [0.1, 0.5, 1] (note that FedProx with µ = 0 is the same as FedAvg). The best test accuracy (which are plotted here) was by β = 0 for both and µ = 0.1 for FedProx. Note that FedProx is the slowest method here (in fact it is even slower than FedAvg). The additional regularizer does not seem to reduce client drift while still slowing down convergence (Karimireddy et al., 2020; Wang et al., 2020b) . SCAFFOLD is also slower than Mime in this setup. This is because SCAFFOLD was designed for the crosssilo setting and not corss-device setting. The large number of clients (N = 3.4k) means that each client is on averaged visited less than 6 times during the entire training (20 clients per round for 1k rounds). Hence, the client control variate stored is quite stale (from about 200 rounds ago) which slows down the convergence. This perfectly reflects our theoretical understanding that when the number of clients N is large relative to training rounds (which is true in the cross-device setting) SCAFFOLD is outperformed by MIME. For FEDAVG we searched over both the client and server learning rates, whereas for MIME and MIMELITE, we search only over client (base) learning rate. This search is performed over a grid (0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.4, 0.8). For momentum, we chose best of (0.9, 0.99) and for Adam, we varied β 1 in (0.9, 0.99), β 2 in (0.99, 0.999) and in (0.01, 0.001, 0.0001). In this setting as well MIME and MIMELITE outperform FEDAVG but with a smaller margin. 

FEDAVG

step V SGD - x -ηg - SGDm m x -η((1 -β)g + βm) m = (1 -β)g + βm RMSProp v x -η + √ v g v = (1 -β)g 2 + βv Adam m, v x -η + √ v ((1 -β 1 )g + β 1 m) m = (1 -β 1 )g + β 1 m v = (1 -β 2 )g 2 + β 2 v

C PROOF OVERVIEW

In this section, we give proof sketches of the main components of Theorem I: i) how momentum reduces the effect of client drift, ii) how local steps can take advantage of Hessian similarity, and iii) why the SVRG correction improves constants. Improving the statistical term via momentum. Note that the statistical (first) term in Theorem I without momentum (β = 0) for the convex case is LG 2 µS . This is (up to constants) optimal and cannot be improved. For the non-convex case however using β = 0 gives the usual rate of LG 2 S 2 . However, this can be improved to (1+δ)G 2 F S 2 3/4 using momentum. This matches a similar improvement in the centralized setting (Cutkosky and Orabona, 2019; Tran-Dinh et al., 2019) and is in fact optimal (Arjevani et al., 2019) . Let us examine why momentum improves the statistical term. Assume that we sample a single client i t in round t and that we use full-batch gradients. Also let the local client update at step k round t be of the form y ← y -ηd k . (6) The ideal choice of update is of course d k = ∇f (y) but however this is unattainable. Instead, MIME with momentum β = 1 -a uses d SGDm k = mk ← a∇f i (y) + (1 -a)m t-1 where m t-1 is the momentum computed at the server. The variance of this update can then be bounded as E mk -∇f (y) 2 a 2 E ∇f it (y) -∇f (y) 2 + (1 -a) E m t-1 -∇f (y) 2 ≈ a 2 G 2 + (1 -a) E m t-1 -∇f (x t-2 ) 2 ≈ aG 2 . The last step follows by unrolling the recursion on the variance of m. We also assumed that η is small enough that y ≈ x t-2 . This way, momentum can reduce the variance of the update from G 2 to (aG 2 ) by using past gradients computed on different clients. To formalize the above sketch requires slightly modifying the momentum algorithm similar to (Cutkosky and Orabona, 2019) , and is carried out in Appendix G. Improving the optimization term via local steps. The optimization (second) term in Theorem I for the convex case is δK+L µK and for the non-convex case (with or without momentum) is δK+L K . In contrast, the optimization term of the server-only methods is L/µ and L/ respectively. Since in most cases δ L, the former can be significantly smaller than the latter. This rate also suggests that the best choice of number of local updates is L/δ i.e. we should perform more client updates when they have more similar Hessians. This generalizes results of (Karimireddy et al., 2020) from quadratics to all functions. This improvement is due to a careful analysis of the bias in the gradients computed during the local update steps. Note that for client parameters y k-1 , the gradient E[∇f it (y k-1 )] = E[∇f (y k-1 )] since y k-1 was also computed using the same loss function f it . In fact, only the first gradient computed at x t-1 is unbiased. Dropping the subscripts k and t, we can bound this bias as: E[∇f i (y) -∇f (y)] = E[∇f i (y) -∇f i (x) ≈∇ 2 fi(x)(y-x) + ∇f (x) -∇f (y i ) ≈∇ 2 f (x)(x-yi) ] + E i [∇f i (x)] -∇f (x) =0 since unbiased ≈ E[(∇ 2 f i (x) -∇ 2 f (x))(y i -x)] ≈ δ E[(y i -x)] . Thus, the Hessian dissimilarity (A2) control the bias, and hence the usefulness of local updates. This intuition can be made formal using Lemma 3. Mini-batches via SVRG correction. In our previous discussion about momentum and local steps, we assumed that the clients compute full batch gradients and that only one client is sampled per round. However, in practice a large number (S) of clients are sampled and further the clients use mini-batch gradients. The SVRG correction reduces this within-client variance since Var ∇f i (y i ; ζ) -∇f i (x; ζ) + 1 |S| i∈S ∇f i (x) L 2 y i -x 2 + G 2 S ≈ G 2 S . Here, we used the smoothness of f i (•; ζ) and assumed that y i ≈ x since we don't move too far within a single round. Thus, the SVRG correction allows us to use minibatch gradients in the local updates while still ensuring that the variance is of the order G 2 /S.

D TECHNICALITIES

We examine some additional definitions and introduce some technical lemmas.

D.1 ASSUMPTIONS AND DEFINITIONS

We make precise a few definitions and explain some of their implications. We first discuss the two assumptions on the dissimilarity between the gradients (A1) and the Hessians (A2). Loosely, these two quantities are an extension of the concepts of variance and smoothness which occur in centralized SGD analysis to the federated learning setting. Just as the variance and smoothness are completely orthogonal concepts, we can have settings where G 2 (gradient dissimilarity) is large while δ (Hessian dissimilarity) is small, or vice-versa. Our assumption about the bound on the G gradient dissimilarity can easily be extended to (G, B) gradient dissimilarity used by (Karimireddy et al., 2019) : E i ∇f i (x) 2 ≤ G 2 + B 2 ∇f (x) 2 . (7) All the proofs in the paper extend in a straightforward manner to the above weaker notion. Since this notion does not present any novel technical challenge, we omit it in the rest of the proofs. Note however that the above weaker notion can potentially capture the fact that by increasing the model capacity, we can reduce G. In the extreme case, by taking a sufficiently over-parameterized model, it is possible to make G = 0 in certain settings (Vaswani et al., 2018) . However, this comes both at a cost of increased resource requirements (i.e. higher memory and compute requirements per step) but can also result in other constants increasing (e.g. B and L). The second crucial definition we use in this work is that of δ bounded Hessian dissimilarity (A2). This has been used previously in the analyses of distributed (Shamir et al., 2014; Arjevani and Shamir, 2015; Reddi et al., 2016) and federated learning (Karimireddy et al., 2020) , but has been restricted to quadratics. Here, we show how to extend both the notion as well as the analysis to general smooth functions. The main manner we will use this assumption is in Lemma 3 to claim that for any x and y the following holds: E ∇f i (y; ζ) -∇f i (x; ζ) + ∇f (x) -∇f (y) 2 ≤ δ 2 y -x 2 . (8) Here the expectation is both over ζ as well as the choice of client i. To understand what the above condition means, it is illuminating to define Ψ i (z; ζ) = f i (z; ζ) -f (z). Then, we can rewrite (A2) and ( 8) respectively as ∇ 2 Ψ i (z; ζ) ≤ δ and E ∇Ψ i (y; ζ) -∇Ψ i (x; ζ) 2 ≤ δ 2 y -x 2 . Thus ( 8) and (A2) are both different notions of smoothness of Ψ i (x; ζ) (formal definition of smoothness will follow soon). The latter definition closely matches the notion of squared-smoothness used by Arjevani et al. (2019) and is a promising relaxation of (A2). However, we run into some technical issues since in our case the variable y can also be a random variable and depend on the choice of the client i. Extending our results to this weaker notion of Hessian-similarity and proving tight non-convex lower bounds is an exciting theoretical challenge. Finally note that if the functions f i (x; ζ) are assumed to be smooth as in (Shamir et al., 2014; Arjevani and Shamir, 2015; Karimireddy et al., 2020) , then Ψ i ((x; ζ) is 2L-smooth. Thus, we always have that δ ≤ 2L. But, as Shamir et al. (2014) show, it is possible to have δ L if the data distribution amongst the clients is similar. Further, the lower bound from Arjevani and Shamir (2015) proves that Hessian-similarity is the crucial quantity capturing the number of rounds of communication required for distributed/federated optimization. We next define the terms smoothness and strong-convexity which we repeatedly use in the paper. (A3) f is L-smooth and satisfies: ∇f (x) -∇f (y) ≤ L xy , for any x, y . (9) The assumption (A3) also implies the following quadratic upper bound on f f (y) ≤ f (x) + ∇f (x), y -x + L 2 y -x 2 . ( ) Further, if f is twice-differentiable, (A3) implies that ∇ 2 f (x) ≤ β for any x. (A4) f is µ-PL strongly convex (Karimi et al., 2016) for µ > 0 if it satisfies: ∇f (x) 2 ≥ 2µ(f (x) -f ) . Note that PL-strong convexity is much weaker than the standard notion of strong-convexity (Karimi et al., 2016) .

D.2 SOME TECHNICAL LEMMAS

Now we cover some technical lemmas which are useful for computations later on. First, we state a relaxed triangle inequality true for the squared 2 norm. Lemma 1 (relaxed triangle inequality). Let {v 1 , . . . , v τ } be τ vectors in R d . Then the following are true: 1. v i + v j 2 ≤ (1 + c) v i 2 + (1 + 1 c ) v j 2 for any c > 0, and 2. τ i=1 v i 2 ≤ τ τ i=1 v i 2 . Proof. The proof of the first statement for any c > 0 follows from the identity: v i + v j 2 = (1 + c) v i 2 + (1 + 1 c ) v j 2 - √ cv i + 1 √ c v j 2 . For the second inequality, we use the convexity of x → x 2 and Jensen's inequality 1 τ τ i=1 v i 2 ≤ 1 τ τ i=1 v i 2 . Next we state an elementary lemma about expectations of norms of random vectors. Lemma 2 (separating mean and variance). Let {Ξ 1 , . . . , Ξ τ } be τ random variables in R d which are not necessarily independent. First suppose that their mean is E[Ξ i ] = ξ i and variance is bounded as E[ Ξ i -ξ i 2 ] ≤ σ 2 . Then, the following holds E[ τ i=1 Ξ i 2 ] ≤ τ i=1 ξ i 2 + τ 2 σ 2 . Now instead suppose that their conditional mean is E[Ξ i |Ξ i-1 , . . . Ξ 1 ] = ξ i i.e. the variables {Ξ iξ i } form a martingale difference sequence, and the variance is bounded by E[ Ξ i -ξ i 2 ] ≤ σ 2 as before. Then we can show the tighter bound E[ τ i=1 Ξ i 2 ] ≤ 2 τ i=1 ξ i 2 + 2τ σ 2 . Proof. For any random variable X, E[X 2 ] = (E[X -E[X]]) 2 + (E[X]) 2 implying E[ τ i=1 Ξ i 2 ] = τ i=1 ξ i 2 + E[ τ i=1 Ξ i -ξ i 2 ] . Expanding the above expression using relaxed triangle inequality (Lemma 1) proves the first claim: E[ τ i=1 Ξ i -ξ i 2 ] ≤ τ τ i=1 E[ Ξ i -ξ i 2 ] ≤ τ 2 σ 2 . For the second statement, ξ i is not deterministic and depends on Ξ i-1 , . . . , Ξ 1 . Hence we have to resort to the cruder relaxed triangle inequality to claim E[ τ i=1 Ξ i 2 ] ≤ 2 τ i=1 ξ i 2 + 2 E[ τ i=1 Ξ i -ξ i 2 ] and then use the tighter expansion of the second term: E[ τ i=1 Ξ i -ξ i 2 ] = i,j E (Ξ i -ξ i ) (Ξ j -ξ j ) = i E Ξ i -ξ i 2 ≤ τ σ 2 . The cross terms in the above expression have zero mean since {Ξ i -ξ i } form a martingale difference sequence.

E PROPERTIES OF FUNCTIONS WITH δ BOUNDED HESSIAN DISSIMILARITY

We now study two lemmas which hold for any functions which satisfy (A2). The first is closely related to the notion of smoothness (A3). Lemma 3 (similarity). The following holds for any two functions f i (•; ζ) and f (•) satisfying (A2), and any x, y: ∇f i (y; ζ) -∇f i (x; ζ) + ∇f (x) -∇f (y) 2 ≤ δ 2 y -x 2 . Proof. Consider the function Ψ(z) := f i (z; ζ) -f (z). By the assumption (A2), we know that ∇ 2 Ψ(z) ≤ δ for all z i.e. Ψ is δ-smooth. By standard arguments based on taking limits (Nesterov, 2018) , this implies that ∇Ψ(y) -∇Ψ(x) ≤ δ yx . Plugging back the definition of Ψ into the above inequality proves the lemma. Next, we see how weakly-convex functions satisfy a weaker notion of "averaging does not hurt". Lemma 4 (averaging). Suppose f is δ-weakly convex. Then, for any γ ≥ δ, and a sequence of parameters {y i } i∈S and x: 1 |S| i∈S f (y i ) + γ 2 x -y i 2 ≥ f ( ȳ) + γ 2 x -ȳ 2 , where ȳ := 1 |S| i∈S y i . Proof. Since f is δ-weakly convex, Φ(z) := f (z) + γ 2 z -x 2 is convex. This proves the claim since 1 |S| i∈S Φ(y i ) ≤ Φ( ȳ).

F ANALYSIS OF MIMESGD (WITHOUT MOMENTUM)

Let us rewrite the MimeSGD update using notation convenient for analysis. In each round t, we sample clients S t such that |S t | = S. The server communicates the server parameters x t-1 as well as the average gradient across the sampled clients c t-1 defined as c t-1 = 1 S i∈S t ∇f i (x t-1 ) . Note that computing c t-1 itself requires two rounds of communication. But from a theoretical viewpoint, this only changes the communication rounds required by a constant factor and hence we ignore this issue. Practically, we recommend using MimeLite if this additional rounds of communication are an issue. Then each client i ∈ S t makes a copy y t i,0 = x t-1 and perform K local client updates. In each local client update k ∈ [K], the client samples a dataset ζ t i,k and y t i,k = y t i,k-1 -η(∇f i (y t i,k-1 ; ζ t i,k ) -∇f i (x t-1 ; ζ t i,k ) + c t-1 ) . (12) After K such local updates, the server then aggregates the new client parameters as x t = 1 S i∈S t y t i,K . Variance of update. Consider the local update at step k on client i, dropping superscript t y i,k = y i,k-1 -ηd i,k , where d i,k := ∇f i (y i,k-1 ; ζ i,k ) -∇f i (x; ζ i,k ) + c . Lemma 5. Given that assumptions (A1) and (A2) are satisfied, each client update satisfies E d i,k 2 ≤ 3G 2 S + 3δ 2 y i,k-1 -x 2 + 3 ∇f (y i,k-1 ) 2 . Proof. Starting from the definition of d i,k and the relaxed triangle inequality, d i,k 2 = ∇f i (y i,k-1 ; ζ i,k ) -∇f i (x; ζ i,k ) + c 2 = ∇f i (y i,k-1 ; ζ i,k ) -∇f i (x; ζ i,k ) + ∇f (x) -∇f (y i,k-1 ) + (c -∇f (x)) + ∇f (y i,k-1 ) 2 ≤ 3 ∇f i (y i,k-1 ; ζ i,k ) -∇f i (x; ζ i,k ) + ∇f (x) -∇f (y i,k-1 ) 2 + 3 c -∇f (x) 2 + 3 ∇f (y i,k-1 ) 2 ≤ 3δ 2 y i,k-1 -x 2 + 3 c -∇f (x) 2 + 3 ∇f (y i,k-1 ) 2 . We used Lemma 3 to bound the first term. Taking expectations on both sides to bound the second term via (A1) yields the lemma. Distance moved in each round. We show that the distance moved by a client in each round during the K updates can be controlled. To further reduce the burden of notation, we will drop he subscript i, k and refer y i,k-1 simply as y and y i,k as y + . Lemma 6. For update following (12) for η ≤ 1 4Kδ satisfying (A1) and (A2), we have at any step k, E y + -x 2 ≤ 1 + 2 K y -x 2 + 6Kη 2 G 2 S + 6Kη 2 ∇f (y) 2 . Proof. Starting from the update (12) and the relaxed triangle inequality Lemma 1 with c = K ≥ 1, E y + -x 2 = E y -ηd -x 2 ≤ (1 + 1 c ) E y -x 2 + (1 + c)η 2 E d 2 ≤ (1 + 1 K ) E y -x 2 + 3(1 + K)η 2 G 2 S + 3(1 + K)η 2 δ 2 y -x 2 + 3(1 + K)η 2 ∇f (y) 2 ≤ (1 + 1 K + 6Kη 2 δ 2 ) E y -x 2 + 6Kη 2 G 2 S + 6Kη 2 ∇f (y) 2 . The second to last step used the variance bound in Lemma 5. The proof now follows from the restriction on step-size since 16K 2 η 2 δ 2 ≤ 1. Progress in one client update. We now have the tools required to keep track of the progress made in one round. Lemma 7. For any constant µ ≥ 0 and each step of MimeSGD with step size η ≤ min 1 18L , 1 756δK , 1 42µK , and given that (A1)-(A3) hold, we have η 4 E ∇f (y t i,k-1 ) 2 ≤ A t i,k-1 -A t i,k + (255KLη 2 )G 2 2S , where we define A t i,k := E[f (y t i,k )] + δ 1 + 3 K K-k E y t i,k -x t-1 2 , and A t i,k-1 := E[f (y t i,k-1 )] + δ(1 -µη) 1 + 3 K K-k+1 E y t i,k-1 -x t-1 2 . Proof. The assumption that f is L-smooth implies a quadratic upper bound (10). Using this in our case, we have E[f (y + )] -E[f (y)] ≤ -η E[ ∇f (y), d ] + Lη 2 2 E d 2 = -η E[ ∇f (y), ∇f i (y; ζ) -∇f i (x; ζ) + c ] T1 + Lη 2 2 E d 2 T2 . Let us examine the terms T 1 and T 2 separately. By our variance bound Lemma 5, we have that T 2 ≤ 3Lη 2 G 2 2S + 3Lη 2 δ 2 2 y -x 2 + 3Lη 2 2 ∇f (y) 2 . To simplify T 1 , the biggest obstacle is that E[∇f i (y; ζ)] = ∇f (y) since y itself depends on the sampling of the client i. Only the server gradient is unbiased and E[c] = ∇f (x). Instead we will use the similarity of the functions as in Lemma 3: T 1 = -η E[ ∇f (y), ∇f i (y; ζ) -∇f i (x; ζ) + ∇f (x) ] ≤ - η 2 E ∇f (y) 2 + η 2 ∇f i (y; ζ) -∇f i (x; ζ) + ∇f (x) -∇f (y) 2 ≤ - η 2 E ∇f (y) 2 + ηδ 2 2 E y -x 2 . The first inequality above used that for any a, b, the following holds -2ab = (a -b) 2 -a 2 -b 2 ≤ (a -b) 2 -a 2 . The second used the similarity Lemma 3. Combining the terms T 1 and T 2 together, we have E[f (y + )] -E[f (y)] ≤ (3Lη 2 -η) 2 E ∇f (y) 2 + (ηδ 2 + 3Lη 2 δ 2 ) 2 E y -x 2 + 3Lη 2 G 2 2S . To bound the distance between y and x, we use Lemma 6 multiplied on both sides by δ 1 + 3 K K-k . Note that δ ≤ δ 1 + 3 K K-k ≤ 21δ . This gives us for any constant µ ≥ 0 δ 1 + 3 K K-k E y + -x 2 ≤ δ 1 + 3 K K-k 1 + 2 K y -x 2 + 6Kδ 1 + 3 K K-k η 2 G 2 S + 6Kδ 1 + 3 K K-k η 2 ∇f (y) 2 ≤ δ(1 -µη) 1 + 3 K K-(k-1) y -x 2 + 6Kδ 1 + 3 K K-k η 2 G 2 S + 6Kδ 1 + 3 K K-k η 2 ∇f (y) 2 + 1 + 3 K K-k (µηδ - δ K ) y -x 2 ≤ δ 1 + 3 K K-(k-1) y -x 2 + 126Kδη 2 G 2 S + 126Kδη 2 ∇f (y) 2 + (21µηδ -δ K ) y -x 2 . The second inequality from the last used that 1+2/K < (1+3/K)(1-µη) = (1+2/K)+(1/K -(1 + 3/K)µη). This is true by our restriction that η < 1 42µK , which implies (1 + 3/K)µη < 4µη < 1/(10K) and so that (1/K -(1 + 3/K)µη) > 0. Adding the two bounds, we get the following recursion E[f (y + )] + δ 1 + 3 K K-k E y + -x 2 =:A i,k ≤ E[f (y)] + δ 1 + 3 K K-(k-1) (1 -µη) y -x 2 =:A i,k-1 + (252Kδη 2 + 3Lη 2 -η) 2 E ∇f (y) 2 + (ηδ 2 + 3Lη 2 δ 2 + 42µηδ) 2 - δ K E y -x 2 + (3Lη 2 + 252Kδη 2 )G 2 2S Now, note that our constraint on the step-size η ≤ min( 1 18L , 1 756δK ) implies that 252Kδη 2 +3Lη 2 ≤ η 2 and K(ηδ 2 +3Lη 2 δ 2 +42µηδ) ≤ 2δ. Plugging this into the above bound and recalling that δ ≤ L finishes the proof. Convergence for PL strongly-convex functions. We will unroll the one step progress Lemma 7 compute a linear rate. Theorem II. Suppose that (A1)-(A4) are satisfied for µ > 0. Then the updates of MimeSGD with step-size η = min(η max , Õ 1 µT K ) for η max = min 1 18L , 1 756δK , 1 42µK satisfy E ∇f (x out ) 2 ≤ Õ LG 2 µT S + F η max exp - µ 18L + 756δK + 42µK T K where we define F := f (x 0 ) -f , ȳt k is chosen to be y t i,k for i ∈ S t uniformly at random, and the output x out to be ȳt k with probability proportional to (1 -ηµ 4 ) KT -kt . Proof. Note that by PL strong convexity (A4), we have η 4 ∇f (y) 2 ≤ η 8 ∇f (y) 2 + ηµ 4 (f (y) -f ) . Using this, we can tighten the one step progress Lemma 7 as η 8 E ∇f (y t i,k-1 ) 2 ≤ 1 - µη 4 E[f (y t i,k-1 ) -f ] + δ 1 - µη 4 1 + 3 K K-k+1 E y t i,k-1 -x t-1 2 =: 1- µη 4 Φ t i,k-1 -E[f (y t i,k ) -f ] + δ 1 + 3 K K-k E y t i,k -x t-1 2 =:Φ t i,k + (255KLη 2 )G 2 2S , Now take a weighted sum over the steps k using weights (1 -ηµ 4 ) K-k η 8 k∈[K] (1-ηµ 4 ) K-k E ∇f (y t i,k-1 ) 2 ≤ Φ t i,0 -1 -µη 4 K Φ t i,K + k∈[K] (1-ηµ 4 ) K-k (255KLη 2 )G 2 2S . By the initialization y t i,0 = x t-1 and hence Φ t i,0 = E f (x t-1 ) -f and further by the averaging Lemma 4, we have 1 S i∈S Φ t i,K ≥ E f (x t ) -f . Hence, on averaging over the clients we get the one round progress lemma η 8S k∈[K] i∈S t (1 -ηµ 4 ) K-k E ∇f (y t i,k-1 ) 2 ≤ E f (x t-1 ) -f -1 -µη 4 K (E f (x t ) -f ) + k∈[K] (1 -ηµ 4 ) K-k (255KLη 2 )G 2 2S . Now further taking a weighted average over the rounds t ∈ [T ] with weights proportional to (1 - ηµ 4 ) tK gives η 8S t∈[T ] k∈[K] i∈S t (1 -ηµ 4 ) KT -kt E ∇f (y t i,k-1 ) 2 ≤ E f (x 0 ) -f + t∈[T ] k∈[K] (1 -ηµ 4 ) KT -kt (255KLη 2 )G 2 2S . Finally, choosing the right step size, similar to Lemma 23 of (Karimireddy et al., 2020) yields the desired rate. Convergence for general functions. We will unroll the one step progress Lemma 7 to compute a sublinear rate. Theorem III. Suppose that (A1)-(A3) are satisfied. Then the updates of MimeSGD with step-size η = min η max , √ F S √ 255K 2 LG 2 T for η max = min 1 18L , 1 756δK satisfy 1 KT S k∈[K] t∈[T ] i∈S t E ∇f (y t i,k-1 ) 2 ≤ O G √ LF √ T S + (L + δK)F T K . where we define F := f (x 0 ) -f . Proof. By summing over the equations from Lemma 7 for all local steps in one round we obtain η 2 K k=1 E ∇f (y t i,k-1 ) 2 ≤ A t i,0 -A t i,K + 255K 2 Lη 2 G 2 2S . By the initialization y t i,0 = x t-1 , hence A t i,0 = A t j,0 = E[f (x t-1 )] for all i, j ∈ S t . Furthermore, by Lemma 4 1 |S t | i∈St A t i,K ≥ E[f (x t )] + δ x t-1 -x t 2 ≥ E[f (x t )] = A t+1 i,0 This means that we can keep unrolling over all rounds, obtaining η 2S T t=1 K k=1 i∈S t E ∇f (y t i,k-1 ) 2 ≤ A 1 i,0 -A T i,K + 255T K 2 Lη 2 G 2 2S . By noting A 1 i,0 -A T i,K = (f (x 0 ) -f ) -(E[f (x T ) -f ) ≤ F and the choice of the stepsize the theorem follows.

G ANALYSIS OF MIMEMVR (WITH MOMENTUM BASED VARIANCE REDUCTION)

In this section we see how to use momentum based variance reduction (Cutkosky and Orabona, 2019) to reduce the variance of the updates and improve convergence. It should be noted that MVR does not exactly fit the MIME framework (BASEALG) since it requires computing gradients at two points on the same batch. However, it is straightforward to extend the idea of MIME to MVR as we will now do. We use MVR as a theoretical justification for why the usual momentum works well in practice. An interesting future direction would be to adapt the algorithm and analysis of (Cutkosky and Mehta, 2020) , which does fit the framework of MIME. MimeMVR algorithm. Now, we formally describe the MimeMVR algorithm. In each round t, we sample clients S t such that |S t | = S. The server communicates the server parameters x t-1 , the momentum m t-1 and the average gradient across the sampled clients c t-1 defined as c t-1 = 1 S i∈S t ∇f i (x t-2 ) . Note that both c t-1 and m t-1 use gradients and parameters from previous rounds (different from the previous section). Then each client i ∈ S t makes a copy y t i,0 = x t-1 and perform K local client updates. In each local client update k ∈ [K], the client samples a dataset ζ t i,k and y t i,k = y t i,k-1 -ηd t i,k , where d t i,k = a(∇f i (y t i,k-1 ; ζ t i,k ) -∇f i (x t-1 ; ζ t i,k ) + c t-1 ) + (1 -a)m t-1 + (1 -a)(∇f i (y t i,k-1 ; ζ t i,k ) -∇f i (x t-1 ; ζ t i,k )) . After K such local updates, the server then aggregates the new client parameters as x t = 1 S j∈S t y t j,K . The momentum term is updated at the end of the round for a ≥ 0 as m t = a( 1 S j∈S t ∇f j (x t-1 )) + (1 -a)m t-1 SGDm + (1 -a)( 1 S j∈S t ∇f j (x t-1 ) -∇f j (x t-2 )) correction . (17) As we can see, the momentum update of MVR can be broken down into the usual SGDm update, and a correction. Intuitively, this correction term is very small since f i is smooth and x t-1 ≈ x t-2 . Another way of looking at the update ( 17) is to note that if all functions are identical i.e. f j = f k for any j, k, then (17) just becomes the usual gradient descent. Thus MimeMVR tries to maintain an exponential moving average of only the variance terms, reducing its bias. We refer to (Cutkosky and Orabona, 2019) for more detailed explanation of MVR. Momentum variance bound. We compute the variance of the server momentum m t-1 . Define the variance term V t = m t -∇f (x t-1 ). Then its expected norm can be bounded as follows. Lemma 8. For the momentum update (17), given (A1) and (A2), the following holds for any a ∈ [0, 1] and V t := m t -∇f (x t-1 ) E V t 2 ≤ (1 -a) E V t-1 2 + 2δ 2 E x t-1 -x t-2 2 + 2a 2 G 2 S . Proof. Starting from the momentum update (17), V t = (1 -a)V t-1 + (1 -a)   1 S j∈S t (∇f j (x t-1 ) -∇f j (x t-2 )) -∇f (x t-1 ) + ∇f (x t-2 )   + a   1 S j∈S t (∇f j (x t-1 ) -∇f (x t-1 )   . Now, the term V t-1 does not have any information from round t and hence is statistically independent of the rest of the terms. Further, the rest of the terms have mean 0. Hence, we can separate out the zero mean noise terms from the V t-1 following Lemma 2 and then the relaxed triangle inequality Lemma 1 to claim E V t 2 ≤ (1 -a) 2 E V t-1 2 + 2(1 -a) 2 1 S j∈S t (∇f j (x t-1 ) -∇f j (x t-2 )) -∇f (x t-1 ) + ∇f (x t-2 ) 2 + 2a 2 1 S j∈S t (∇f j (x t-1 ) -∇f (x t-1 ) 2 ≤ (1 -a) 2 E V t-1 2 + 2(1 -a) 2 δ 2 x t-1 -x t-2 2 + 2a 2 G 2 S . The inequality used the Hessian similarity Lemma 3 to bound the second term and the heterogeneity bound (A1) to bound the last term. Finally, note that (1 -a) 2 ≤ (1 -a) ≤ 1 for a ∈ [0, 1]. Update variance bound. Now we examine the variance of our update in each local step d t i,k . Lemma 9. For the client update (15), given (A1) and (A2), the following holds for any a ∈ [0, 1] E d t i,k -∇f (y t i,k-1 ) 2 ≤ 3 E V t-1 2 + 3δ 2 E y t i,k-1 -x t-2 2 + 3a 2 G 2 S . Proof. Starting from the client update (15), we can rewrite it as d t i,k -∇f (y t i,k-1 ) = (1 -a)V t-1 + ∇f i (y t i,k-1 ; ζ t i,k ) -∇f i (x t-2 ; ζ t i,k )) -∇f (y t i,k-1 + ∇f (x t-2 ) + a   1 S j∈S t ∇f j (x t-2 ) -∇f (x t-2   . We can use the relaxed triangle inequality Lemma 1 to claim E d t i,k -∇f (y t i,k-1 ) 2 = 3(1 -a) 2 E V t-1 2 + 3(1 -a) 2 (∇f i (y t i,k-1 ; ζ t i,k ) -∇f i (x t-2 ; ζ t i,k )) -(∇f (y t i,k-1 ) -∇f (x t-2 )) 2 + 3a 2 1 S j∈S t ∇f j (x t-2 ) -∇f (x t-2 ) 2 ≤ 3 E V t-1 2 + 3δ 2 y t i,k-1 -x t-2 2 + 3a 2 G 2 S . The last inequality used the Hessian similarity Lemma 3 to bound the second term and the heterogeneity bound (A1) to bound the last term. Also, (1 -a) 2 ≤ 1 since a ∈ [0, 1]. Distance moved in each step. We show that the distance moved by a client in each step during the client update can be controlled. Lemma 10. For MimeMVR updates (15) with η ≤ 1 6Kδ and given (A1) and (A2), the following holds ∆ t i,k ≤ 1 + 1 K ∆ t i,k-1 + 18η 2 Ka 2 G 2 S + 18η 2 K E V t-1 2 + 6η 2 K ∇f (y t i,k-1 ) 2 , where we define ∆ t i,k := max E y t i,k -x t-2 2 , E y t i,k -x t-1 2 , E x t-1 -x t-2 2 . Proof. Starting from the MimeMVR update (15) and the relaxed triangle inequality with c = 2K, E y t i,k -x t-2 2 = E y t i,k-1 -ηd t i,k -x t-2 2 ≤ 1 + 1 2K E y t i,k-1 -x t-2 2 + (2K + 1)η 2 E d t i,k 2 ≤ 1 + 1 2K E y t i,k-1 -x t-2 2 + 6Kη 2 E d t i,k -∇f (y t i,k-1 ) 2 + 6Kη 2 E ∇f (y t i,k-1 ) 2 ≤ 1 + 1 2K + 18Kη 2 δ 2 E y t i,k-1 -x t-2 2 + 18Kη 2 E V t-1 2 + 18Kη 2 a 2 G 2 S + 6Kη 2 E ∇f (y t i,k-1 ) 2 . The last inequality used the update variance bound Lemma 9. We can simplify the expression further since η ≤ 1 6Kδ implies 18Kη 2 δ 2 ≤ 1 2K . Similar computations for E y t i,k -x t-1 2 yield the lemma. Progress in one step. Now we have all the tools required to compute the progress made in each round. Lemma 11. For any step of MimeMVR with step size η ≤ min 1 L , 1 40δK and momentum parameter a = 1536η 2 δ 2 K 2 . Then, given that (A1)-(A3) hold, we have E[f (y t i,k )] + 3η a E V t 2 + 8ηδ 2 K a 1 + 2 K K-k ∆ t i,k ≤ E[f (y t i,k-1 )] + 3η a E V t-1 2 + 8ηδ 2 K a 1 + 2 K K-(k-1) ∆ t i,k-1 - η 4 E ∇f (y t i,k-1 ) 2 + 11136η 3 δ 2 K 2 G 2 S . Proof. The assumption that f is L-smooth implies a quadratic upper bound (10). f (y t i,k ) -f (y t i,k-1 ) ≤ -η ∇f (y t i,k-1 ), d t i,k + Lη 2 2 d t i,k 2 = - η 2 ∇f (y t i,k-1 ) 2 + Lη 2 -η 2 d t i,k 2 + η 2 d t i,k -∇f (y t i,k-1 ) 2 . The second equality used the fact that for any a, b, -2ab = (a -b) 2 -a 2 -b 2 . The second term can be removed since η ≤ 1 L . Taking expectation on both sides and using the update variance bound Lemma 9, E f (y t i,k ) -E f (y t i,k-1 ) ≤ - η 2 E ∇f (y t i,k-1 ) 2 + 3ηa 2 G 2 2S + 3η 2 E V t-1 2 + 3ηδ 2 2 E y t i,k-1 -x t-2 2 Multiplying the momentum variance bound Lemma 8 by 3η a , we have 3η a E V t 2 ≤ 3η a E V t-1 2 + 6ηδ 2 a E x t-1 -x t-2 2 + 6ηaG 2 S -3η E V t-1 2 . We will also multiply the distance bound Lemma 10 by 8ηδ 2 K a 1 + 2 K K-k . Note that for any K ≥ 1 and k ∈ [K], we have 1 ≤ 1 + 2 K K-k ≤ 8. Then we get 8ηδ 2 K a 1 + 2 K K-k ∆ t i,k ≤ 8ηδ 2 K a 1 + 2 K K-(k-1) ∆ t i,k-1 - 8ηδ 2 a ∆ t i,k-1 + 1152η 3 δ 2 K 2 a G 2 S + 1152η 3 δ 2 K 2 a E V t-1 2 + 384η 3 δ 2 K 2 a ∇f (y t i,k-1 ) 2 , where recall that we defined ∆ t i,k := max E y t i,k -x t-2 2 , E y t i,k -x t-1 2 , E x t-1 -x t-2 2 . Combining the three inequalities together, we get E f (y t i,k ) + 3η a E V t 2 + 8ηδ 2 K a 1 + 2 K K-k ∆ t i,k ≤ E f (y t i,k-1 ) + 3η a E V t-1 2 + 8ηδ 2 K a 1 + 2 K K-(k-1) ∆ t i,k-1 + 384η 3 δ 2 K 2 a - η 2 E ∇f (y t i,k-1 ) 2 + 1152η 2 δ 2 K 2 + 6 + 3a 2 aηG 2 S + 3η 2 + 1152η 3 δ 2 K 2 a -3η E V t-1 2 + 3 2 + 6 a - 8 a ηδ 2 ∆ t i,k-1 . Note that 1152η 3 δ 2 K 2 a = 3η 4 since we defined a = 1536η 2 δ 2 K 2 . Further, a ≤ 1 when defined this way since we assumed η ≤ 1 40δK . Similarly, the definition of a implies that 384η 3 δ 2 K 2 a = η 4 . Thus, we can simplify the above expression as E f (y t i,k ) + 3η a E V t 2 + 8ηδ 2 K a 1 + 2 K K-k ∆ t i,k ≤ E f (y t i,k-1 ) + 3η a E V t-1 2 + 8ηδ 2 K a 1 + 2 K K-(k-1) ∆ t i,k-1 - η 4 E ∇f (y t i,k-1 ) 2 + 11136η 3 δ 2 K 2 G 2 S . This proves the lemma. Progress in one round. Let us sum over all the steps within a round to compute the progress made in a full round. Lemma 12. For any round of MimeMVR with step size η ≤ min 1 L , 1 40δK and momentum parameter a = 1536η 2 δ 2 K 2 . Then, given that (A1)-(A3) hold, we have η 4KS k∈[K],j∈S t E ∇f (y t i,k-1 ) 2 ≤ Φ t-1 -Φ t + 11136η 3 δ 2 K 2 G 2 S , where we define the sequence Φ t := 1 K E[f (x t )] + 3ηK a E V t 2 + 8ηδ 2 a E x t -x t-1 2 . Proof. We start by summing Lemma 11 over the client updates η 4 k∈[K] E ∇f (y t i,k-1 ) 2 ≤ E[f (y t i,0 )] + 3ηK a E V t-1 2 + 8ηδ 2 K a 1 + 2 K K ∆ t i,0 -E[f (y t i,K )] - 3ηK a E V t 2 + 8ηδ 2 K a ∆ t i,K + 11136η 3 δ 2 K 3 G 2 S . Recall that we defined ∆ t i,k := max E y t i,k -x t-2 2 , E y t i,k -x t-1 2 , E x t-1 -x t-2 2 . Because y t i,0 = x t-1 , we can simplify E[f (y t i,0 )] + 8ηδ 2 K a ∆ t i,0 ≤ E[f (x t-1 )] + 8ηδ 2 K a E x t-1 -x t-2 2 Then by the averaging Lemma 4, we have 1 S j∈S t E[f (y t j,K )] + 8ηδ 2 K a ∆ t j,K ≥ 1 S j∈S E[f (y t j,K )] + E x t-1 -y t j,K 2 ≥ E[f (x t )] + 8ηδ 2 K a E x t-1 -x t 2 . So by averaging our recursion over the sampled clients, and diving our summation over the updates by K, we get η 4KS k∈[K],j∈S t E ∇f (y t i,k-1 ) 2 ≤ 1 K E[f (x t-1 )] + 3η a E V t-1 2 + 8ηδ 2 a E x t-1 -x t-2 2 =Φ t-1 -1 K E[f (x t )] + 3ηK a E V t 2 + 8ηδ 2 a E x t -x t-1 2 =Φ t + 11136η 3 δ 2 K 2 G 2 S . Theorem IV (non-convex convergence of MimeMVR). Let us run MimeMVR with step size η ≤ min 1 15K F S T δ 2 G 2 1/3 , 1 L , 1 40δK and momentum parameter a = 1536η 2 δ 2 K 2 . Then, given that (A1)-(A3) hold, we have 1 KST t∈[T ] k∈[K] j∈S t E ∇f (y t i,k-1 ) 2 ≤ O (1 + δ)G 2 F ST 2/3 + (L + δK)F KT , where we define F := f (x 0 ) -f . Proof. Unroll the one round progress Lemma 12 and average over T rounds to get 1 KST t∈[T ] k∈[K] j∈S t E f (y t i,k-1 ) 2 ≤ 4(Φ 0 -Φ T ) ηKT + 11136η 2 δ 2 K 2 G 2 S ≤ 4(f (x 0 ) -f ) ηKT + 11136η 2 δ 2 K 2 G 2 S . Our choice of step size now yields the desired rate.



The momentum based variance reduction (MVR), introduced by Cutkosky and Orabona (2019), is a modification of the standard SGDm algorithm to make it amenable to analysis. All our theory uses MVR, while our experiments use SGDm.



Figure2: SGDm (dashed black), FedSGDm (top), MimeLiteSGDm (middle), and MimeSGDm (bottom) on simulated data, all with momentum (β = 0.5). FedAvg gets slower as the gradientdissimilarity (G) increases (to the right). MimeLite shows a similar pattern, but is consistently better than FedAvg. Mime is significantly faster than both and is unaffected by heterogeneity (G).

, by incorporating momentum based variance reduction, MimeMVR matches the lower bound of Ω G/ √ S by Arjevani et al. (

Figure 3: Server-only, FedAvg, Mime, and MimeLite with SGDm (left) and Adam (middle) run on (top) EMNIST62 and a 2 hidden layer (300u-100) MLP and (bottom) Resnet20 run on Cifar100.Mime and MimeLite have very similar performance and are consistently the best. FedAvg is even worse than the server-only baselines. Also, Mime makes better use of momentum than FedAvg, with a large increase in performance (right).

Fig 3 shows the results. Mime > MimeLite > Server-only > FedAvg. Mime and MimeLite have the best performance.

Figure5: Comparison with Scaffold and FedProx for cross-device FL: Mime, SCAFFOLD and FedProx with SGDm run on EMNIST62 with a 2 hidden layer (300u-100) MLP. For FedProx and SCAFFOLD, in addition to tuning the learning rate, we search for the best server momentum β ∈ [0, 0.9, 0.99]. FedProx uses an additional regularizer µ which we search over [0.1, 0.5, 1] (note that FedProx with µ = 0 is the same as FedAvg). The best test accuracy (which are plotted here) was by β = 0 for both and µ = 0.1 for FedProx. Note that FedProx is the slowest method here (in fact it is even slower than FedAvg). The additional regularizer does not seem to reduce client drift while still slowing down convergence(Karimireddy et al., 2020; Wang et al., 2020b). SCAFFOLD is also slower than Mime in this setup. This is because SCAFFOLD was designed for the crosssilo setting and not corss-device setting. The large number of clients (N = 3.4k) means that each client is on averaged visited less than 6 times during the entire training (20 clients per round for 1k rounds). Hence, the client control variate stored is quite stale (from about 200 rounds ago) which slows down the convergence. This perfectly reflects our theoretical understanding that when the number of clients N is large relative to training rounds (which is true in the cross-device setting) SCAFFOLD is outperformed by MIME.



Number of communication rounds required to reach ∇f (x) 2 ≤ for L-smooth functions (log factors are ignored) with S clients sampled each round. G 2 bounds the gradient dissimilarity (A1), and δ bounds the Hessian dissimilarity (A2). FEDAVG is slower than the server-only methods due to additional drift terms. Convergence of SCAFFOLD depends on the total number of clients

Based on this observation, we introduced a new framework MIME which not only overcomes client drift, but also adapts arbitrary centralized algorithms such as Adam to the federated setting without any additional hyper-parameters. We demonstrated the superiority of MIME via strong convergence guarantees and empirical evaluations.Yue Zhao, Meng Li, Liangzhen Lai, Naveen Suda, Damon Civin, and Vikas Chandra. Federated learning with non-iid data. arXiv preprint arXiv:1806.00582, 2018.Martin Zinkevich, Markus Weimer, Lihong Li, and Alex J Smola. Parallelized stochastic gradient descent. In Advances in neural information processing systems, pages 2595-2603, 2010.In all cases we report the top-1 test accuracy in our experiments. EMNIST uses the metadata indicating the original author of the characters to separate them into multiple clients yielding naturally partitioned dataset. Table2summarizes the statistics about the different datasets. Note that the average number of rounds a client participates in (computed as sampled clients×number of rounds/number of clients) provides an indication of how much of the training data is seen with SHAKESPEARE being closest to the cross-silo setting and STACKOVERFLOW representing the most cross-device in nature. Details about the datasets used and experiment setting.

).



Additional algorithmic details: Decomposing base algorithms into a parameter update (U) and statistics tracking (V).

annex

We perform 4 tasks over 3 datasets: i) On the EMNIST62 (extended MNIST) dataset (Caldas et al., 2018b) we run a convex logistic regression model, a fully connected MLP with 2 hidden layers (300u-100), and convolution model with two CNN layers and two dense layers with dropout. ii) we also run a ResNet20 on CIFAR100.

