FEDERATED AVERAGING AS EXPECTACTION-MAXIMIZATION

Abstract

Federated averaging (FedAvg), despite its simplicity, has been the main approach in training neural networks in the federated learning setting. In this work, we show that the algorithmic choices of the FedAvg algorithm correspond to optimizing a single objective function that involves the global and all of the shard specific models using a hard version of the well known Expectation-Maximization (EM) algorithm. As a result, we gain a better understanding of the behavior and design choices of federated averaging while being able to provide interesting connections to recent literature. Based on this view, we further propose FedSparse, a version of federated averaging that employs prior distributions to promote model sparsity. In this way, we obtain a procedure that leads to reductions in both server-client and client-server communication costs as well as more efficient models.

1. INTRODUCTION

Smart devices have become ubiquitous in today's world and are generating large amounts of potentially sensitive data. Traditionally, such data is transmitted and stored in a central location for training machine learning models. Such methods rightly raise privacy concerns and we seek the means for training powerful models, such as neural networks, without the need to transmit potentially sensitive data. To this end, Federated Learning (FL) (McMahan et al., 2016) has been proposed to train global machine learning models without the need for participating devices to transmit their data to the server. The Federated Averaging (FedAvg) (McMahan et al., 2016) algorithm communicates the parameters of the machine learning model instead of the data itself, which is a more private means of communication. The FedAvg algorithm was originally proposed through empirical observations. While it can be shown that it converges (Li et al., 2019) , its theoretical understanding in terms of the model assumptions as well as the underlying objective function is still not well understood. The first contribution of this work improves our understanding of FedAvg; we show that FedAvg can be derived by applying the general Expectation-Maximization (EM) framework to a simple hierarchical model. This novel view has several interesting consequences: it sheds light on the algorithmic choices of FedAvg, bridges FedAvg with meta-learning, connects several extensions of FedAvg and provides fruitful ground for future extensions. Apart from theoretical grounding, the FL scenario poses several practical challenges, especially in the "cross-device" setting (Kairouz et al., 2019) that we consider in this work. In particular, communicating model updates over multiple rounds across a large amount of devices can incur significant communication costs. Communication via the public internet infrastructure and mobile networks is potentially slow and not for free. Equally important, training (and inference) takes place on-device and is therefore restricted by the edge-devices' hardware constraints on memory, speed and heat dissipation capabilities. Therefore, jointly addressing both of these issues is an important step towards building practical FL systems, as also discussed in Kairouz et al. (2019) . Through the novel EM view of FedAvg that we introduce, we develop our second contribution, FedSparse. FedSparse allows for learning sparse models at the client and server via a careful choice of priors within the hierarchical model. As a result, it tackles the aforementioned challenges, since it can simultaneously reduce the overall communication and computation at the client devices. Empirically, FedSparse provides better communication-accuracy trade-offs compared to both FedAvg as well as methods proposed for similar reasons (Caldas et al., 2018) .

2. FE DAV G THROUGH THE LENS OF EM

The FedAvg algorithm is a simple iterative procedure realized in four simple steps. At the beginning of each round t, the server communicates the model parameters, let them be w, to a subset of the devices. The devices then proceed to optimize w, e.g., via stochastic gradient descent, on their respective dataset via a given loss function L s (D s , w) := 1 N s Ns i=1 L(D si , w) where s indexes the device, D s corresponds to the dataset at device s and N s corresponds to its size. After a specific amount of epochs of optimization on L s is performed, denoted as E, the devices communicate the current state of their parameters, let it be φ s , to the server. The server then performs an update to its own model by simply averaging the client specific parameters w t = 1 S s φ s .

2.1. THE CONNECTION TO EM

We now ask the following question; does the overall algorithm correspond to a specific optimization procedure on a given objective function? Let us consider the following objective function: arg max w 1 S S s=1 log p(D s |w), where D s corresponds to the shard specific dataset that has N s datapoints, p(D s |w) corresponds to the likelihood of D s under the server parameters w. Now consider decomposing each of the shard specific likelihoods as follows: p(D s |w) = p(D s |φ s )p(φ s |w)dφ s , p(φ s |w) ∝ exp - λ 2 φ s -w 2 , where we introduced the auxiliary latent variables φ s , which are the parameters of the local model at shard s. The server parameters w act as "hyperparameters" for the prior over the shard specific parameters and λ acts as a regularization strength that prevents φ s from moving too far from w. How can we then optimize the resulting objective in the presence of these latent variables φ s ? The traditional way to optimize such objectives is through Expectation-Maximization (EM). EM consists of two steps, the E-step where we form the posterior distribution over these latent variables p(φ s |D s , w) = p(D s |φ s )p(φ s |w) p(D s |w) , and the M-step where we maximize the probability of D s w.r.  If we perform a single gradient step for w in the M-step, this procedure corresponds to doing gradient ascent on the original objective, a fact we show in Appendix D. When posterior inference is intractable, hard-EM is sometimes employed. In this case we make "hard" assignment for the latent variables φ s in the E-step by approximating p(φ s |D s , w) with its most probable point, i.e.  This is usually easier to do as we can use techniques such as stochastic gradient ascent. Given these hard assignments, the M-step then corresponds to another simple maximization



φ * s = arg max φs p(D s |φ s )p(φ s |w) p(D s |w) = arg max φs log p(D s |φ s ) + log p(φ s |w).

