FEDERATED LEARNING VIA POSTERIOR AVERAGING: A NEW PERSPECTIVE AND PRACTICAL ALGORITHMS

Abstract

Federated learning is typically approached as an optimization problem, where the goal is to minimize a global loss function by distributing computation across client devices that possess local data and specify different parts of the global objective. We present an alternative perspective and formulate federated learning as a posterior inference problem, where the goal is to infer a global posterior distribution by having client devices each infer the posterior of their local data. While exact inference is often intractable, this perspective provides a principled way to search for global optima in federated settings. Further, starting with the analysis of federated quadratic objectives, we develop a computation-and communicationefficient approximate posterior inference algorithm-federated posterior averaging (FEDPA). Our algorithm uses MCMC for approximate inference of local posteriors on the clients and efficiently communicates their statistics to the server, where the latter uses them to refine a global estimate of the posterior mode. Finally, we show that FEDPA generalizes federated averaging (FEDAVG), can similarly benefit from adaptive optimizers, and yields state-of-the-art results on four realistic and challenging benchmarks, converging faster, to better optima.

1. INTRODUCTION

Federated learning (FL) is a framework for learning statistical models from heterogeneous data scattered across multiple entities (or clients) under the coordination of a central server that has no direct access to the local data (Kairouz et al., 2019) . To learn models without any data transfer, clients must process their own data locally and only infrequently communicate some model updates to the server which aggregates these updates into a global model (McMahan et al., 2017) . While this paradigm enables efficient distributed learning from data stored on millions of remote devices (Hard et al., 2018) , it comes with many challenges (Li et al., 2020) , with the communication cost often being the critical bottleneck and the heterogeneity of client data affecting convergence. Canonically, FL is formulated as a distributed optimization problem with a few distinctive properties such as unbalanced and non-i.i.d. data distribution across the clients and limited communication. The de facto standard algorithm for solving federated optimization is federated averaging (FEDAVG, McMahan et al., 2017) , which proceeds in rounds of communication between the server and a random subset of clients, synchronously updating the server model after each round (Bonawitz et al., 2019) . By allowing the clients perform multiple local SGD steps (or epochs) at each round, FEDAVG can reduce the required communication by orders of magnitude compared to mini-batch (MB) SGD. However, due to heterogeneity of the client data, more local computation often leads to biased client updates and makes FEDAVG stagnate at inferior optima. As a result, while slow during initial training, MB-SGD ends up dominating FEDAVG at convergence (see example in Fig. 1 ). This has been observed in multiple empirical studies (e.g., Charles & Konečnỳ, 2020) , and recently was shown theoretically (Woodworth et al., 2020a) . Using stateful clients (Karimireddy et al., 2019; Pathak & Wainwright, 2020) can help to remedy the convergence issues in the cross-silo setting, where relatively few clients are queried repeatedly, but is not practical in the cross-device setting (i.e., when clients are mobile devices) for several reasons (Kairouz et al., 2019; Li et al., 2020; Lim et al., 2020) . One key issue is that the number of clients in such a setting is extremely large and the average client will only ever participate in a single FL round. Thus, the state of a stateful algorithm is never used. Is it possible to design FL algorithms that exhibit both fast training and consistent convergence with stateless clients? In this work, we answer this question affirmatively, by approaching federated learning not as optimization but rather as posterior inference problem. We show that modes of the global posterior over the model parameters correspond to the desired optima of the federated optimization objective and can be inferred by aggregating information about local posteriors. Starting with an analysis of federated quadratics, we introduce a general class of federated posterior inference algorithms that run local posterior inference on the clients and global posterior inference on the server. In contrast with federated optimization, posterior inference can, with stateless clients, benefit from an increased amount of local computation without stagnating at inferior optima (illustrated in Fig. 1 ). However, a naïve approach to federated posterior inference is practically infeasible because its computation and communication costs are cubic and quadratic in the model parameters, respectively. Apart from the new perspective, our key technical contribution is the design of an efficient algorithm with linear computation and communication costs. Contributions. The main contributions of this paper can be summarized as follows: 1. We introduce a new perspective on federated learning through the lens of posterior inference which broadens the design space for FL algorithms beyond purely optimization techniques. 2. With this perspective, we design a computation-and communication-efficient approximate posterior inference algorithm-federated posterior averaging (FEDPA). FEDPA works with stateless clients and its computational complexity and memory footprint are similar to FEDAVG. 3. We show that FEDAVG with many local steps is in fact a special case of FEDPA that estimates local posterior covariances with identities. These biased estimates are the source of inconsistent updates and explain why FEDAVG has suboptimal convergence even in simple quadratic settings. 4. Finally, we compare FEDPA with strong baselines on realistic FL benchmarks introduced by Reddi et al. (2020) and achieve state-of-the-art results with respect to multiple metrics of interest.

2. RELATED WORK

Federated optimization. Starting with the seminal paper by McMahan et al. (2017) , a lot of recent effort in federated learning has focused on understanding of FEDAVG (also known as local SGD) as an optimization algorithm. Multiple works have provided upper bounds on the convergence rate of FEDAVG in the homogeneous i.i.d. setting (Yu et al., 2019; Karimireddy et al., 2019; Woodworth et al., 2020b) as well as explored various non-i.i.d. settings with different notions of heterogeneity (Zhao et al., 2018; Sahu et al., 2018; Hsieh et al., 2019; Li et al., 2019; Wang et al., 2020; Woodworth et al., 2020a) . Reddi et al. (2020) reformulated FEDAVG in a way that enabled adaptive optimization and derived corresponding convergence rates, noting that FEDAVG requires careful tuning of learning rate schedules in order to converge to the desired optimum, which was further analyzed by Charles & Konečnỳ (2020) . To the best of our knowledge, our work is perhaps the first to connect, reinterpret, and analyze federated optimization from the probabilistic inference perspective. Distributed MCMC. Part of our work builds on the idea of sub-posterior aggregation, which was originally proposed for scaling up Markov chain Monte Carlo techniques to large datasets (known as the concensus Monte Carlo, Neiswanger et al., 2013; Scott et al., 2016) . One of the goals of this paper is to highlight the connection between distributed inference and federated optimization and develop inference techniques that can be used under FL-specific constraints.

3. A POSTERIOR INFERENCE PERSPECTIVE ON FEDERATED LEARNING

