FEDERATED AVERAGING AS EXPECTACTION-MAXIMIZATION

Abstract

Federated averaging (FedAvg), despite its simplicity, has been the main approach in training neural networks in the federated learning setting. In this work, we show that the algorithmic choices of the FedAvg algorithm correspond to optimizing a single objective function that involves the global and all of the shard specific models using a hard version of the well known Expectation-Maximization (EM) algorithm. As a result, we gain a better understanding of the behavior and design choices of federated averaging while being able to provide interesting connections to recent literature. Based on this view, we further propose FedSparse, a version of federated averaging that employs prior distributions to promote model sparsity. In this way, we obtain a procedure that leads to reductions in both server-client and client-server communication costs as well as more efficient models.

1. INTRODUCTION

Smart devices have become ubiquitous in today's world and are generating large amounts of potentially sensitive data. Traditionally, such data is transmitted and stored in a central location for training machine learning models. Such methods rightly raise privacy concerns and we seek the means for training powerful models, such as neural networks, without the need to transmit potentially sensitive data. To this end, Federated Learning (FL) (McMahan et al., 2016) has been proposed to train global machine learning models without the need for participating devices to transmit their data to the server. The Federated Averaging (FedAvg) (McMahan et al., 2016) algorithm communicates the parameters of the machine learning model instead of the data itself, which is a more private means of communication. The FedAvg algorithm was originally proposed through empirical observations. While it can be shown that it converges (Li et al., 2019) , its theoretical understanding in terms of the model assumptions as well as the underlying objective function is still not well understood. The first contribution of this work improves our understanding of FedAvg; we show that FedAvg can be derived by applying the general Expectation-Maximization (EM) framework to a simple hierarchical model. This novel view has several interesting consequences: it sheds light on the algorithmic choices of FedAvg, bridges FedAvg with meta-learning, connects several extensions of FedAvg and provides fruitful ground for future extensions. Apart from theoretical grounding, the FL scenario poses several practical challenges, especially in the "cross-device" setting (Kairouz et al., 2019) that we consider in this work. In particular, communicating model updates over multiple rounds across a large amount of devices can incur significant communication costs. Communication via the public internet infrastructure and mobile networks is potentially slow and not for free. Equally important, training (and inference) takes place on-device and is therefore restricted by the edge-devices' hardware constraints on memory, speed and heat dissipation capabilities. Therefore, jointly addressing both of these issues is an important step towards building practical FL systems, as also discussed in Kairouz et al. (2019) . Through the novel EM view of FedAvg that we introduce, we develop our second contribution, FedSparse. FedSparse allows for learning sparse models at the client and server via a careful choice of priors within the hierarchical model. As a result, it tackles the aforementioned challenges, since it can simultaneously reduce the overall communication and computation at the client devices. Empirically, FedSparse provides better communication-accuracy trade-offs compared to both FedAvg as well as methods proposed for similar reasons (Caldas et al., 2018) .

2. FE DAV G THROUGH THE LENS OF EM

The FedAvg algorithm is a simple iterative procedure realized in four simple steps. At the beginning of each round t, the server communicates the model parameters, let them be w, to a subset of the devices. The devices then proceed to optimize w, e.g., via stochastic gradient descent, on their respective dataset via a given loss function L s (D s , w) := 1 N s Ns i=1 L(D si , w) where s indexes the device, D s corresponds to the dataset at device s and N s corresponds to its size. After a specific amount of epochs of optimization on L s is performed, denoted as E, the devices communicate the current state of their parameters, let it be φ s , to the server. The server then performs an update to its own model by simply averaging the client specific parameters w t = 1 S s φ s .

2.1. THE CONNECTION TO EM

