FEDERATED LEARNING VIA POSTERIOR AVERAGING: A NEW PERSPECTIVE AND PRACTICAL ALGORITHMS

Abstract

Federated learning is typically approached as an optimization problem, where the goal is to minimize a global loss function by distributing computation across client devices that possess local data and specify different parts of the global objective. We present an alternative perspective and formulate federated learning as a posterior inference problem, where the goal is to infer a global posterior distribution by having client devices each infer the posterior of their local data. While exact inference is often intractable, this perspective provides a principled way to search for global optima in federated settings. Further, starting with the analysis of federated quadratic objectives, we develop a computation-and communicationefficient approximate posterior inference algorithm-federated posterior averaging (FEDPA). Our algorithm uses MCMC for approximate inference of local posteriors on the clients and efficiently communicates their statistics to the server, where the latter uses them to refine a global estimate of the posterior mode. Finally, we show that FEDPA generalizes federated averaging (FEDAVG), can similarly benefit from adaptive optimizers, and yields state-of-the-art results on four realistic and challenging benchmarks, converging faster, to better optima.

1. INTRODUCTION

Federated learning (FL) is a framework for learning statistical models from heterogeneous data scattered across multiple entities (or clients) under the coordination of a central server that has no direct access to the local data (Kairouz et al., 2019) . To learn models without any data transfer, clients must process their own data locally and only infrequently communicate some model updates to the server which aggregates these updates into a global model (McMahan et al., 2017) . While this paradigm enables efficient distributed learning from data stored on millions of remote devices (Hard et al., 2018) , it comes with many challenges (Li et al., 2020) , with the communication cost often being the critical bottleneck and the heterogeneity of client data affecting convergence. Canonically, FL is formulated as a distributed optimization problem with a few distinctive properties such as unbalanced and non-i.i.d. data distribution across the clients and limited communication. The de facto standard algorithm for solving federated optimization is federated averaging (FEDAVG, McMahan et al., 2017) , which proceeds in rounds of communication between the server and a random subset of clients, synchronously updating the server model after each round (Bonawitz et al., 2019) . By allowing the clients perform multiple local SGD steps (or epochs) at each round, FEDAVG can reduce the required communication by orders of magnitude compared to mini-batch (MB) SGD. However, due to heterogeneity of the client data, more local computation often leads to biased client updates and makes FEDAVG stagnate at inferior optima. As a result, while slow during initial training, MB-SGD ends up dominating FEDAVG at convergence (see example in Fig. 1 ). This has been observed in multiple empirical studies (e.g., Charles & Konečnỳ, 2020), and recently was shown theoretically (Woodworth et al., 2020a) . Using stateful clients (Karimireddy et al., 2019; Pathak & Wainwright, 2020) can help to remedy the convergence issues in the cross-silo setting, where relatively few clients are queried repeatedly, but is not practical in the cross-device setting (i.e., when clients are mobile devices) for several reasons (Kairouz et al., 2019; Li et al., 2020; Lim et al., 2020) . One key issue is that the number of clients in such a setting is extremely large and the average client will only ever participate in a single FL round. Thus, the state of a stateful algorithm is never used.