Federated learning is typically formulated as the following optimization problem: min θ∈R d F (θ) := N i=1 q i f i (θ) , f i (θ) := 1 n i ni j=1 f (θ; z ij ), where the global objective function F (θ) is a weighted average of the local objectives f i (θ) over N clients; each client's objective is some loss f (θ; z) computed on the local data D i = {z i1 , . . . , z ini }. In real-world cross-device applications, the total number of clients N can be extremely large, and hence optimization of F (θ) is done over multiple rounds with only a small subset of M clients participating in each round. The weights {q i } are typically set proportional to the sizes of the local datasets {n i }, which makes F (θ) coincide with the training objective of the centralized setting. Typically, f (θ; z) is negative log likelihood of z under some probabilistic model parametrized by θ, i.e., f (θ; z) := -log P (z | θ). For example, least squares loss corresponds to likelihood under a Gaussian model, cross entropy loss corresponds to likelihood under a categorical model, etc. (Murphy, 2012) . Thus, Eq. 1 corresponds to maximum likelihood estimation (MLE) of the model parameters θ. An alternative (Bayesian) approach to maximum likelihood estimation is posterior inference or estimation of the posterior distribution of the parameters given all the data: P (θ | D ≡ D 1 ∪ • • • ∪ D N ). The posterior is proportional to the product of the likelihood and a prior, P (θ | D) ∝ P (D | θ) P (θ), and, if the prior is uninformative (uniform over all θ), the modes of the global posterior coincide with MLE solutions or optima of F (θ) in Eq. 1. While this simple observation establishes an equivalence between the inference of the posterior mode and optimization, the advantage of this perspective comes from the fact that the global posterior exactly decomposes into a product of local posteriors.foot_0 Proposition 1 (Global Posterior Decomposition) Under the uniform prior, any global posterior distribution that exists decomposes into a product of local posteriors: P (θ | D) ∝ N i=1 P (θ | D i ). Proposition 1 suggests that as long as we are able to compute local posterior distributions P (θ | D i ) and communicate them to the server, we should be able to solve Eq. 1 by multiplicatively aggregating them to find the mode of the global posterior P (θ | D) on the server. Note that posterior inference via multiplicative averaging has been successfully used to scale Monte Carlo methods to large datasets, where the approach is embarrassingly parallel (Neiswanger et al., 2013; Scott et al., 2016) . In the FL context, this means that once all clients have sent their local posteriors to the server, we can construct the global posterior without any additional communication. However, there remains the challenge of making the local and global inference and communication efficient enough for real federated settings. The example below illustrates how this can be difficult even for a simple model and loss function. Federated least squares. Consider federated least squares regression with a linear model, where z := (x, y) and the loss f (θ; x, y) := 1 2 (x θ -y) 2 is quadratic. Then, the client objective becomes: f i (θ) = log exp 1 2 X i θ -y i 2 = log exp 1 2 (θ -µ i ) Σ -1 i (θ -µ i ) + const, where X i ∈ R ni×d is the design matrix, y i ∈ R ni is the response vector, Σ -1 i := X i X i and µ i := X i X i -1 X i y i . Note that the expression in Eq. 2 is the log likelihood for a multivariate Gaussian distribution with mean µ i and covariance Σ i . Therefore, each local posterior (under the uniform prior) is Gaussian, and, as a product of Gaussians, the global posterior is also Gaussian with the following mean (which coincides with the posterior mode): µ := N i=1 q i Σ -1 i -1 N i=1 q i Σ -1 i µ i . Concretely, in the case of least squares regression, this suggests that it is sufficient for clients to infer the means {µ i } and inverse covariances {Σ -1 i } of their local posteriors and communicate that information to server for the latter to be able to find the global optimum. However, a straightforward application of Eq. 3 would require O(d 2 ) space and O(d 3 ) computation, both on the clients and on the server, which is very expensive for the typical cross-device FL setting. Similarly, the communication cost would be O(d 2 ), while standard FL algorithms have communication cost of O(d). Note that directly computing and communicating these quantities would be completely infeasible for the realistic setting where models are neural networks with millions of parameters. In the following section, we design a practical algorithm where all costs are linear in the number of model parameters. 

4. FEDERATED POSTERIOR AVERAGING: A PRACTICAL ALGORITHM

(θ) := 1 2 θ Aθ -b θ, where A := N i=1 q i Σ -1 i and b := N i=1 q i Σ -1 i µ i . Proposition 2 allows us to obtain a good estimate of µ by running stochastic optimization of the quadratic objective Q(θ) on the server. Note that the gradient of Q(θ) has the following form: ∇Q(θ) := N i=1 q i Σ -1 i (θ -µ i ), which suggests that we can obtain µ by using the same Algorithm 1 as FEDAVG but using different client updates: ∆ i := Σ -1 i (θ -µ i ). Importantly, as long as clients are able to compute ∆ i 's, this approach will result in O(d) communication and O(d) server computation cost per round.

Algorithm 4 IASG Sampling (CLIENTMCMC)