We now ask the following question; does the overall algorithm correspond to a specific optimization procedure on a given objective function? Let us consider the following objective function: arg max w 1 S S s=1 log p(D s |w), where D s corresponds to the shard specific dataset that has N s datapoints, p(D s |w) corresponds to the likelihood of D s under the server parameters w. Now consider decomposing each of the shard specific likelihoods as follows: p(D s |w) = p(D s |φ s )p(φ s |w)dφ s , p(φ s |w) ∝ exp - λ 2 φ s -w 2 , where we introduced the auxiliary latent variables φ s , which are the parameters of the local model at shard s. The server parameters w act as "hyperparameters" for the prior over the shard specific parameters and λ acts as a regularization strength that prevents φ s from moving too far from w. How can we then optimize the resulting objective in the presence of these latent variables φ s ? The traditional way to optimize such objectives is through Expectation-Maximization (EM). EM consists of two steps, the E-step where we form the posterior distribution over these latent variables p(φ s |D s , w) = p(D s |φ s )p(φ s |w) p(D s |w) , and the M-step where we maximize the probability of D s w.r.t. the parameters of the model w by marginalizing over this posterior arg max w 1 S s E p(φs|Ds,wold) [log p(D s |φ s ) + log p(φ s |w)] = arg max w 1 S s E p(φs|Ds,wold) [log p(φ s |w)]. If we perform a single gradient step for w in the M-step, this procedure corresponds to doing gradient ascent on the original objective, a fact we show in Appendix D. When posterior inference is intractable, hard-EM is sometimes employed. In this case we make "hard" assignment for the latent variables φ s in the E-step by approximating p(φ s |D s , w) with its most probable point, i.e.  where we alternate between optimizing φ 1:S and w while keeping the other fixed. How is this framework the same as FedAvg? By letting λ → 0 in Eq. 3 it is clear that the hard assignments in the E-step mimic the process of optimizing a local model on the data of each shard. In fact, even by optimizing the model locally with stochastic gradient ascent for a fixed number of iterations with a given learning rate we implicitly assume a specific prior over the parameters; for linear regression, this prior is a Gaussian centered at the initial value of the parameters (Santos, 1996) whereas for non-linear models it bounds the distance from the initial point. After obtaining φ * s the M-step then corresponds to arg max w L r := 1 S s -λ 2 φ * s -w 2 , and we can easily find a closed form solution by setting the derivative of the objective w.r.t. w to zero and solving for w: ∂L r ∂w = 0 ⇒ λ S s (φ * s -w) = 0 ⇒ w = 1 S s φ * s . It is easy to see that the optimal solution for w given φ * 1:S is the same as the one from FedAvg. Of course, FedAvg does not optimize the local parameters φ s to convergence at each round, so one might wonder whether the correspondence to EM is still valid. It turns out that the alternating procedure of EM corresponds to block coordinate ascent on a single objective function, the variational lower bound of the marginal log-likelihood (Neal & Hinton, 1998 ) of a given model. More specifically for our setting, we can see that the EM iterations perform block coordinate ascent to optimize: arg max w 1:S ,w 1 S s E qw s (φs) log p(D s |φ s ) + log p(φ s |w)] + H[q ws (φ s )] where w s are the parameters of the variational approximation to the posterior distribution p(φ s |D s , w) and H[q] corresponds to the entropy of the q distribution. To obtain the procedure of FedAvg we can use a (numerically) deterministic distribution for φ s , q ws (φ s ) := N (w s , I). This leads us to the same objective as in Eq. 7, since the expectation concentrates on a single term and the entropy of q ws (φ s ) becomes a constant independent of the optimization. In this case, the optimized value for φ s after a fixed number of steps corresponds to the w s of the variational approximation. It is interesting to contrast recent literature under the lens of this framework. Optimizing the same hierarchical model with hard-EM but with a non-trivial λ results into the same procedure that was proposed by Li et al. (2018) . Furthermore, using the difference of the local parameters to the global parameters as a "gradient" (Reddi et al., 2020) is equivalent to hard-EM on the same model where in the M-step we take a single gradient step. In addition, this view makes precise the idea that FedAvg is a meta-learning algorithm (Jiang et al., 2019) ; the underlying hierarchical model we optimize is similar to the ones used in meta-learning (Grant et al., 2018; Chen et al., 2019) . How can we then use this novel view of FedAvg to our advantage? The most straightforward way is to use an alternative prior which would result into different behaviours in local training and server side updating. For example, one could use a Laplace prior, which would result into the server selecting the median instead of averaging, or a mixture of Gaussians prior, which would result into training an ensemble of models at the server. In order to tackle the communication and computational costs, which is important for "cross-device" FL, we chose a sparsity inducing prior, namely the spike and slab prior. We describe the resulting algorithm, FedSparse, in the next section.

3. THE FE DSP A R S E ALGORITHM: SPARSITY IN FEDERATED LEARNING

