FEDERATED LEARNING VIA POSTERIOR AVERAGING: A NEW PERSPECTIVE AND PRACTICAL ALGORITHMS

Abstract

Federated learning is typically approached as an optimization problem, where the goal is to minimize a global loss function by distributing computation across client devices that possess local data and specify different parts of the global objective. We present an alternative perspective and formulate federated learning as a posterior inference problem, where the goal is to infer a global posterior distribution by having client devices each infer the posterior of their local data. While exact inference is often intractable, this perspective provides a principled way to search for global optima in federated settings. Further, starting with the analysis of federated quadratic objectives, we develop a computation-and communicationefficient approximate posterior inference algorithm-federated posterior averaging (FEDPA). Our algorithm uses MCMC for approximate inference of local posteriors on the clients and efficiently communicates their statistics to the server, where the latter uses them to refine a global estimate of the posterior mode. Finally, we show that FEDPA generalizes federated averaging (FEDAVG), can similarly benefit from adaptive optimizers, and yields state-of-the-art results on four realistic and challenging benchmarks, converging faster, to better optima.

1. INTRODUCTION

Federated learning (FL) is a framework for learning statistical models from heterogeneous data scattered across multiple entities (or clients) under the coordination of a central server that has no direct access to the local data (Kairouz et al., 2019) . To learn models without any data transfer, clients must process their own data locally and only infrequently communicate some model updates to the server which aggregates these updates into a global model (McMahan et al., 2017) . While this paradigm enables efficient distributed learning from data stored on millions of remote devices (Hard et al., 2018) , it comes with many challenges (Li et al., 2020) , with the communication cost often being the critical bottleneck and the heterogeneity of client data affecting convergence. Canonically, FL is formulated as a distributed optimization problem with a few distinctive properties such as unbalanced and non-i.i.d. data distribution across the clients and limited communication. The de facto standard algorithm for solving federated optimization is federated averaging (FEDAVG, McMahan et al., 2017) , which proceeds in rounds of communication between the server and a random subset of clients, synchronously updating the server model after each round (Bonawitz et al., 2019) . By allowing the clients perform multiple local SGD steps (or epochs) at each round, FEDAVG can reduce the required communication by orders of magnitude compared to mini-batch (MB) SGD. However, due to heterogeneity of the client data, more local computation often leads to biased client updates and makes FEDAVG stagnate at inferior optima. As a result, while slow during initial training, MB-SGD ends up dominating FEDAVG at convergence (see example in Fig. 1 ). This has been observed in multiple empirical studies (e.g., Charles & Konečnỳ, 2020), and recently was shown theoretically (Woodworth et al., 2020a) . Using stateful clients (Karimireddy et al., 2019; Pathak & Wainwright, 2020) can help to remedy the convergence issues in the cross-silo setting, where relatively few clients are queried repeatedly, but is not practical in the cross-device setting (i.e., when clients are mobile devices) for several reasons (Kairouz et al., 2019; Li et al., 2020; Lim et al., 2020) . One key issue is that the number of clients in such a setting is extremely large and the average client will only ever participate in a single FL round. Thus, the state of a stateful algorithm is never used. Is it possible to design FL algorithms that exhibit both fast training and consistent convergence with stateless clients? In this work, we answer this question affirmatively, by approaching federated learning not as optimization but rather as posterior inference problem. We show that modes of the global posterior over the model parameters correspond to the desired optima of the federated optimization objective and can be inferred by aggregating information about local posteriors. Starting with an analysis of federated quadratics, we introduce a general class of federated posterior inference algorithms that run local posterior inference on the clients and global posterior inference on the server. In contrast with federated optimization, posterior inference can, with stateless clients, benefit from an increased amount of local computation without stagnating at inferior optima (illustrated in Fig. 1 ). However, a naïve approach to federated posterior inference is practically infeasible because its computation and communication costs are cubic and quadratic in the model parameters, respectively. Apart from the new perspective, our key technical contribution is the design of an efficient algorithm with linear computation and communication costs. Contributions. The main contributions of this paper can be summarized as follows: 1. We introduce a new perspective on federated learning through the lens of posterior inference which broadens the design space for FL algorithms beyond purely optimization techniques. 2. With this perspective, we design a computation-and communication-efficient approximate posterior inference algorithm-federated posterior averaging (FEDPA). FEDPA works with stateless clients and its computational complexity and memory footprint are similar to FEDAVG. 3. We show that FEDAVG with many local steps is in fact a special case of FEDPA that estimates local posterior covariances with identities. These biased estimates are the source of inconsistent updates and explain why FEDAVG has suboptimal convergence even in simple quadratic settings. 4. Finally, we compare FEDPA with strong baselines on realistic FL benchmarks introduced by Reddi et al. ( 2020) and achieve state-of-the-art results with respect to multiple metrics of interest. et al., 2018; Sahu et al., 2018; Hsieh et al., 2019; Li et al., 2019; Wang et al., 2020; Woodworth et al., 2020a) . Reddi et al. ( 2020) reformulated FEDAVG in a way that enabled adaptive optimization and derived corresponding convergence rates, noting that FEDAVG requires careful tuning of learning rate schedules in order to converge to the desired optimum, which was further analyzed by Charles & Konečnỳ (2020). To the best of our knowledge, our work is perhaps the first to connect, reinterpret, and analyze federated optimization from the probabilistic inference perspective.

2. RELATED WORK

Distributed MCMC. Part of our work builds on the idea of sub-posterior aggregation, which was originally proposed for scaling up Markov chain Monte Carlo techniques to large datasets (known as the concensus Monte Carlo, Neiswanger et al., 2013; Scott et al., 2016) . One of the goals of this paper is to highlight the connection between distributed inference and federated optimization and develop inference techniques that can be used under FL-specific constraints.



Figure 1: An illustration of federated learning in a toy 2D setting with two clients and quadratic objectives. Left: Contour plots of the client objectives, their local optima, as well as the corresponding global optimum. Middle: Learning curves for MB-SGD and FEDAVG with 10 and 100 steps per round. FEDAVG makes fast progress initially, but converges to a point far away from the global optimum. Right: Learning curves for FEDPA with 10 and 100 posterior samples per round and shrinkage ρ = 1. More posterior samples (i.e., more local computation) results in faster convergence and allows FEDPA to come closer to the global optimum. Shaded regions denote bootstrapped 95% CI based on 5 runs with different initializations and random seeds. Best viewed in color.

Federated optimization. Starting with the seminal paper byMcMahan et al. (2017), a lot of recent effort in federated learning has focused on understanding of FEDAVG (also known as local SGD) as an optimization algorithm. Multiple works have provided upper bounds on the convergence rate of FEDAVG in the homogeneous i.i.d. setting(Yu et al., 2019; Karimireddy et al., 2019; Woodworth et al.,  2020b)  as well as explored various non-i.i.d. settings with different notions of heterogeneity (Zhao