input initial θ, loss fi(θ), optimizer CLIENTOPT(α), B: burn-in steps, K: steps per sample, : # samples. // Burn-in 1: for step t = 1, . . . , B do 2: θ ← CLIENTOPT(θ, ∇fi(θ)) 3: end for // Sampling 4: for sample s = 1, . . . , do 5: S θ ← ∅ // Initialize iterates 6: for step t = 1, . . . , K do 7: θ ← CLIENTOPT(θ, ∇fi(θ)) 8: S θ ← S θ ∪ {θ} 9: end for 10: θs ← AVERAGE(S θ ) // Average iterates 11: end for output samples {θ1, . . . , θ } (2) Efficient local posterior inference. To compute ∆ i , each client needs to be able to estimate the local posterior means and covariances. We propose to use stochastic gradient Markov chain Monte Carlo (SG-MCMC, Welling & Teh, 2011; Ma et al., 2015) for approximate sampling from local posteriors on the clients, so that these samples can be used to estimate μi 's and Σi 's. Specifically, we use a variant of SG-MCMCfoot_3 with iterate averaging (IASG, Mandt et al., 2017) , which involves: (a) running local SGD for some number of steps to mix in the Markov chain, then (b) continued running of SGD for more steps to periodically produce samples via Polyak averaging (Polyak & Juditsky, 1992) of the intermediate iterates (Algorithm 4). The more computation we can run locally on the clients each round, the more posterior samples can be produced, resulting in better estimates of the local moments. (3) Efficient computation of the deltas. Even if we can obtain samples { θ1 , . . . , θ } via MCMC and use them to estimate local moments, μi and Σi , computing ∆ i naïvely would still require inverting a d × d matrix, i.e., O(d 3 ) compute and O(d 2 ) memory. The good news is that we are able to show that clients can compute ∆ i 's much more efficiently, in O(d) time and memory, using a dynamic programming algorithm and appropriate mean and covariance estimators. Theorem 3 Given approximate posterior samples { θ1 , . . . , θ }, let μ be the sample mean, Ŝ be the sample covariance, and Σ := ρ I + (1 -ρ ) Ŝ be a shrinkage estimator (Ledoit & Wolf, 2004b) of the covariance with ρ := 1/(1 + ( -1)ρ) for some ρ ∈ [0, +∞). Then, for any θ, we can compute ∆ = Σ-1 (θ -μ ) in O( 2 d) time and using O( d) memory. Proof [Sketch] We give a constructive proof by designing an efficient algorithm for computing ∆ . Our approach is based on two key ideas: 1. We prove that the specified shrinkage estimator of the covariance has a recursive decomposition into rank-1 updates, i.e., Σt = Σt-1 + c t • x t x t , where c t is a constant and x t is some vector. This allows us to leverage the Sherman-Morrison formula for computing the inverse of Σ . 2. Further, we design a dynamic programming algorithm for computing ∆ exactly without storing the covariance matrix or its inverse. Our algorithm is online and allows efficient updates of ∆ as more posterior samples become available. See Appendix C for the full proof and derivation of the algorithm. Note that the computational cost of ∆ consists of two components: (i) the cost of producing approximate local posterior samples using IASG and (ii) the cost of solving a linear system using dynamic programming. How much of an overhead does it add compared to simply running local SGD? It turns out that in practical settings the overhead is almost negligible. Table 1 shows the time it takes a client to compute the updates based on 5 local epochs (100 steps per epoch) using different algorithms (FEDAVG vs. our approach with exact or dynamic programming (DP) matrix inversion) on synthetic linear regressions. As the dimensionality grows, computational complexity of DP-based estimation of ∆ becomes nearly identical to FEDAVG, which indicates that the majority of the cost in practice would come from SGD steps rather than our dynamic programming procedure. The final algorithm, discussion, and implications. Putting all the pieces together, we arrive at the federated posterior averaging (FEDPA) algorithm for approximately computing the mode of the global posterior over multiple communication rounds. Our algorithm is a variant of generalized federated optimization (Algorithm 1) with a new client update procedure (Algorithm 3). Importantly, this also implies that FEDAVG can be viewed as posterior inference algorithm that estimates Σ with an identity and, as a result, obtains biased client deltas ∆FEDAVG := I(θμ). In Fig. 1 in the introduction, we demonstrate the differences in behavior between FEDAVG and FEDPA that stem from the differences in their client updates. Biased client updates make FEDAVG converge to a suboptimal point; moreover, increasing local computation only pushes the fixed point further away from the global optimum. On the other hand, FEDPA converges faster and to a better optimum, trading off bias for slightly more variance (becomes visible only closer to convergence). We see that FEDPA also substantially benefits from more local computation (more local samples). Since the main difference between FEDAVG and FEDPA is, in fact, the bias-variance trade off in the server gradient estimates (Eq. 4), we can view both methods as biased SGD (Ajalloeian & Stich, 2020 ) and reason about their convergence rates as well as distances between their fixed points and correct global optima as functions of the gradient bias. In Appendix A, we provide further details, discuss convergence, empirically quantify the bias and variance of the client updates for both methods, and analyse the effects of the sampling-based approximations on the behavior of FEDPA.

5. EXPERIMENTS

Using a suite of realistic benchmark tasks introduced by Reddi et al. ( 2020), we evaluate FEDPA against several competitive baselines: the best versions of FEDAVG with adaptive optimizers as well as MIME (Karimireddy et al., 2020)-a recently-proposed FEDAVG variant that also works with stateless clients, but uses control-variates and server-level statistics to mitigate convergence issues. (Caldas et al., 2018) , CIFAR100 was partitioned randomly into 600 clients with a realistic heterogeneous structure (Reddi et al., 2020) , and StackOverflow was partitioned by its unique users. All datasets were preprocessed using the code provided by Reddi et al. (2020) . Methods and models. We use a generalized framework for federated optimization (Algorithm 1), which admits arbitrary adaptive server optimizers and expects clients to compute model deltas. As a baseline, we use federated averaging with adaptive optimizers (or with momentum) on the server and refer to it as FEDAVG-1E or FEDAVG-ME, which stands for 1 or multiple local epochs performed by clients at each round, respectively. 4 The number of local epochs in the multi-epoch versions is a hyperparameter. We use the same framework for federated posterior averaging and refer to it as FEDPA-ME. As our clients use IASG to produce approximate posterior samples, collecting Hyperparameters. For hyperparameter tuning, we first ran small grid searches for FEDAVG-ME using the best server optimizer and corresponding learning rate grids from Reddi et al. (2020) . Then, we used the best FEDAVG-ME configuration and did a small grid search to tune the additional hyperparameters of FEDPA-ME, which turned out not to be very sensitive (i.e., many configurations provided results superior to FEDAVG). More hyperparameter details can be found in Appendix D. Metrics. Since both speed of learning as well as final performance are important quantities for federated learning, we measure: (i) the number of rounds it takes the algorithm to attain a desired level of an evaluation metric and (ii) the best performance attained within a specified number of rounds. For EMNIST-62, we measure the number of rounds it takes different methods to achieve 84% and 86% evaluation accuracyfoot_5 , and the best validation accuracy attained within 500 and 1500 rounds. For CIFAR-100, we use the same metrics but use 30% and 40% as evaluation accuracy cutoffs and 1000 and 1500 as round number cutoffs. Finally, for StackOverflow, we measure the the number of rounds it takes to the best performance and evaluation accuracy (for the NWP task) and precision, recall at 5, macro-and micro-F1 (for the LR task) attained by round 1500. We note that the total number of rounds was selected based on computational considerations (to ensure reproducibility within a reasonable amount of computational cost) and the intermediate cutoffs were selected qualitatively to highlight some performance points of interest. In addition, we provide plots of the evaluation loss and other metrics for all methods over the course of training which show a much fuller picture of the behavior of the algorithms (most of the plots are given in Appendix E). Implementation and reproducibility. All our experiments on the benchmark tasks were conducted in simulation using TensorFlow Federated (TFF, Ingerman & Ostrowski, 2019) . Synthetic experiments were conducted using JAX (Bradbury et al., 2018) . The JAX implementation of the algorithms is available at https://github.com/alshedivat/fedpa. The TFF implementation will be released through https://github.com/google-research/federated. (Reddi et al., 2020) . ‡ the best results taken from (Karimireddy et al., 2020) . * results were only available for the method trained to 1000 rounds.

5.2. RESULTS ON BENCHMARK TASKS