Encouraging sparsity in FL has two main advantages; the model becomes smaller and less resource intensive to train and furthermore, it cuts down on communication costs as the pruned parameters do not need to be communicated. The golden standard for sparsity in probabilistic models is the spike and slab (Mitchell & Beauchamp, 1988) prior. It is a mixture of two components, a delta spike at zero, δ(0), and a continuous distribution over the real line, i.e. the slab. More specifically, by adopting a Gaussian slab for each local parameter φ si we have that p(φ si |θ i , w i ) = (1 -θ i )δ(0) + θ i N (φ si |w i , 1/λ), or equivalently as a hierarchical model p(φ si |θ i , w i ) = zsi p(z si |θ i )p(φ si |z si , w i ), p(z si ) = Bern(θ i ), p(φ si |z si = 1, w i ) = N (φ si |w i , 1/λ), p(φ si |z si = 0) = δ(0), where z si plays the role of a "gating" variable that switches on or off the parameter φ si . Now consider using this distribution for the prior over the parameters in the federated setting. w, θ will be the server side model weights and probabilities of the binary gates. In order to stay close to the FedAvg paradigm of simple point estimation and since approximate inference for complex posteriors, such as those that arise in neural networks, is still an open problem, we will perform hard-EM in order to optimize w, θ. By using approximate distributions q ws (φ s |z s ), q πs (z s ), the variational lower bound for this model becomes arg max w 1:S ,w,π 1:S ,θ 1 S s E qπ s (zs)qw s (φs|zs) log p(D s |φ s ) + log p(φ s |w, z s ) + log p(z s |θ) -log q ws (φ s |z s ) + H[q πs (z s )]. For the shard specific weight distributions, as they are continuous, we will use q ws (φ si |z si = 1) := N (w si , ), q(φ si |z si = 1) := N (0, ) with ≈ 0 which will be, numerically speaking, deterministic. For the gating variables, as they are binary, we will use q πsi (z si ) := Bern(π si ) with π si being the probability of activating local gate z si . In order to do hard-EM for the binary variables, we will remove the entropy term for the q πs (z s ) from the aforementioned bound as this will encourage the approximate distribution to move towards the most probable value for z s . Furthermore, by relaxing the spike at zero to a Gaussian with precision λ 2 , i.e., p(φ si |z si = 0) = N (0, 1/λ 2 ), and by plugging in the appropriate expressions into Eq. 13 we can show that the local and global objectives will be arg max ws,πs L s (D s , w, θ, w s ,π s ) := E qπ s (zs) Ns i L(D si , w s z s ) - λ 2 j π sj (w sj -w j ) 2 -λ 0 j π sj + j (π sj log θ j + (1 -π sj ) log(1 -θ j )) + C, (14) arg max w,θ L := 1 S S s=1 L s (D s , w, θ, w s , π s ) respectively, where λ 0 = 1 2 log λ2 λ and C is a constant independent of the variables to be optimized. The derivation can be found at Appendix E. It is interesting to see that the final objective at each shard intuitively tries to find a trade-off between four things: 1) explaining the local dataset D s , 2) having the local weights close to the server weights (regulated by λ), 3) having the local gate probabilities close to the server probabilities and 4) reducing the local gate activation probabilities so as to prune away a parameter (regulated by λ 0 ). The latter is an L 0 regularization term, akin to the one proposed by Louizos et al. (2017) . Now let us consider what happens at the server after the local shard, through some procedure, optimized w s and π s . Since the server loss for w, θ is the sum of all local losses, the gradient for each of the parameters will be ∂L ∂w = s λπ s (w s -w), ∂L ∂θ = s π s θ - 1 -π s 1 -θ . ( ) Setting these derivatives to zero, we see that the stationary points are w = 1 j π j s π s w s , θ = 1 S s π s i.e., a weighted average of the local weights and an average of the local probabilities of keeping these weights. Therefore, since the π s are being optimized to be sparse through the L 0 penalty, the server probabilities θ will also become small for the weights that are used by only a small fraction of the shards. As a result, to obtain the final sparse architecture, we can prune the weights whose corresponding server inclusion probabilities θ are less than a threshold, e.g., 0.1. It should be noted that the sums and averages of Eq. 16, 17 respectively can be easily approximated by subsampling a small set of clients S from S. Therefore we do not have to consider all of the clients at each round, which would be prohibitive for the "cross-device" setting of FL.

3.1. REDUCING THE COMMUNICATION COST

The framework described so far allow us to learn a more efficient model. We now discuss how we can use it in order to cut down both download and upload communication costs during training. Reducing client to server communication cost In order to reduce the client to server cost we will communicate sparse samples from the local distributions instead of the distributions themselves; in this way we do not have to communicate the zero values of the parameter vector. This leads to large savings, while still keeping the server gradient unbiased. More specifically, we can express the gradients and stationary points for the server weights and probabilities as follows ∂L ∂w = s λE qπ s (zs) [z s (w s -w)] , w = E qπ 1:S (z 1:S ) 1 j z j s z s w s , ∂L ∂θ = s E qπ s (zs) z s θ - 1 -z s 1 -θ , θ = 1 S s E qπ s (zs) [z s ] . ( ) As a result, we can then communicate from the client only the subset of the local weights ŵs that are non-zero in z s ∼ q πs (z s ), ŵs = w s z s , and the server can infer the state of z s by inspecting which parameters were omitted. Having access to those samples, the server can then form 1-sample stochastic estimates of either the gradients or the stationary points for w, θ. Notice that this is a way to reduce communication without adding bias in the gradients of the original objective. In case that we are willing to incur extra bias, we can further use techniques such as quantization (Amiri et al., 2020) and top-k gradient selection (Lin et al., 2017) to reduce communication even further. Such approaches are left for future work. Reducing the server to client communication cost The server needs to communicate to the clients the updated distributions at each round. Unfortunately, for simple unstructured pruning, this doubles the communication cost as for each weight w i there is an associated θ i that needs to be sent to the client. To mitigate this effect we will employ structured pruning, which introduces a single additional parameter for each group of weights. For groups of moderate sizes, e.g., the set of weights of a given convolutional filter, the extra overhead is small. We can also take the communication cost reductions one step further if we allow for some bias in the optimization procedure; we can prune the global model during training after every round and thus send to each of the clients only the subset of the model that has survived. Notice that this is easy to do and does not require any data at the server. The inclusion probabilities θ are available at the server, so we can remove the parameters that have θ less than a threshold, e.g. 0.1. This can lead to large reductions in communication costs, especially once the model becomes sufficiently sparse.

3.2. FE DSP A R S E IN PRACTICE

Local optimization While optimizing for w s locally is straightforward to do with gradient based optimizers, π s is more tricky, as the expectation over the binary variables z s in Eq. 14 is intractable to compute in closed form and using Monte-Carlo integration does not yield reparametrizable samples. To circumvent these issues, we rewrite the objective in an equivalent form and use the hard-concrete relaxation from (Louizos et al., 2017) , which can allow for the straightforward application of gradient ascent. We provide the details in Appendix F. When the client has to communicate to the server, we propose to form ŵs by sampling from the zero-temperature relaxation, which yields exact binary samples. Furthermore, at the beginning of each round, following the practice of FedAvg, the participating clients initialize their approximate posteriors to be equal to the priors that were communicated from the server. Empirically, we found that this resulted in better global accuracy. Parametrization of the probabilities Since there have been evidence that such optimization based pruning can be inferior to simple magnitude based pruning (Gale et al., 2019) , we take an approach that combines the two and reminisces the recent work of Azarian et al. (2020) . We parametrize the probabilities θ, π s as a function of the model weights and magnitude based thresholds that regulate how active a parameter can be. More specifically, we use the following parametrization θ g := σ w g 2 -τ g T , π sg := σ w sg 2 -τ sg T , where the subscript g denotes the group, σ(•) is the sigmoid function, τ g , τ sg are the global and client specific thresholds for a given group g and T is a temperature hyperparameter. Following Azarian et al. (2020) we also "detach" the gradient of the weights through θ, π s , to avoid decreasing the probabilities by just shrinking the weights. With this parametrization we lose the ability to get a closed form solution for the server thresholds, but nonetheless we can still perform gradient based optimization at the server by using the chain rule. For a positive threshold, we use a parametrization in terms of a softplus function, i.e., τ = log(1 + exp(v)) where v is the learnable parameter. Algorithm 1 The server side algorithm for FedSparse (assuming weight sparsity for simplicity). σ(•) is the sigmoid function, is the threshold for pruning. Initialize v and w for round t in 1, . . . T do τ ← log(1 + exp(v)) θ ← σ ((|w| -τ )/T ) w ← I[θ > ]w prune global model Initialize ∇ t w = 0, ∇ t v = 0 for s in random subset of the clients do ŵt s ← CLIENT(s, w, v) z s ← I[ ŵt s = 0] ∇ t w + = z s ( ŵt s -w) ∇ t v + = -(z s (1 -θ) -(1 -z s )θ) σ(v)/T end for w t+1 , v t+1 ← ADAM(∇ t w ), ADAMAX(∇ t v ) end for Algorithm 2 The client side algorithm for FedSparse. Get w, v from the server θ ← σ ((|w| -τ )/T ) w s , v s ← w, v for epoch e in 1, . . . , E do for batch b ∈ B do τ s ← log(1 + exp(v s )) π s ← σ ((|w s | -τ s )/T ) L s ← L s (b, w, θ, w s , π s ) w s ← SGD(∇ ws L s ) v s ← ADAMAX(∇ vs L s ) end for end for π s ← σ ((|w s | -τ s )/T ) z s ∼ q πs (z s ) return z s w s