The effects of posterior correction of client deltas. As we demonstrated in Section 4, FEDPA essentially generalizes FEDAVG and only differs in the computation done on the clients, where we compute client deltas using an estimator of the local posterior inverse covariance matrix, Σ -1 i , which requires sampling from the posterior. To be able to use SG-MCMC for local sampling, we first run FEDPA in the burn-in regime (which is identical to FEDAVG) for a number of rounds to bring the server state closer to the clients' local optima,foot_6 after which we "turn on" the local posterior sampling. The effect of switching from FEDAVG to FEDPA for CIFAR-100 (after 400 burn-in rounds) and StackOverflow LR (after 800 burn-in rounds) is presented on Figs. 2a and 2b , respectively. 7 During the burn-in phase, evaluation performance is identical for both methods, but once FEDPA starts computing client deltas using local posterior samples, the loss immediately drops and the convergence trajectory changes, indicating that FEDPA is able to avoid stagnation and make progress towards a better optimum. Similar effects are observed across all other tasks (see Appendix E). 8While the improvement of FEDPA over FEDAVG on some of the tasks is visually apparent (Fig. 2 ), we provide a more detailed comparison of the methods in terms of the speed of learning and the attained performance on all four benchmark tasks, summarized in Table 3 and discussed below. Results on EMNIST-62 and CIFAR-100. In Tables 3a and 3b , we present a comparison of FEDPA against: tuned FEDAVG with a fixed client learning rate (denoted FEDAVG-1E and FEDAVG-ME), the best variation of adaptive FEDAVG from Reddi et al. (2020) with exponentially decaying client learning rates (denoted AFO), and MIME of Karimireddy et al. (2020) . With more local epochs, we see significant improvement in terms of speed of learning: both FEDPA-ME and FEDAVG-ME achieve 84% accuracy on EMNIST-62 in under 100 rounds (similarly, both methods attain 30% on CIFAR-100 by round 350). However, more local computation eventually hurts FEDAVG leading to worse optima: on EMNIST-62, FEDAVG-ME is not able to consistently achieve 86% accuracy within 1500 rounds; on CIFAR-100, it takes extra 350 rounds for FEDAVG-ME to get to 40% accuracy. Finally, federated posterior averaging achieves the best performance on both tasks in terms of evaluation accuracy within the specified limit on the number of training rounds. On EMNIST-62 in particular, the final performance of FEDPA-ME after 1500 training rounds is 87.3%, which, while only a 0.5% absolute improvement, bridges 41.7% of the gap between the centralized model accuracy (88%) and the best federated accuracy from previous work (86.8%, Reddi et al., 2020) . Results on StackOverflow NWP and LR. Results for StackOverflow are presented in Table 3c . Although not as pronounced as for image datasets, we observe some improvement of FEDPA over FEDAVG here as well. For NWP, we have an accuracy gain of 0.4% over the best baseline. For the LR task, we compare methods in terms of average precision, recall at 5, and macro-/micro-F1. The first two metrics have appeared in some prior FL work, while the latter two are the primary evaluation metrics typically used in multi-label classification work (Gibaja & Ventura, 2015) . Interestingly, while FEDPA underperforms in terms of precision and recall, it substantially outperforms in terms of micro-and macro-averaged F1, especially the macro-F1. This indicates that while FEDAVG learns a model that can better predict high-frequency labels, FEDPA learns a model that better captures rare labels (Yang, 1999; Yang & Liu, 1999) . Interestingly, note while FEDPA improves on F1 metrics and has almost the same recall at 5, it's precision after 1500 rounds is worse than FEDAVG. A more detailed discussion along with training curves for each evaluation metric are provided in Appendix E.

6. CONCLUSION AND FUTURE DIRECTIONS

In this work, we presented a new perspective on federated learning based on the idea of global posterior inference via averaging of local posteriors. Applying this perspective, we designed a new algorithm that generalizes federated averaging, is similarly practical and efficient, and yields state-of-the-art results on multiple challenging benchmarks. While our algorithm required a number of specific approximation and design choices, we believe that the underlying approach has potential to significantly broaden the design space for FL algorithms beyond purely optimization techniques. Limitations and future work. As we mentioned throughout the paper, our method has a number of limitations due to the design choices, such as specific posterior sampling and covariance estimation techniques. While in the appendix we analyzed the effects of some of these design choices, exploration of: (i) other sampling strategies, (ii) more efficient covariance estimators (Hsieh et al., 2013) , (iii) alternatives to MCMC (such as variational inference), and (iv) more general connections with Bayesian deep learning are all interesting directions to pursue next. Finally, while there is a known, interesting connection between posterior sampling and differential privacy (Wang et al., 2015) , better understanding of privacy implications of posterior inference in federated settings is an open question.

A PRELIMINARY ANALYSIS AND ABLATIONS

In Section 4, we derived federated posterior averaging (FEDPA) starting with the global posterior decomposition (Proposition 1, which is exact) and applying the following three approximations: 1. The Laplace approximation of the local and global posterior distributions. 2. The shrinkage estimation of the local moments. 3. Approximate sampling from the local posteriors using MCMC. We have also observed that FEDAVG is a special case of FEDPA (from the algorithmic point of view), since it can be viewed as also using the Laplace approximation for the posteriors, but estimating local covariances Σi 's with identities and local means using the final iterates of local SGD. In this section, we analyze the effects of approximations 2 and 3 on the convergence of FEDPA. Specifically, we first discuss the convergence rates of FEDAVG and FEDPA as biased stochastic gradient optimization methods (Ajalloeian & Stich, 2020) . We show how the bias and variance of the client deltas behave for FEDAVG and FEDPA as functions of the number samples. We also analyze the quality of samples produced by IASG (Mandt et al., 2017) and how they depend on the amount of local computation and hyperparameters. Our analyses are conducted empirically on synthetic data. A.1 DISCUSSION OF THE CONVERGENCE OF FEDPA VS. FEDAVG First, observe that if each client is able to perfectly estimate their ∆ i = Σ -1 i (θ -µ), the problem solved by Algorithm 1 simply becomes an optimization of a quadratic objective using unbiased stochastic gradients, ∆ := 1 M M i=1 ∆ i . The noise in the gradients in this case comes from the fact that the server interacts with only a small subset of M out of N clients in each round. This is a classical stochastic optimization problem with well-known convergence rates under some assumptions on the norm of the stochastic gradients (e.g., Nemirovski et al., 2009) . The rate of convergence for SGD with a O(t -1 ) decaying learning rate used on the server is O(1/ √ t). It can be further improved to O(1/t) using Polyak momentum (Polyak, 1964) or iterate averaging (Polyak & Juditsky, 1992) . In reality, both FEDAVG and FEDPA produce biased estimates ∆FEDAVG and ∆FEDPA , respectively. Thus, we can analyze the problem as SGD with biased stochastic gradient estimates and let ∆t := ∇F (θ t ) + b(θ t ) + n(θ t ) where b(θ t ) and n(θ t , ξ) are bias and noise terms. Following Ajalloeian & Stich (2020) , we can further assume that the bias and noise terms are norm-bounded as follows. Assumption 4 ((m, ζ 2 )-bounded bias) There exist constants 0 ≤ m < 1 and ζ 2 ≥ 0 such that b(θ) 2 ≤ m ∇F (θ) 2 + ζ 2 , ∀θ ∈ R d . (5) Assumption 5 ((M, σ 2 )-bounded noise) There exist constants 0 ≤ M < 1 and σ 2 ≥ 0 such that E ξ n(θ, ξ) 2 ≤ M ∇F (θ) 2 + σ 2 , ∀θ ∈ R d . Under these general assumptions, the following convergence result holds. Theorem 6 (Ajalloeian & Stich (2020) , Theorem 2) Let F (θ) be L-smooth. Then SGD with a learning rate α := min 1 L , 1-m 2M L , LF σ 2 T 1/2 and gradients that satisfy Assumptions 4, 5 achieves the vicinity of a stationary point, E ∇F (θ ) 2 = O ε + ζ 2 1-m , in T iterations, where T = O 1 ε 1 + M 1 -m + σ 2 ε(1 -m) LF 1 -m . ( ) Note that SGD with biased gradients is able to converge to a vicinity of the optimum determined by the bias term ζ 2 /(1 -m). For FEDAVG, since the bias is not countered, this term determines the distance between the stationary point and the true global optimum. For FEDPA, since ∆FEDPA → ∆ with more local samples, the bias should vanish as we increase the amount of local computation. Determining the precise statistical dependence of the gradient bias on the local samples is beyond the scope of this work. However, to gain more intuition about the differences in behavior of FEDPA and FEDAVG, below we conduct an empirical analysis of the bias and variance of the estimated client deltas on synthetic least squares problems, for which exact deltas can be computed analytically. Published as a conference paper at ICLR 2021 (c) FEDPA bias and variance as functions of the shrinkage parameter. For dimensionality 10, 100, and 1000, the number of local steps was fixed to 5,000, 10,000, and 50,000, respectively. Variance (Fro) Quantifying empirically the bias and variance of ∆ for FEDPA and FEDAVG. We measure the empirical bias and variance of the client deltas computed by each of the methods on the synthetic least squares linear regression problems generated according to Guyon (2003) using the make_regression function from scikit-learn. 9 The problems were generated as follows: for each dimensionality (10, 100, and 1000 features), we generated 10 random least squares problems, each of which consisted of 500 synthetic data points. Next, for each of the problems we generated 10 random initial model parameters {θ 1 , . . . , θ 10 } and for each of the parameters we computed the exact ∆ i as well as ∆FEDAVG,i and ∆FEDPA,i for different numbers of local steps; for ∆FEDPA we also varied the shrinkage hyperparameter. Using these sample estimates, we further computed the L 2 -norm of the bias and the Frobenius norm of the covariance matrices as functions of the number of local steps. The results are presented on Fig. 3 . From Fig. 3a , we see that as the amount of local computation increases, the bias in FEDAVG delta estimates grows and the variance reduces. For FEDPA (Fig. 3b ), the trends turn out to be the opposite: as the number of local steps increases, the bias consistently reduces; the variance initially goes up, but with enough samples joins the downward trend. Note that the initial upward trend in the variance is due to the fact that we used the same fixed shrinkage ρ regardless of the number of local steps. To avoid sharp increases in the variance, ρ must be selected for each number of local steps separately; Fig. 3c demonstrates how the bias and variance depend on the shrinkage hyperparameter for some fixed number of local steps.foot_10 