4. RELATED WORK

Both contributions of this work share similarities with several recent works in federated learning. FedProx (Li et al., 2018) proposed to add a proximal term to the local objective at each shard, so that it prevents the local models from drifting too far from the global model. Through the EM view of FedAvg we show how such a local objective arises if we use a non-trivial precision for the Gaussian prior over the local parameters. Furthermore, FedAvg has been advocated to be a meta-learning algorithm in Jiang et al. (2019) ; with the EM view we make this claim precise and show that the underlying hierarchical model that FedAvg optimizes is the same as the models used in several meta-learning works (Grant et al., 2018; Chen et al., 2019) . Furthermore, by performing a single gradient step for the M-step in the EM view of FedAvg we see that we arrive at a procedure that has been previously explored both in a meta-learning context with the Reptile algorithm (Nichol et al., 2018) , as well as the federated learning context with the "generalized" FedAvg (Reddi et al., 2020) . One important difference between meta-learning and FedAvg is that the latter maximizes the average, across shards, marginal-likelihood in order to find the server / global parameters whereas meta-learning methods usually optimize the global parameters such that the finetuned model perform well on the local validation set. Exploring such parameter estimation methods, as, e.g., in Chen et al. (2019) , in the federated scenario and how these relate to existing approaches that merge meta-learning with federated learning, e.g., Fallah et al. (2020) , is an interesting avenue for future work. Finally, the EM view we provide here for FedAvg can also provide a novel perspective for optimization works that improve model performance via model replicas (Zhang et al., 2019; Pittorino et al., 2020) . Reducing the communication costs is a well known and explored topic in federated learning. FedSparse has close connections to federated dropout (Caldas et al., 2018) , as the latter can be understood via a similar hierarchical model, where gates z are global and have a fixed probability θ for both the prior and the approximate posterior. Then at each round and for each client, the server samples these gates (due to the expectation), collects the submodels and sends those to be optimized at the shard. Compared to federated dropout, FedSparse allow us to optimize the dropout rates to the data, such that they satisfy a given accuracy / sparsity trade-off, dictated by the hyperparameter λ 0 . Another benefit we gain from the EM view of FedAvg is that it makes it clear that the server can perform gradient based optimization. As a result, we can harvest the large literature on efficient distributed optimization (Lin et al., 2017; Bernstein et al., 2018; Wangni et al., 2018; Yu et al., 2019) , which involves gradient quantization, sparsification and more general compression. On this front, there have been also other works that aim to reduce the communication cost in FL via such approaches (Sattler et al., 2019; Han et al., 2020) . In general, such approaches can be orthogonal to FedSparse and exploring how they can be incorporated is a promising avenue for future research.

5. EXPERIMENTS

We verify in three tasks whether the FedSparse procedure leads to similar or better global models compared to FedAvg while providing reductions in communication costs and efficient models. As a baseline that also reduces communication costs by sparsifying the model, we consider the federated dropout procedure from Caldas et al. (2018) , which we refer to as FedDrop. For each of the three tasks, we present the results for FedSparse with regularization strengths that target three sparsity levels: low, mid and high. For the FedDrop baseline, we experiment with multiple combinations of dropout probabilities for the convolutional and fully connected layers. For each of these tasks we report the setting that performs best in term of accuracy / communication trade-off. The first task we consider is a federated version of CIFAR10 classification; we partition the data among 100 shards in a non-i.i.d. way by following Hsu et al. (2019) . For the model we employ a LeNet-5 convolutional architecture (LeCun et al., 1998) with the addition of dropout(0.1) for the second convolutional layer and dropout(0.3) for the first fully connected layer in order to prevent overfitting locally at the shard. We optimize the model for 1k communication rounds. For the second task we consider the 500 shard federated version of CIFAR100 classification from Reddi et al. (2020) . For the model we use a ResNet20, where we replace the batch normalization layers with group normalization, following Reddi et al. (2020) , and we optimize for 6k rounds. For the final task we considered the non-i.i.d. Femnist classification and we use the same configuration as CIFAR10 but we optimize the model for 6k rounds. More details can be found at Appendix A. We evaluate FedSparse and the baselines on two metrics that highlight the tradeoffs between accuracy and communication costs. On both metrics the x-axis represents the total communication cost incurred up until that point and the y-axis represents two distinct model accuracies. The first one corresponds to the accuracy of the global model on the union of the shard test sets, whereas the second one corresponds to the average accuracy of the shard specific "local models" on the shard specific test sets. The "local model" on each shard is the model configuration that the shard last communicated to the server, and serves as a proxy for the personalized model performance on each shard. The later metric is motivated from the meta-learning (Jiang et al., 2019) and EM view of federated averaging, and corresponds to using the local posteriors for prediction on the local test set instead of the server side priors.

5.1. EXPERIMENTAL RESULTS

The results from our experiments can be found in the following table and figures. Overall, we observed that the FedSparse models achieve their final sparsity ratios early in training, i.e., after 30-50 rounds, which quickly reduces the communication costs for each round (Appendix B). Table 1 : Average global, local test-set accuracies across clients in %, along with total communications costs in GB and sparsity of the final model for Cifar10, Cifar100 and Femnist. We report the average over the last 10 evaluations. We can see that for CIFAR 10, the FedSparse models with medium (∼45%) and high (∼62%) sparsity outperform all other methods for small communications budgets on the global accuracy front, but are eventually surpassed by FedDrop on higher budgets. However, on the local accuracy front, we see that the FedSparse models Pareto dominate both baselines, achieving, e.g., 87% local accuracy with 43% less communication compared to FedAvg. Overall, judging the final performance only, we see that FedDrop reaches the best accuracy on the global model, but FedSparse reaches the best accuracy in the local models. On CIFAR 100, the differences are less pronounced, as the models did not fully converge for the maximum number of rounds we use. Nevertheless, we still observe similar patterns; for small communication budgets, the sparser models are better for both the global and local accuracy as, e.g., they can reach 32% global accuracy while requiring 13% less communication than FedAvg. Finally for Femnist we observe the most differences, as the Femnist task is also more communication intensive due to having 3.5k shards. We see that the FedSparse algorithm Pareto dominates both FedDrop and FedAvg and, more specifically, in the high sparsification setting, it can reach 84% global accuracy and 89% local accuracy while requiring 41% and 51% less communication compared to FedAvg respectively. Judging by the final accuracy; both FedAvg and FedSparse with the low setting reached similar global and local model performance, which is to be expected given that that particular FedSparse setting lead to only 1% sparsity.

6. CONCLUSION

In this work, we showed how the FedAvg algorithm, the standard in federated learning, corresponds to applying a variant of the well known EM algorithm to a simple hierarchical model. Through this perspective, we bridge several recent works on federated learning as well as connect FedAvg to meta-learning. As a straightforward extension stemming from this view, we proposed FedSparse, a generalization of FedAvg with sparsity inducing priors. Empirically, we showed that FedSparse can learn sparse neural networks which, besides being more efficient, can also significantly reduce the communication costs without decreasing performance.

APPENDIX