A.2 ANALYSIS OF THE QUALITY OF IASG-BASED SAMPLING AND COVARIANCE

The more and better samples we can obtain locally, the lower the bias and variance of the gradients of Q(θ) will be, resulting in faster convergence to a fixed point closer to the global optimum. For local sampling, we proposed to use a variant of SG-MCMC called Iterate Averaged Stochastic Gradient (IASG) developed by Mandt et al. (2017) , given in Algorithm 4. The algorithm generates samples by simply averaging every K intermediate iterates produced by a client optimizer (typically, SGD with some a fixed learning rate α) after skipping the first B iterates as a burn-in phase.foot_11  How good are the samples produced by IASG and how do different parameters of the algorithm affect the quality of the samples? To answer this question, we run IASG on synthetic least squares problems, for which we can compute the actual posterior distribution and measure the quality of the samples by evaluating the effective sample size (ESS, Liu, 1996; Owen, 2013) . Given approximate posterior samples {θ 1 , . . . , θ }, the ESS statistic can be computed as follows: ESS {θ i } j=1 :=   j=1 w j   2 j=1 w 2 j , where weights w j must be proportional to the posterior probabilities, or equivalently to the loss. Effects of the dimensionality, the number of data points, and IASG parameters on ESS. The results of our synthetic experiments are presented below in Fig. 4 . The takeaways are as follows: • More burn-in steps (or epochs) generally improve the quality of samples. • The larger the number of steps per sample the better (less correlated) the samples are. • The learning rate is the most sensitive and important hyperparameter-if too large, IASG might diverge (happened in the 1000 dimensional case); if too small, the samples become correlated. • Finally, the quality of the samples deteriorates with the increase in the number of dimensions. 

B PROOFS

Proposition 1 (Global Posterior Decomposition) Under the uniform prior, any global posterior distribution that exists decomposes into a product of local posteriors: P (θ | D) ∝ N i=1 P (θ | D i ). Proof Under the uniform prior, the following equivalence holds for P (θ | D) as a function of θ: P (θ | D) ∝ P (D | θ) = z∈D P (z | θ) = N i=1 z∈Di P (z | θ) local likelihood ∝ N i=1 P (θ | D i ) The proportionality constant between the left and right hand side in Eq. 8 is N i=1 P (D i ) /P (D). Proposition 2 (Global Posterior Inference) The global posterior mode µ given in Eq. 3 is the minimizer of a quadratic Q(θ) := 1 2 θ Aθb θ, where A : = N i=1 q i Σ -1 i and b := N i=1 q i Σ -1 i µ i . Proof The statement of the proposition (implicitly) assumes that all matrix inverses exist. Then, the quadratic Q(θ) is positive definite (PD) since A is PD as a convex combination of PD matrices Σ -1 i . Thus, the quadratic has a unique solution θ where the gradient of the objective vanishes: Aθ -b = 0 ⇒ θ = A -1 b = N i=1 q i Σ -1 i -1 N i=1 q i Σ -1 i µ i ≡ µ, which implies that µ is the unique minimizer of Q(θ).

C COMPUTATION OF CLIENT DELTAS VIA DYNAMIC PROGRAMMING