A EXPERIMENTAL DETAILS For all of the three tasks we randomly select 10 clients without replacement in a given round but with replacement across rounds. For the local optimizer of the weights we use stochastic gradient descent with a learning rate of 0.05, whereas for the global optimizer we use Adam (Kingma & Ba, 2014) with the default hyperparameters provided in (Kingma & Ba, 2014) . For the pruning thresholds in FedSparse we used the Adamax (Kingma & Ba, 2014) optimizer with 1e -3 learning rate at the shard level and the Adamax optimizer with 1e -2 learning rate at the server. For all three of the tasks we used E = 1 with a batch size of 64 for CIFAR10 and 20 for CIFAR100 and Femnist. It should be noted that for all the methods we performed gradient based optimization using the difference gradient for the weights (Reddi et al., 2020) instead of averaging. For the FedDrop baseline, we used a very small dropout rate of 0.01 for the input and output layer and tuned the dropout rates for convolutional and fully connected layers separately in order to optimize the accuracy / communication tradeoff. For convolutional layers we considered rates in {0.1, 0.2, 0.3} whereas for the fully connected layers we considered rates in {0.1, 0.2, 0.3, 0.4, 0.5}. For CIFAR10 we did not employ the additional dropout noise at the shard level, since we found that it was detrimental for the FedDrop performance. Furthermore, for Resnet20 on CIFAR100 we did not apply federated dropout at the output layer. For CIFAR10 the best performing dropout rates were 0.1 for the convolutional and 0.5 for the fully connected, whereas for CIFAR100 it was 0.1 for the convolutional. For Femnist, we saw that a rate of 0.2 for the convolutional and a rate of 0.4 for the fully connected performed better. For FedSparse, we initialized v such that the thresholds τ lead to θ = 0.99 initially, i.e. we started from a dense model. The temperature for the sigmoid in the parameterization of the probabilities was set to T = 0.001. Furthermore, we downscaled the cross-entropy term between the client side probabilities, π s , and the server side probabilities, θ by mutltiplying it with 1e -4. Since at the beginning of each round we were always initializing π S = θ and we were only optimizing for a small number of steps before synchronizing, we found that the full strength of the cross-entropy was not necessary. Furthermore, for similar reasons, i.e. we set w s = w at the beginning of each round, we also used λ = 0 for the drift term λ 2 π sj (w s -w) 2 . The remaining hyperparameter λ 0 dictates how sparse the final model will be. For the LeNet-5 model the λ 0 's we report are {5e -7, 5e -6, 5e -5} for the "low", "mid" and "high" settings respectively, which were optimized for CIFAR10 and used as-is for Femnist. For CIFAR100 and Resnet20, we did not perform any pruning for the output layer and the λ 0 's for the "low", "mid" and "high" settings were {5e -7, 5e -6, 5e -5} respectively. These were chosen so that we obtain models with comparable sparsity ratios as the one on CIFAR10.

B EVOLUTION OF SPARSITY

We show the evolution of the sparsity ratios for all tasks and configurations in the following plot. We can see that in all settings the model attains its final sparsity quite early in training. 

C ADDITIONAL RESULTS

Convergence plots in terms of communication rounds. In order to understand whether the extra noise is detrimental to the convergence speed of FedSparse, we plot the validation accuracy in terms of communication rounds for all tasks and baselines. As it can be seen, there is no inherent difference before FedSparse starts pruning. This happens quite early in training for CIFAR 100 thus it is there where we observe the most differences. Impact of server side pruning. In order to understand whether server side pruning is harmful for convergence, we plot both the global and average local validation accuracy on CIFAR 10 for the "mid" setting of FedSparse with and without server side pruning enabled. As we can see, there are no noticeable differences and in fact, pruning results into a slightly better overall performance.  By performing EM with a single gradient step for w in the M-step (instead of full maximization), we are essentially doing gradient ascent on the original objective at 21. To see this, we can take the gradient of Eq. 21 w.r.  where to compute Eq. 24 we see that we first have to obtain the posterior distribution of the local variables φ s and then estimate the gradient for w by marginalizing over this posterior.

E DERIVATION OF THE LOCAL LOSS

Let p(φ si |w i , z si = 1) = N (w i , 1/λ), p(φ si |w i , z si = 0) = N (0, 1/λ 2 ) and q(φ si |z si = 1) = N (w si , 2 ), q(φ si |z si = 0) = N (0, 2 ). Furthermore, let q(z si ) = Bern(π si ). The local objective that stems from 13 can be rewritten as: arg max (25) where we omitted from the objective the entropy of the distribution over the local gates. One of the quantities that we are after is E q(zsi) [KL(q(φ si |z si )||p(φ si |z si ))] = π si KL(N (w si , 2 )||N (w i , 1/λ)) + (1 -π si )KL(N (0, 2 )||N (0, 1/λ 2 )). (26) The KL term for when z si = 1 can be written as KL(N (w si , 2 )||N (w i , 1/λ)) = - 1 2 log λ -log + λ 2 2 - 1 2 + λ 2 (w si -w i ) 2 . ( ) The KL term for when z si = 0 can be written as KL(N (0, 2 )||N (0, 1/λ 2 )) = - 1 2 log λ 2 -log + λ 2 2 2 - 1 2 . ( ) Taking everything together we thus have E q(zsi) [KL(q(φ si |z si )||p(φ si |z si ))] = λπ si 2 (w si -w i ) 2 + π si (- 1 2 log λ -log + λ 2 2 - 1 2 )+ (1 -π si )(- 1 2 log λ 2 -log + λ 2 2 2 - 1 2 ) (29) = λπ si 2 (w si -w i ) 2 + π si 1 2 log λ 2 λ + 2 2 (λ -λ 2 ) + C (30) ≈ λπ si 2 (w si -w i ) 2 + λ 0 π si + C, where λ 0 = 1 2 log λ2 λ and 2 2 (λ -λ 2 ) was omitted due to 2 ≈ 0. In the appendix of Louizos et al. (2017) , the authors argue about a hypothetical prior that results into needing λ nats to transform that prior to the approximate posterior. Here we make this claim more precise and show that this prior is approximately equivalent to a mixture of Gaussians prior where the precision of the non-zero prior component λ → (in order to avoid the L 2 regularization term) and the precision of the zeroth component λ 2 is equivalent to λ exp(2λ 0 ), where λ 0 is the desired L 0 regularization strength. Furthermore, the cross-entropy from q πs (z s ) to p(z s |θ) is straightforward to compute as and then replace the Bernoulli distribution q πs (z s ) with a continuous relaxation, the hard-concrete distribution (Louizos et al., 2017) . Let the continuous relaxation be r us (z s ), where u s are the parameters of the surrogate distribution. In this case the local objective becomes We can now straightforwardly optimize the surrogate objective with gradient ascent.



(a) Femnist global acc. (b) Femnist local acc.

Figure 5: Evolution of the validation accuracy in terms of communication rounds.

(a) CIFAR10 global val. acc. (b) CIFAR100 local val. acc.

Figure 6: Evolution of the validation accuracy in terms of communication rounds with and without server side pruning.

qπ s (zs)qw s (φs|zs) log p(D s |φ s ) -E qπ s (zs) KL(q ws (φ s |z s )||p(φ s |w, z s ))+ E qπ s (zs) [log p(z s |θ)],

qπ s (zs) [log p(z s |θ)] = j (π sj log θ j + (1 -π sj ) log(1 -θ j )) . si , w s z s ) -λ 2 j π sj (w sj -w j ) 2log θ j + (1 -π sj ) log(1 -θ j )) + C. (33)F LOCAL OPTIMIZATION OF THE BINARY GATESWe propose to rewrite the local loss in Eq. 13 toL s (D s , w, θ, φ s ,π s ) := E qπ s (zs) Ns i L(D si , w s z s ) -λ j I[z sj = 0](w sj -w) 2 j + log(1 -θ j ) ,(34)

s (D s , w, θ,φ s , u s ) := E ru s (zs)Ns i L(D si , w s z s ) -λ j R usj (z sj > 0)(w sj -w) 2 -λ 0 j R usj (z sj > 0) + j R usj (z sj > 0) log θ j 1 -θ j + log(1 -θ j ) , (35)where R us (•) is the cumulative distribution function (CDF) of the continuous relaxation r us (•).

As a result, hard-EM corresponds to a block coordinate ascent type of algorithm on the following objective function

Acc. L.Acc. Comm. Spars. G.Acc. L.Acc. Comm. Spars. G.Acc. L.Acc. Comm. Spars.

t. w where Z s = p(D s |φ s )p(φ s |w)dφ s