In this section, we provide a constructive proof for the following theorem by designing an efficient algorithm for computing ∆ := Σ-1 (θ -μ ) on the clients in time and memory linear in the number of dimensions d of the parameter vector θ ∈ R d . Theorem 3 Given approximate posterior samples { θ1 , . . . , θ }, let μ be the sample mean, Ŝ be the sample covariance, and Σ := ρ I + (1 -ρ ) Ŝ be a shrinkage estimator (Ledoit & Wolf, 2004b) of the covariance with ρ := 1/(1 + ( -1)ρ) for some ρ ∈ [0, +∞). Then, for any θ, we can compute ∆ = Σ-1 (θ -μ ) in O( 2 d) time and using O( d) memory. The naïve computation of update vectors (i.e., where we first estimate μ and Σ from posterior samples and use them to compute deltas) requires O(d 2 ) storage and O(d 3 ) compute on the clients and is both computationally and memory intractable. We derive an algorithm that, given posterior samples, allows us to compute ∆ using only O( d) memory and O( 2 d) compute. The algorithm makes use of the following two components: 1. The shrinkage estimator of the covariance (Ledoit & Wolf, 2004b) , which is known to be well-conditioned even in high-dimensional settings (i.e., when the number of samples is smaller than the number of dimensions) and is widely used in econometrics (Ledoit & Wolf, 2004a ) and computational biology (Schäfer & Strimmer, 2005) . 2. Incremental computation of Σ-1 (θ -μ ) that exploits the fact that each new posterior sample only adds a rank-1 component to Σ and applies the Sherman-Morrison formula to derive a dynamic program for updating ∆ . Notation. For the of this discussion, we denote θ (i.e., the server state broadcasted to the clients at round t) as x 0 , drop the client index i, denote posterior samples as x j , sample mean as x := 1 j=1 x j , and sample covariance as Ŝ := 1 -1 j=1 (x j -x )(x j -x ) . C.1 THE SHRINKAGE ESTIMATOR OF THE COVARIANCE Ledoit & Wolf (2004b) proposed to estimate a high-dimensional covariance matrix using a convex combination of identity matrix and sample covariance (known as the LW or shrinkage estimator): Σ (ρ ) := ρ I + (1 -ρ )S , where ρ is a scalar parameter that controls the bias-variance tradeoff of the estimator. As an aside, while ρ can be arbitrary and the optimal ρ requires knowing the true covariance Σ, there are near-optimal ways to estimate ρ from the samples (Chen et al., 2010) , which we discuss at the end of this section. In this section, we focus on deriving an expression for ρ t as a function of t = 1, . . . , that ensures that the difference between Σt and Σt-1 is a rank-1 matrix (this is not the case for arbitrary ρ's). Derivation of a shrinkage estimator that admits rank-1 updates. Consider the following matrix: Σt := I + β t Ŝt , where β t is a scalar function of t = 1, 2, . . . , . We would like to find β t such that Σt = Σt-1 +γ t U t , where U t is a rank-1 matrix, i.e., the following equality should hold: β t Ŝt = β t-1 Ŝt-1 + γ t U t (12) To determine the functional form of β t , we need recurrent relationships for xt and Ŝt . For the former, note that the following relationship holds for two consecutive estimates of the sample mean, xt-1 and xt : xt = (t -1)x t-1 + x t t = xt-1 + 1 t (x t -xt-1 ) This allows us to expand Ŝt as follows: (t -1) Ŝt = t j=1 (x j -xt )(x j -xt ) = t j=1 x j -xt-1 - x t -xt-1 t x j -xt-1 - x t -xt-1 t = t-1 j=1 (x j -xt-1 ) (x j -xt-1 ) =(t-2) Ŝt-1 -2 x t -xt-1 t t-1 j=1 (x j -xt-1 ) =0 + t -1 t 2 (x t -xt-1 )(x t -xt-1 ) + t -1 t 2 (x t -xt-1 )(x t -xt-1 ) = (t -2) Ŝt-1 + t -1 t (x t -xt-1 )(x t -xt-1 ) Thus, we have the following recurrent relationship between Ŝt and Ŝt-1 : Ŝt = t -2 t -1 Ŝt-1 + 1 t (x t -xt-1 )(x t -xt-1 ) Now, we can plug (15) into ( 12) and obtain the following equation: β t t -2 t -1 Ŝt-1 + β t t (x t -xt-1 )(x t -xt-1 ) = β t-1 S t-1 + γ t U t , which implies that U t := (x t -xt-1 )(x t -xt-1 ) , γ t := β t /t, and the following telescoping expressions for β t : β t = t -1 t -2 β t-1 = t -1 t -2 • t -2 t -3 β t-2 = • • • = (t -1)β 2 , Now, equipped with these two recurrences, given a stream of samples x 1 , x 2 , . . . , x t , . . . , we compute ∆t for t ≥ 2 based on x t , {u k } t-1 k=1 , {v k-2,k-1 } t-1 k=1 and ∆t-1 using the following two steps: 1. Compute u t and v t-1,t using the second recurrence. 2. Compute ∆t from u t , v t-1,t , and ∆t-1 using the first recurrence. For each new sample in the sequence, we repeat the two steps to obtain the updated ∆t estimate, until we have processed all samples. Note that the first step requires O(t) vector-vector multiplies, i.e., O(td) compute, and O(d) memory, and the second step a O(1) number of vector-vector multiplies. As a result, the computational complexity of estimating ∆ is O( 2 d) and the storage needed for the dynamic programming state represented by a tuple {u k } t-1 k=1 , {v k-2,k-1 } t-1 k=1 , ∆t-1 is O( d). The any-time property of the resulting algorithm. Interestingly, the above algorithm is online as well as any-time in the following sense: as we keep sampling more from the posterior, the estimate of ∆ keeps improving, but if stopped at any time, the algorithm still produces the best possible estimate under the given time constraint. If the posterior sampler is stopped during the burn-in phase or after having produced only 1 posterior sample, the returned delta will be identical to FEDAVG. By spending more compute on the clients (and a bit of extra memory), with each additional posterior sample x t , we have ∆t -→ t→∞ Σ -1 (x 0 -µ). Optimal selection of ρ. Note that to be able to run the above described algorithm in an online fashion, we have to select and commit to a ρ before seeing any samples. Alternatively, if the online and any-time properties of the algorithm are unnecessary, we can first obtain posterior samples {x k } k=1 , then infer a near-optimal ρ from these samples-e.g., using the Rao-Blackwellized version of the LW estimator (RBLW) or the oracle approximating shrinkage (OAS), both proposed and analyzed by Chen et al. ( 2010)-and then use the inferred ρ to compute the corresponding delta using our dynamic programming algorithm.

D DETAILS ON THE EXPERIMENTAL SETUP

In this part, we provide additional details on our experimental setup, including a more detailed description of the datasets and tasks, models, methods, and hyperparameters.

D.1 DATASETS, TASKS, AND MODELS

Statistics of the datasets used in our empirical study can be found in Table 2 . All the datasets and tasks considered in our study are a subset of the tasks introduced by Reddi et al. (2020) . EMNIST-62. The dataset is comprised of 28 × 28 images of handwritten digits and lower and upper case English characters (62 different classes total). The federated version of the dataset was introduced by Caldas et al. ( 2018), and is partitioned by the author of each character. The heterogeneity of the dataset is coming from the different writing style of each author. We use this dataset for the character recognition task, termed EMNIST CR in Reddi et al. (2020) and the same model architecture, which is a 2-layer convolutional network with 3 × 3 kernel, max pooling, and dropout, followed by a 128-unit fully connected layer. The model was adopted from the TensorFlow Federated library: https://bit.ly/3l41LKv.

CIFAR-100.

The federated version of CIFAR-100 was introduced by Reddi et al. (2020) . The training set of the dataset is partitioned among 500 clients, 100 data points per client. The partitioning was created using a two-step latent Dirichlet allocation (LDA) over to "coarse" to "fine" labels which created a label distribution resembling a more realistic federated setting. For the model, also following Reddi et al. (2020) , we used a modified ResNet-18 with group normalization layer instead of batch normalization, as suggested by Hsieh et al. (2019) . The model was adopted from the TensorFlow Federated library: https://bit.ly/33jMv6g. SERVEROPT SGD (m = 0.9) SGD (m = 0.9) Adam (τ = 10 -3 ) Adagrad (τ = 10 -5 ) CLIENTOPT SGD (m = 0.9) SGD (m = 0.9) SGD (m = 0.0) SGD (m = 0.9) # clients p/round 100 20 10 10 StackOverflow. The dataset consists of text (questions and answers) asked and answered by the total of 342,477 unique users, collected from https://stackoverflow.com. The federated version of the dataset partitions it into clients by the user. In addition, questions and answers in the dataset have associated metadata, which includes tags. We consider two tasks introduced by Reddi et al. ( 2020): the next word prediction task (NWP) and the tag prediction task via multi-label logistic regression. The vocabulary of the dataset is restricted to 10,000 most frequently used words for each task (i.e., the NWP task becomes a multi-class classification problem with 10,000 classes). The tags are similarly restricted to 500 most frequent ones (i.e., the LR task becomes a multi-label classification proble with 500 labels). For tag prediction, we use a simple linear regression model where each question or answer are represented by a normalized bag-of-words vector. The model was adopted from the TensorFlow Federated library: https://bit.ly/2EXjAeY. For the NWP task, we restrict each client to the first 128 sentences in their dataset, perform padding and truncation to ensure that sentences have 20 words, and then represent each sentence as a sequence of indices corresponding to the 10,000 frequently used words, as well as indices representing padding, out-of-vocabulary (OOV) words, beginning of sentence (BOS), and end of sentence (EOS). We note that accuracy of next word prediction is measured only on the content words and not on the OOV, BOS, and EOS symbols. We use an RNN model with 96-dimensional word embeddings (trained from scratch), 670-dimensional LSTM layer, followed by a fully connected output softmax layer. The model was adopted from the TensorFlow Federated library: https://bit.ly/2SoSi3X.

D.2 METHODS

As mentioned in the main text, we used FEDAVG with adaptive server optimizers with 1 or multiple local epochs per client as our baselines. For each task, we selected the best server optimizer based on the results reported by Reddi et al. (2020) , given in Table 4 . We emphasize, even though we refer to all our baseline methods as FEDAVG, the names of the methods as given by Reddi et al. (2020) should be FEDAVGM for EMNIST-62 and CIFAR-100, FEDADAM for StackOverflow NWP and FEDADAGRAD for StackOverflow LR. Another difference between our baselines and Reddi et al. ( 2020) is that we ran SGD with momentum on the clients for EMNIST-62, CIFAR-100, and StackOverflow LR, as that improved performance of the methods with multiple epochs per client. Our FEDPA methods used the same configurations as FEDAVG baselines; moreover, FEDPA and FEDAVG were identical (algorithmically) during the burn-in phase and only different in the client-side computation during the sampling phase of FEDPA. 

D.3 HYPERPARAMETERS AND GRIDS

All hyperparameter grids are given in Table 5 . The best server and client learning rates were selected based on the FEDAVG performance and used for FEDPA. The best selected hyperparameters are given in Table 6 .

E ADDITIONAL EXPERIMENTAL RESULTS

We provide additional experimental results. As mentioned in the main text, the results presented in Table 3 were selected to highlight the differences between the methods with respect to two metrics of interest: (i) the number of rounds until the desired performance, and (ii) the performance achievable within a fixed number of rounds. A much fuller picture is given by the learning curves of each method. Therefore, we plot evaluation losses, accuracies, and metrics of interest over the course of training. On the plots, individual values at each round are indicated with ×-markers and the 10-round running average with a line of the corresponding color. EMNIST-62. Learning curves for FEDAVG and FEDPA on EMNIST-62 are given in Fig. 5 . Fig. 5a shows the best FEDAVG-1E, FEDAVG-5E, and FEDPA-5E models and Fig. 5b shows the best FEDAVG-20E, and FEDPA-20E. Apart from the fact that multi-epoch versions converge significantly faster than the 1-epoch FEDAVG-1E, note that the effect of bias reduction when switching from the burn-in to sampling in FEDPA becomes much more pronounced in the 20-epoch version. CIFAR-100 and StackOverflow. Learning curves for various models on CIFAR-100 and Stack-Overflow tasks are presented in Figs. 6 and 7 . The takeaways for CIFAR-100 and StackOverflow NWP are essentially the same as for EMNIST-62-much faster convergence with the increased number of local epochs and visually noticeable improvement in losses and accuracies due to sampling-based bias correction in client deltas after the burn-in phase is over. Interestingly, we see that on StackOverflow LR task FEDAVG-1E clearly dominates multi-epoch methods in terms of the loss and recall at 5, losing in precision and macro-F1. Even more puzzling is the significant drop in the average precision of FEDPA-ME after the switching to sampling, while at the same time a jump in recall and F1 metrics. This indicates that the global model moves to a different fixed point where it over-predicts positive labels (i.e., less precise) but also less likely to miss rare labels (i.e., higher recall on rare labels, and as a result a jump in macro-F1). The reason why this happens, however, is unclear.



Note that from the optimization point of view, the global optimum generally cannot be represented as any weighted combination of the local optima even in simple 2D settings (see Fig. 1, left). Gaussian posteriors can be further generalized to the exponential family for which closed form expressions can be obtained under appropriate priors(Wainwright & Jordan, 2008). We leave this extension to future work. While in this work we use a variant of SG-MCMC, other techniques such as HMC(Neal et al., 2011) or NUTs (Hoffman & Gelman, 2014) can be used, too. We leave analysis of alternative approaches to future work. Reddi et al. (2020) referred to federated averaging with adaptive server optimizers as FEDADAM, FEDYOGI, etc. Instead, we select the best optimizer for each task and refer to the corresponding method simply as FEDAVG. Centralized optimization of the CNN model on EMNIST-62 attains the evaluation accuracy of 88%. If SGD cannot reach the vicinity of clients' local optima within the specified number of local steps or epochs, estimated local means and covariances based on the SGD iterates can be arbitrarily poor. The number of burn-in rounds is a hyperparamter and was selected for each task to maximize performance. See more details in Appendix D. We note that running burn-in for a fixed number of rounds before switching to sampling was a design choice; other, more adaptive strategies for determining when to switch from burn-in to sampling are certainly possible (e.g., use local loss values to determine when to start sampling). We leave such alternatives as future work. https://scikit-learn.org/stable/modules/generated/sklearn.datasets. make_regression.html One could also use posterior samples to estimate the best possible ρ that balance the bias-variance tradeoff (e.g.,Chen et al., 2010) and avoids sharp increases in the variance. Note that in our experiments in Section 5, instead of using B local steps for burn-in at each round, we used several initial rounds as burn-in-only rounds, running FEDPA in the FEDAVG regime.



Figure 1: An illustration of federated learning in a toy 2D setting with two clients and quadratic objectives. Left: Contour plots of the client objectives, their local optima, as well as the corresponding global optimum. Middle: Learning curves for MB-SGD and FEDAVG with 10 and 100 steps per round. FEDAVG makes fast progress initially, but converges to a point far away from the global optimum. Right: Learning curves for FEDPA with 10 and 100 posterior samples per round and shrinkage ρ = 1. More posterior samples (i.e., more local computation) results in faster convergence and allows FEDPA to come closer to the global optimum. Shaded regions denote bootstrapped 95% CI based on 5 runs with different initializations and random seeds. Best viewed in color.

Figure 2: Evaluation metrics for FEDAVG and FEDPA computed at each training round on (a) CIFAR-100 and (b) StackOverflow LR.During the initial rounds (the "burn-in phase"), FEDPA computes deltas the same way as FEDAVG; after that, FEDPA computes deltas using Algorithm 3 and approximate posterior samples. a single sample per epoch is optimal(Mandt et al., 2017). Thus FEDPA-ME uses M samples to estimate client deltas and has the same local and global computational complexity as FEDAVG-ME but with two extra hyperparameters: the number of burn-in rounds and the shrinkage coefficient ρ from Theorem 3. As inReddi et al. (2020), we use the following model architectures for each task: CNN for EMNIST-62, ResNet-18 for CIFAR-100, LSTM for StackOverflow NWP, and multi-label logistic regression on bag-of-words vectors for StackOverflow LR (for details see Appendix D). Hyperparameters. For hyperparameter tuning, we first ran small grid searches for FEDAVG-ME using the best server optimizer and corresponding learning rate grids fromReddi et al. (2020). Then, we used the best FEDAVG-ME configuration and did a small grid search to tune the additional hyperparameters of FEDPA-ME, which turned out not to be very sensitive (i.e., many configurations provided results superior to FEDAVG). More hyperparameter details can be found in Appendix D. Metrics. Since both speed of learning as well as final performance are important quantities for federated learning, we measure: (i) the number of rounds it takes the algorithm to attain a desired level of an evaluation metric and (ii) the best performance attained within a specified number of rounds. For EMNIST-62, we measure the number of rounds it takes different methods to achieve 84% and 86% evaluation accuracy 5 , and the best validation accuracy attained within 500 and 1500 rounds. For CIFAR-100, we use the same metrics but use 30% and 40% as evaluation accuracy cutoffs and 1000 and 1500 as round number cutoffs. Finally, for StackOverflow, we measure the the number of rounds it takes to the best performance and evaluation accuracy (for the NWP task) and precision, recall at 5, macro-and micro-F1 (for the LR task) attained by round 1500. We note that the total number of rounds was selected based on computational considerations (to ensure reproducibility within a reasonable amount of computational cost) and the intermediate cutoffs were selected qualitatively to highlight some performance points of interest. In addition, we provide plots of the evaluation loss and other metrics for all methods over the course of training which show a much fuller picture of the behavior of the algorithms (most of the plots are given in Appendix E). Implementation and reproducibility. All our experiments on the benchmark tasks were conducted in simulation using TensorFlow Federated (TFF,Ingerman & Ostrowski, 2019). Synthetic experiments were conducted using JAX(Bradbury et al., 2018). The JAX implementation of the algorithms is available at https://github.com/alshedivat/fedpa. The TFF implementation will be released through https://github.com/google-research/federated.

FEDAVG bias and variance as functions of the number of local steps. FEDPA bias and variance as functions of the number of local steps. The burn-in steps were not included. For dimensionality 10, 100, and 1000, the shrinkage ρ was fixed to 0.01, 0.005, and 0.001, respectively.

Figure 3: The bias and variance tradeoffs for FEDAVG and FEDPA as functions of the estimation parameters.

Figure4: The ESS statistics for samples produced by IASG on random synthetic least squares linear regression problems of dimensionality 10, 100, 1000. Total number of data points per problem: 500, batch size: 10. In (a) and (b) the learning rate was set to 0.1 for 10 and 100 dimensions, and 0.01 for 1000 dimensions.

Computational complexity of the client updates for methods that use 5 local epochs measured in milliseconds (% denotes relative increase).

Statistics on the data and tasks. The number of examples per client are given with one standard deviation across the corresponding set of clients (denoted with ±). See description of the tasks in the text.

Comparison of FEDPA with baselines. All metrics were computed on the evaluation sets and averaged over the last 100 rounds before the round limit was reached. The "number of rounds to accuracy" was determined based on the 10-round running average crossing the threshold for the first time. The arrows indicate whether higher (↑) or lower (↓) is better. The best performance in each column is denoted in bold.

Selected optimizers for each task. For SGD, m denotes momentum. For Adam, β1 = 0.9, β2 = 0.99.

Hyperparameter grids for each task.

The best selected hyperparameters for each task.

ACKNOWLEDGMENTS

The authors would like to thank Zachary Charles for the invaluable feedback that influenced the design of the methods and experiments, and Brendan McMahan, Zachary Garrett, Sean Augenstein, Jakub Konečný, Daniel Ramage, Sanjiv Kumar, Sashank Reddi, Jean-François Kagy for many insightful discussions, and Willie Neiswanger for helpful comments on the early drafts.

annex

where we set β 2 ≡ ρ ∈ [0, +∞) to be a constant. Thus, if we define Σt := I + ρ(t -1) Ŝt , then the following recurrent relationships will hold: Σ1 = I,. . .Finally, we can obtain a shrinkage estimator of the covariance from Σn by normalizing coefficients:Note that Σ1 ≡ I and Σt → S t as t → ∞.

C.2 COMPUTING DELTAS USING SHERMAN-MORRISON AND DYNAMIC PROGRAMMING

Since Σ is proportional to Σ and the latter satisfies recurrent rank-1 updates given in Eq. 18, denoting u := xx -1 , we can express Σ-1 = Σ-1 /ρ using the Sherman-Morrison formula:Note that we would like to estimate ∆ := Σ-1 (x 0 -x ), which can be done without computing or storing any matrices if we know Σ-1 -1 u and Σ-1 -1 (x 0 -x ).Denoting ∆t := Σ-1 t (x 0 -xt ), and knowing that x 0 -x = (x 0 -x -1 ) -u / (which follows from Eq. 13), we can compute ∆ using the following recurrence:∆t = ∆t /ρ t // final step for ∆t (24) Remember that our goal is to avoid storing d × d matrices throughout the computation. In the above recursive equations, all expressions depend only on vector-vector products except the one for v t-1,t which needs a matrix-vector product. To express the latter one in the form of vector-vector products, we need another 2-index recurrence on v i,j :(a) EMNIST-62: Evaluation loss and accuracy for FEDAVG-1E, FEDAVG-5E, and FEDPA-5E. 

