FEDERATED LEARNING AS VARIATIONAL INFERENCE: A SCALABLE EXPECTATION PROPAGATION APPROACH

Abstract

The canonical formulation of federated learning treats it as a distributed optimization problem where the model parameters are optimized against a global loss function that decomposes across client loss functions. A recent alternative formulation instead treats federated learning as a distributed inference problem, where the goal is to infer a global posterior from partitioned client data (Al-Shedivat et al., 2021) . This paper extends the inference view and describes a variational inference formulation of federated learning where the goal is to find a global variational posterior that well-approximates the true posterior. This naturally motivates an expectation propagation approach to federated learning (FedEP), where approximations to the global posterior are iteratively refined through probabilistic message-passing between the central server and the clients. We conduct an extensive empirical study across various algorithmic considerations and describe practical strategies for scaling up expectation propagation to the modern federated setting. We apply FedEP on standard federated learning benchmarks and find that it outperforms strong baselines in terms of both convergence speed and accuracy. 1

1. INTRODUCTION AND BACKGROUND

Many applications of machine learning require training a centralized model over decentralized, heterogeneous, and potentially private datasets. For example, hospitals may be interested in collaboratively training a model for predictive healthcare, but privacy rules might require each hospital's data to remain local. Federated Learning (FL, McMahan et al., 2017; Kairouz et al., 2021; Wang et al., 2021) has emerged as a privacy-preserving training paradigm that does not require clients' private data to leave their local devices. FL introduces new challenges on top of classic distributed learning: expensive communication, statistical/hardware heterogeneity, and data privacy (Li et al., 2020a) . The canonical formulation of FL treats it as a distributed optimization problem where the model parameters θ are trained on K (potentially private) datasets D = k∈[K] D k , θ = arg min θ L(θ), where L(θ) = ∑ k∈[K] -log p(D k | θ). Standard distributed optimization algorithms (e.g., data-parallel SGD) are too communicationintensive to be practical under the FL setup. Federated Averaging (FedAvg, McMahan et al., 2017) reduces communication costs by allowing clients to perform multiple local SGD steps/epochs before the parameter updates are sent back to the central server and aggregated. However, due to client data heterogeneity, more local computations could lead to stale or biased client updates, and hence sub-optimal behavior (Charles & Konečnỳ, 2020; Woodworth et al., 2020; Wang et al., 2020a) . An alternative approach is to consider a Bayesian formulation of the FL problem (Al-Shedivat et al., 2021) . Here, we are interested in estimating the posterior of parameters p(θ | D) given a prior p(θ) (such as an improper uniform or a Gaussian prior) and a collection of client likelihoods p(D k | θ) that are independent given the model parameters, p(θ | D) ∝ p(θ) ∏ k∈[K] p(D k | θ). In this case the posterior naturally factorizes across partitioned client data, wherein the global posterior equates to a multiplicative aggregate of local factors (and the prior). However, exact posterior inference is in general intractable for even modestly-sized models and datasets and requires approx-imate inference techniques. In this paper we turn to variational inference, in effect transforming the federated optimization problem into a distributed inference problem. Concretely, we view the solution of federated learning as the mode of a variational (posterior) distribution q ∈ Q with some divergence function D(•∥•) (e.g., KL-divergence), θ = arg max θ q(θ), where q(θ) = arg min q∈Q D (p (θ | D) ∥ q (θ)) . (1) Under this approach, clients use local computation to perform posterior inference (instead of parameter/gradient estimation) in parallel. In exchange, possibly fewer lockstep synchronization and communication steps are required between clients and servers. One way to operationalize Eq. 1 is through federated posterior averaging (FedPA, Al-Shedivat et al., 2021) , where each client independently runs an approximate inference procedure and then sends the local posterior parameters to the server to be multiplicatively aggregated. However, there is no guarantee that independent approximations to local posteriors will lead to a good global approximate posterior. Motivated by the rich line of work on variational inference on streaming/partitioned data (Broderick et al., 2013; Vehtari et al., 2020) , this work instead considers an expectation propagation (EP, Minka, 2001) approach to FL. In EP, each partition of the data maintains its own local contribution to the global posterior that is iteratively refined through probabilistic message-passing. When applied to FL, this results in an intuitive training scheme where at each round, each client (1) receives the current approximation to the global posterior from the centralized server, (2) carries out local inference to update its local approximation, and (3) sends the refined approximation to the server to be aggregated. Conceptually, this federated learning with expectation propagation (FedEP) approach extends FedPA by taking into account the current global approximation in step (2). However, scaling up classic expectation propagation to the modern federated setting is challenging due to the high dimensionality of model parameters and the large number of clients. Indeed, while there is some existing work on expectation propagation-based federated learning (Corinzia et al., 2019; Kassab & Simeone, 2022; Ashman et al., 2022) , they typically focus on small models (fewer than 100K parameters) and few clients (at most 100 clients). In this paper we conduct an extensive empirical study across various algorithmic considerations to scale up expectation propagation to contemporary benchmarks (e.g., models with many millions of parameters and datasets with hundreds of thousands of clients). When applied on top of modern FL benchmarks, our approach outperforms strong FedAvg and FedPA baselines.

2. FEDERATED LEARNING WITH EXPECTATION PROPAGATION

The probabilistic view from Eq. 1 motivates an alternative formulation of federated learning based on variational inference. First observe that the global posterior p (θ | D) given a collection of datasets D = k∈[K] D k factorizes as, p (θ | D) ∝ p(θ) K ∏ k=1 p(D k | θ) = K ∏ k=0 p k (θ), where for convenience we define p 0 (θ) := p(θ) to be the prior and further use p k (θ ) := p(D k | θ) to refer to the local likelihood associated with k-th data partition. To simplify notation we hereon refer to the global posterior as p global (θ) and drop the conditioning on D. Now consider an approximating global posterior q global (θ) that admits the same factorization as the above, i.e., q global (θ) ∝ ∏ K k=0 q k (θ). Plugging in these terms into Eq. 1 gives the following objective, arg max θ K ∏ k=0 q k (θ), where {q k (θ)} K k=0 = arg min q k ∈Q D K ∏ k=0 p k (θ) ∥ K ∏ k=0 q k (θ) . ( ) Here Q is the variational family, which is assumed to be the same for all clients. This global objective is in general intractable; evaluating ∏ k p k (θ) requires accessing all clients' data and violates the standard FL assumption. This section presents a probabilistic message-passing algorithm based on expectation propagation (EP, Minka, 2001) .

2.1. EXPECTATION PROPAGATION

EP is an iterative algorithm in which an intractable target density p global (θ) is approximated by a tractable density q global (θ) using a collection of localized inference procedures. In EP, each local inference problem is a function of just p k and the current global estimate, making it appropriate for the FL setting. 1: for round t = 1, . . . , T do 2: Sample a subset of clients K.

3:

Broadcast q global (θ) to the selected clients. 4: for each client k ∈ K in parallel do 5: ∆q k (θ) ← ClientInfer(q global (θ)) 6: end for 7: Collect ∆q k (θ) from the selected clients. 8: q global (θ) ← ServerInfer({∆q k (θ)} k ) 9: end for 10: Return µ global . Algorithm 2 Approximate Inference: MCMC 1: Input: q \k (θ; D k , η -k , Λ -k ) 2: S k ← {} 3: for i = 1, . . . , N do 4: θ (i) k ← SGDEpoch(-log q \k , θ (i-1) k ) 5: S k ← S k ∪ θ (i) k 6: end for 7: η \k , Λ \k ← EstimateMoments(S k ) 8: Output: q \k (θ; η \k , Λ \k ) Algorithm 3 Gaussian EP: Server Inference 1: Receive: {∆q k (θ; ∆η k , ∆Λ k )} k 2: q new global ∝ q global ∏ k (∆q k ) δ // Sec. 2.2.3 η global ← η global + δ ServerOptim( ∑ k ∆η k ) Λ global ← Λ global + δ ServerOptim( ∑ k ∆Λ k ) 3: Send: q global (θ; η global , Λ global ) Algorithm 4 Gaussian EP: Client Inference 1: Receive: q global (θ; η global , Λ global ) 2: q -k ∝ q global /q k // cavity distribution η -k ← η global -η k , Λ -k ← Λ global -Λ k 3: q \k ≈ q \k ∝ p k q -k // tilted inference (Sec. 2.2.2) η \k , Λ \k ← ApproxInference(q \k ∝ p k q -k ) 4: ∆q k ∝ q \k /q global // client deltas (Sec. A.1) ∆η k ← η \k -η global , ∆Λ k ← Λ \k -Λ global 5: q new k ∝ q k (∆q k ) δ // local update (Sec. 2.2.3) η k ← η k + δ ClientOptim(∆η k ) Λ k ← Λ k + δ ClientOptim(∆Λ k ) 6: Send: ∆q k (θ; ∆η k , ∆Λ k ) Concretely, EP iteratively solves the following problem (either in sequence or parallel), q new k (θ) = arg min q∈Q D p k (θ) q -k (θ) ∝ q \k (θ) ∥ q(θ) q -k (θ) ∝ q \k (θ) , where q -k (θ) ∝ q global (θ) q k (θ) . Here q global (θ) and q k (θ) are the global/local distributions from the current iteration. (See Sec. A.2 for further details). In the EP literature, q -k (θ) is referred to as the cavity distribution and q \k (θ) and q \k (θ) are referred to as the target/approximate tilted distributions. EP then uses q new k (θ) to derive q new global (θ). While the theoretical properties of EP are still not well understood (Minka, 2001; Dehaene & Barthelmé, 2015; 2018) , it has empirically been shown to produce good posterior approximations in many cases (Li et al., 2015; Vehtari et al., 2020) . When applied to FL, the central server initiates the update by sending the parameters of the current global approximation q global (θ) as messages to the subset of clients K. Upon receiving these messages, each client updates the respective local approximation q new k (θ) and sends back the changes in parameters as messages, which is then aggregated by the server. Algorithms 1-4 illustrate the probabilistic message passing with the Gaussian variational family in more detail. Remark. Consider the case where we set q -k (θ) ∝ 1 (i.e., an improper uniform distribution that ignores the current estimate of the global parameters). Then Eq. 3 reduces to federated learning with posterior averaging (FedPA) from Al-Shedivat et al. (2021) , q new k (θ) = arg min q∈Q D (p k (θ) ∥ q(θ)). Hence, FedEP improves upon FedPA by taking into account the global parameters and the previous local estimate while deriving the local posterior.foot_1 

2.2. SCALABLE EXPECTATION PROPAGATION

While federated learning with expectation propagation is conceptually straightforward, scaling up FedEP to modern models and datasets is challenging. For one, the high dimensionality of the parameter space of contemporary models can make local inference difficult even with simple mean-field Gaussian variational families. This is compounded by the fact that classic expectation propagation is stateful and therefore requires that each client always maintains its local contribution to the global posterior. These factors make classic EP potentially an unideal approach in settings where the clients may be resource-constrained and/or the number of clients is large enough that each client is updated only a few times during the course of training. This section discusses various algorithmic consideration when scaling up FedEP to contemporary federated learning benchmarks. 2.2.1 VARIATIONAL FAMILY Following prior work on variational inference in high-dimensional parameter space (Graves, 2011; Blundell et al., 2015; Zhang et al., 2019; Osawa et al., 2019) , we use the mean-field Gaussian variational family for Q, which corresponds to multivariate Gaussian distributions with diagonal covariance. Although non-diagonal extensions are possible (e.g., through shrinkage estimators (Ledoit & Wolf, 2004 )), we empirically found the diagonal to work well while being simple and communication-efficient. For notational simplicity, we use the following two parameterizations of a Gaussian distribution interchangeably, q(θ) = N (θ; µ, Σ) = N (θ; η, Λ) , where Λ := Σ -1 , η := Σ -1 µ. Conveniently, both products and quotients of Gaussian distributions-operations commonly used in EP-result in another Gaussian distribution, which simplifies the calculation of the cavity distribution q -k (θ) and the global distribution q global (θ). 3

2.2.2. CLIENT INFERENCE

At each round of training, each client must estimate q \k (θ), its own approximation to the tilted distribution q \k (θ) in Eq. 3. We study various approaches for this estimation procedure. Stochastic Gradient Markov Chain Monte Carlo (SG-MCMC). SG-MCMC (Welling & Teh, 2011; Ma et al., 2015) uses stochastic gradients to approximately sample from local posteriors. We follow Al-Shedivat et al. ( 2021) and use a simple variant of SGD-based SG-MCMC, where we collect a single sample per epoch to obtain a set of samples S k = {θ 4 The SGD objective in this case is the unnormalized tilted distribution, (1) k , . . . , θ (N) k }. -∑ z∈D k log p(z | θ) -log p k (θ) + 1 2 θ ⊤ Λ -k θ -η ⊤ -k θ -log q -k (θ) , which is simply the client negative log likelihood (-log p k (θ)) plus a regularizer that penalizes parameters that have low probability under the cavity distribution (-log q -k (θ)). This connection makes it clear that the additional client computation compared to FedAvg (which just minimizes the client negative log-likelihood) is negligible. Given a set of samples S k from SG-MCMC, we estimate the parameters of the tilted distribution q \k (θ) with moment matching, i.e., q \k (θ) = N (θ; µ \k , Σ \k ) where µ \k , Σ \k ← MomentEstimator(S k ). While the mean obtained from S k via averaging empirically worked well, the covariance estimation was sometimes unstable. We next discuss three alternative techniques for estimating the covariance. SG-MCMC with Scaled Identity Covariance. Our simplest approach approximates the covariance as a scaled identity matrix with a tunable hyper-parameter α cov , i.e., Σ \k ← α cov I. This cuts down the communication cost in half since we no longer have to send messages for the covariance parameters. While extremely simple, we found scaled identity covariance to work well in practice. Laplace Approximation. Laplace's method approximates the covariance as the inverse Hessian of the negative log-likelihood at the (possibly approximate) MAP estimate. Since the exact inverse Hessian is intractable, we follow common practice and approximate it with the diagonal Fisher, Σ \k ← H k + Σ -1 -k -1 , where H k ≈ diag E x∼D k , y∼p(y|x,θ) ∇ θ log p(y | θ, x) 2 diagonal Fisher approximation . ( ) 3 Specifically, we have the following identities, N (θ; η 1 , Λ 1 ) N (θ; η 2 , Λ 2 ) ∝ N (θ; η 1 + η 2 , Λ 1 + Λ 2 ) , N (θ; η 1 , Λ 1 ) N (θ; η 2 , Λ 2 ) ∝ N (θ; η 1 -η 2 , Λ 1 -Λ 2 ) . This approach samples the input x from the data D k and the output y from the current model p(y | x, θ), as recommended by Kunstner et al. (2019) . The Fisher approximation requires additional epochs of backpropagation on top of usual SGD (usually 5 in our case), which requires additional client compute. Natural Gradient Variational Inference. Our final approach uses natural-gradient variational inference (NGVI, Zhang et al., 2018; Khan et al., 2018; Osawa et al., 2019) , which incorporates the geometry of the distribution to enable faster convergence. Most existing work on NGVI assume a zero-mean isotropic Gaussian prior. We extend NGVI to work with arbitrary Gaussian priorsnecessary for regularizing towards the cavity distribution in FedEP. Specifically, NGVI iteratively computes the following for t = 1 . . . T NGVI and learning rate β NGVI , Σ \k,t ← |D k |s t + Σ -1 -k -1 , where s t ← β NGVI s t-1 + (1 -β NGVI ) E θ∼q \k,t-1 1 |D k | Fisher(θ) . Here Fisher(•) is the diagonal Fisher approximation in Eq. 4 but evaluated at a sample of parameters from q \k,t (θ), the approximate posterior using the current estimate of Σ \k,t . We give the exact NGVI update (which is algorithmically similar to the Adam optimizer (Kingma & Ba, 2015)) in Algorithm 5 in the appendix.

2.2.3. ADAPTIVE OPTIMIZATION AS DAMPING

Given the approximate tilted distribution q \k (θ) and the corresponding parameters µ \k , Σ \k , we can in principle follow the update equation in Eq. 2 to estimate q new global (θ). However, adaptive optimizers have been shown to be crucial for scaling federated learning to practical settings (Reddi et al., 2020) , and the vanilla EP update does not immediately lend itself to adaptive updates. This section describes an adaptive extension to EP based on damping, in which we to re-interpret a damped EP update as a gradient update on the natural parameters, which allows for the use of adaptive optimizers. Damping performs client updates only partially with step size δ and is commonly used in parallel EP settings (Minka & Lafferty, 2002; Vehtari et al., 2020) . Letting ∆q k (θ) ∝ q \k (θ)/q global (θ) denote the client "update" distribution, we can simplify the update and arrive at the following intuitive form (Vehtari et al., 2020 ) (see Sec. A.1 for derivation), Client: q new k (θ) ∝ q k (θ) ∆q k (θ) δ , Server: q new global (θ) ∝ q global (θ) ∏ k ∆q k (θ) δ . Recalling that products of Gaussian distributions yields another Gaussian distribution that simply sums the natural parameters, the damped update for η is given by, Client: η k ← η k + δ∆η k , Server η global ← η global + δ ∑ k∈K ∆η k . (The update on the precision Λ is analogous.) By re-interpreting the update distribution ∆q k (θ; ∆η k , ∆Λ k ) as a "gradient", we can apply off-the-shelf adaptive optimizers , Client: η k ← η k + δ optim(∆η k ), Server η global ← η global + δ optim( ∑ k∈K ∆η k ). All our FedEP experiments (and the FedAvg and FedPA baselines) employ adaptive optimization.

2.2.4. STOCHASTIC EXPECTATION PROPAGATION FOR STATELESS CLIENTS

Clients are typically assumed to be stateful in the classic formulations of expectation propagation. However, there are scenarios in which stateful clients are infeasible (e.g., memory constraints) or even undesirable (e.g., large number of clients who only participate in a few update rounds, leading to stale messages). We thus additionally experiment with a stateless version of FedEP via stochastic expectation propagation (SEP, Li et al., 2015) . SEP employs direct iterative refinement of a global approximation comprising the prior p(θ) and K copies of a single approximating factor q k (θ), q global (θ) ∝ p(θ) q k (θ) K . That is, clients are assumed to capture the average effect. In practice, FedSEP is implemented in Algorithm 4 via replacing the cavity update (step 2) with q -k (θ) ∝ q global (θ)/q k (θ) and removing the local update (step 5). 

3. EXPERIMENTS

We empirically study FedEP across various benchmarks. We start with a toy setting in Sec. 3.1 where we examine cases where federated posterior average (FedPA, Al-Shedivat et al., 2021) , which does not take into account global and other clients' approximations during client inference, performs sub-optimally. We then turn to realistic federated learning benchmarks in Sec. 3.2, where both the size of the model and the number of clients are much larger than had been previously considered in prior EP-based approaches to federated learning (Corinzia et al., 2019; Kassab & Simeone, 2022) . Here, we resort to the techniques discussed in Sec. 2.2: approximate inference of the tilted distributions, adaptive optimization, and possibly stateless clients. Finally, we conclude in Sec. 3.3 with an analysis of some of the observations from the benchmark experiments.

3.1. TOY EXPERIMENTS

We start with a simple toy setting to illustrate the differences between FedPA and FedEP. Here the task is to infer the global mean from two clients, each of which is parameterized as a twodimensional Gaussian, p k (θ) = N (θ; µ k , Σ k ) for k ∈ {1, 2}. Assuming an improper uniform prior, the global distribution is then also a Gaussian with its posterior mode coinciding with the global mean. We perform exact inference via analytically solving D KL (q \k ∥ q \k ), but restrict the variational family to Gaussians with diagonal covariance (i.e., mean-field family). In this case both the FedAvg and FedPA solution can be derived in "one-shot". Fig. 1 illustrates a simple case where posterior averaging performs sub-optimally. On the other hand, expectation propagation iteratively refines the approximations toward the globally optimal estimation. We study this phenomena more systematically by sampling random client distributions, where the client parameters are sampled from the normal-inverse-Wishart (NIW) distribution, µ k ∼ N µ | µ 0 , 1 λ Σ k , Σ k ∼ W -1 (Σ | Ψ, ν). Here we set the hyper-prior mean µ 0 = 0, degrees of freedom ν = 7, scale λ = 0.2, and sample a random symmetric positive definite matrix for Ψ. Table 4 shows the average Euclidean distances between the estimated and target global mean for FedAvg, FedPA, and FedEP averaged over 200 random samples of client distributions. Experimental results demonstrate that iterative message passing in FedEP consistently improves upon the sub-optimal solution from posterior averaging.

3.2. BENCHMARKS EXPERIMENTS

We next conduct experiments on a suite of realistic benchmark tasks introduced by Reddi et al. (2020). Table 1 summarizes the model and raw dataset statistics, which is the same as in Al-Shedivat et al. ( 2021). We use the dataset preprocessing provided in TensorFlow Federated (TFF, Authors, 2018), and implement the models in Jax (Bradbury et al., 2018; Hennigan et al., 2020; Ro et al., 2021) . We compare against both FedAvg with adaptive optimizers and FedPA. 5 As in FedPA, we run a few rounds of FedAvg as burn-in before switching to FedEP. We refer the reader to the appendix for the exact experimental setup. For evaluation we consider both convergence speed and final performance. On CIFAR-100 and EMNIST-62, we measure the (1) number of rounds to reach certain accuracy thresholds (based on 10-round running averages), and (2) the best accuracy attained within specific rounds (based on 100-round running averages). For StackOverflow, we measure the best precision, recall, micro-and macro-F1 attained by round 1500 (based on 100-round running averages). 6 Due to the size of this dataset, the performance at each round is evaluated on a 10K subsample. The evaluation setup is almost exactly the same as in prior work (Reddi et al., 2020; Al-Shedivat et al., 2021) . Due to space we mainly discuss the CIFAR-100 ("CIFAR") and StackOverflow Tag Prediction ("StackOverflow") results in this section and defer the EMNIST-62 ("EMNIST") results (which are qualitatively similar) to the appendix (Sec. A.3).

CIFAR.

In Table 2 and Fig. 2 (left, mid), we compare FedAvg, FedPA, and FedEP with various approaches for approximating the clients' tilted distributions (Sec. 2.2.2). A notable observation is the switch from FedAvg to FedPA/FedEP at the 400th round, where observe significant increases in performance. Somewhat surprisingly, we find that scaled identity is a simple yet strong baseline. (We conduct further experiments in Sec. 3.3 to analyze this phenomena in greater detail). We next experiment with stochastic EP (FedSEP, Sec. 2.2.4), a stateless version of FedEP that is more memory-efficient. We find that FedSEP can almost match the performance of full EP despite being much simpler (Fig. 2 , right). StackOverflow. Experiments on CIFAR study the challenges when scaling FedEP to richly parameterized neural models with millions of parameters. Our StackOverflow experiments are on the other hand intended to investigate whether FedEP can scale to regimes with a large number of clients (hundreds of thousands). Under this setup the number of clients is large enough that the average client will likely only ever participate in a single update round, which renders the stateful version of FedEP meaningless. We thus mainly experiment with the stateless version of FedEP. 7Table 3 and Fig. 3 (full figure available in the appendix Fig. 5 ) show the results comparing the same set of approximate client inference techniques. These experiments demonstrate the scalability of EP to a large number of clients even when we assume clients are stateless.

3.3. ANALYSIS AND DISCUSSION

The Effectiveness of Scaled Identity. Why does the scaled identity approximation work so well? We investigate this question in the same toy setting as in Sec. 3.1. Fig. 4 (left) compares the scaledidentity EP with FedEP, FedPA, and FedAvg. Unsurprisingly, this restriction leads to worse performance initially. However, as clients pass messages between each other, scaled-identity EP eventually converges to nearly the same approximation as diagonal EP. The toy experiments demonstrate the effectiveness of scaled identity in terms of the final solution. However, this does not fully explain the benchmark experiments where we observed scaled iden- 

Loss Precision Recall

Micro-F1 Macro-F1 tity EP to match more involved variants in terms of convergence speed. We hypothesize that as models grow more complex, the gap between scaled identity and other techniques might decrease due to the difficulty of obtaining credible estimates of (even diagonal) covariance in high dimensional settings. To test this, we revisit the CIFAR-100 task and compare the following two settings: "small" setting which uses a smaller linear model on the PCA'ed features and has 10.1K parameters, and a "large" setting that uses a linear model on the raw features and has 172.9K parameters. For each setting, we conduct experiments with EP using scaled-identity and NGVI and plot the results in Fig. 4 (right). We observe that under the "small" setting, a more advanced approximate inference technique converges faster than scaled-identity EP, consistent with the toy experiments. As we increase the model size however ("large" setting), the gap between these two approaches disappears. This indicates that as the model gets more complex, the convergence benefits of more advanced approximate inference decline due to covariance estimation's becoming more difficult. 4 .9 (0.3) 7.9 (0.2) FedEP (M) 50.5 (0.5) 50.2 (0.4) 5.9 (0.5) 4.6 (0.4) FedEP (L) 47.7 (0.5) 47.8 (0.5) 8.8 (0.4) 6.6 (0.4) FedEP (V) 49.7 (0.5) 49.5 (0.3) 5.9 (0.4) 2.2 (0.5) FedSEP (I) 49.0 (0.4) 48.5 (0.4) 10.0 (0.4) 3.4 (0.3) FedSEP (M) 48.9 (0.4) 48.6 (0.4) 10.1 (0.4) 3.5 (0.3) FedSEP (L) 47.7 (0.5) 47.8 (0.5) 9.6 (0.6) 7.2 (0.6) FedSEP (V) 48.5 (0.4) 48.7 (0.4) 9.3 (0.4) 3.7 (0.4) Uncertainty Quantification. One motivation for a Bayesian approach is uncertainty quantification. We thus explore whether a Bayesian treatment of federated learning results in models that have better expected calibration error (ECE, Naeini et al., 2015; Guo et al., 2017) , which is defined as ECE = ∑ N bins i b i accuracy i -confidence i . Here accuracy i is the top-1 prediction accuracy in i-th bin, confidence i is the average confidence of predictions in i-th bin, and b i is the fraction of data points in i-th bin. Bins are constructed in a uniform way in the [0, 1] range. 8 We consider accuracy and calibration from the resulting approximate posterior in two ways: (1) point estimation, which uses the final model (i.e., MAP estimate from the approximate posterior) to obtain the output probabilities for each data point, and (2) marginalized estimation, which samples 10 models from the approximate posterior and averages the output probabilities to obtain the final prediction probability. In we observe that FedEP/FedSEP improves both the accuracy (higher is better) as well as expected calibration error (lower is better), with marginalization sometimes helping. Hyperparameters. Table 8 shows the robustness of FedEP w.r.t. various hyperparameters. Limitations. While FedEP outperforms strong baselines in terms of convergence speed and final accuracy, it has several limitations. The stateful variant requires clients to maintain its current contribution to the global posterior, which increases the clients' memory requirements. The non-scaledidentity approaches also impose additional communication overhead due to the need to communicate the diagonal covariance vector. Further, while SG-MCMC/Scaled-Identity approaches have the same compute cost as FedAvg on the client side, Laplace/NGVI approaches require more compute to estimate the Fisher term. Finally, from a theoretical perspective, while the convergence properties of FedAvg under various assumptions have been extensively studied (Li et al., 2018; 2020b) , such guarantees for expectation propagation-based approaches remains an open problem.

4. RELATED WORK

Federated Learning. FL is a paradigm for collaborative learning with decentralized private data (Konečnỳ et al., 2016; McMahan et al., 2017; Li et al., 2020a; Kairouz et al., 2021; Wang et al., 2021) . Standard approach to FL tackles it as a distributed optimization problem where the global objective is defined by a weighted combination of clients' local objectives (Mohri et al., 2019; Li et al., 2020a; Reddi et al., 2020; Wang et al., 2020b) . Theoretical analysis has demonstrated that federated optimization exhibits convergence guarantees but only under certain conditions, such as a bounded number of local epochs (Li et al., 2020b) . Other work has tried to improve the averagingbased aggregations Yurochkin et al. (2019) ; Wang et al. (2020a) . Techniques such as secure aggregation (Bonawitz et al., 2017; 2019; He et al., 2020) and differential privacy (Sun et al., 2019; McMahan et al., 2018) have been widely adopted to further improve privacy in FL (Fredrikson et al., 2015) . Our proposed method is compatible with secure aggregation because it conducts server-side reductions over ∆η k , ∆Λ k . Expectation Propagation and Approximate Inference. This work considers EP as a general technique for passing messages between clients and servers on partitioned data. Here, the cavity distribution "summarizes" the effect of inferences from all other partitions and can be used as a prior in the client's local inference. Historically, EP usually refers to a specific choice of divergence function D KL (p∥q) (Minka, 2001) . This is also known as Variational Message Passing (VMP, Winn et al., 2005) son et al., 2016) to personalized FL. Finally, some prior works also consider applying EP to federated learning (Corinzia et al., 2019; Kassab & Simeone, 2022; Ashman et al., 2022) , but mostly on relatively small-scale tasks. In this work, we instead discuss and empirically study various algorithmic considerations to scale up expectation propagation to contemporary benchmarks.

5. CONCLUSION

This work introduces a probabilistic message-passing algorithm for federated learning based on expectation propagation (FedEP). Messages (probability distributions) are passed to and from clients to iteratively refine global approximations. To scale up classic expectation propagation to the modern FL setting, we discuss and empirically study various algorithmic considerations, such as choice of variational family, approximate inference techniques, adaptive optimization, and stateful/stateless clients. These enable practical EP algorithms for modern-scale federated learning models and data. Reproducibility Statement. For experiment details such as the dataset, model, and hyperparameters, we provide detailed descriptions in Sec. 3 as well as Sec. A.4. We also include in the Appendix additional derivations related to adaptive optimization and damping (Sec. A.1). Server Optimizer SGD (m = 0.9) Adagrad (τ = 10 -5 ) SGD (m = 0.9) Client Optimizer † SGD (m = 0.9) SGD (m = 0.9) SGD (m = 0.9) 



Code: https://github.com/HanGuo97/expectation-propagation. This work was completed while Han Guo was a visiting student at MIT. When the parameters of q global (θ) and q k (θ)'s are initialized as improper uniform distributions, the first round (but only the first round) of FedEP and FedPA is identical. Unlike Al-Shedivat et al. (2021), we do not apply Polyak averaging(Mandt et al., 2017;Maddox et al., 2019) as we did not find it to improve results in our case. Reddi et al. (2020) refer to federated averaging with adaptive server optimizers as FedAdam etc. We refer to this as FedAvg for simplicity. TFF by default considers a threshold-based precision and top-5 recall. Our early experiments found that threshold-based metrics correlate better with loss, and use them in StackOverflow experiments. This was also due to the practical difficulty of storing all the clients' distributions. We also experimented with an alternative binning method which puts an equal number of data points in each bin and observed qualitatively similar results.



Figure 1: FedAvg, FedPA, and FedEP on a toy two dimensional dataset with two clients.

Figure 3: StackOverflow Experiments. Curves represent loss, micro-F1, and macro-F1 of the global parameter estimation as a function of rounds for FedAvg, FedPA, and (stateless) FedSEP with various inference techniques. The transitions from FedAvg to FedPA and FedSEP happen at round 800. Lines and shaded regions refer to the averages and 2 standard deviations over 5 runs, resp.

Figure 4: Analysis Experiments. Left: the average Euclidean distances between the estimated and target global mean as a function of rounds in the toy setting. Middle and Right: accuracy as a function of rounds in the CIFAR-100 setting, with either a (relatively) small model (middle) or large model (right).

Model and dataset statistics. The ± in "Examples per client" client denotes standard deviation.

StackOverflow Experiments.

Toy Experiments. Statistics shown are the averages and standard deviations of Euclidean distances between the estimated and target global mean aggregated over 200 random samples of client distributions.

CIFAR-100 Experiments. Left and Middle: loss and accuracy of the server as a function of rounds for FedAvg, FedPA, and (stateful) FedEP with various inference techniques. Right: accuracy as a function of rounds for FedAvg, FedPA, and (stateless) FedSEP. The transitions from FedAvg to FedPA, FedEP, and FedSEP happen at round 400. Lines and shaded regions refer to the averages and 2 standard deviations over 5 runs, resp.



Euclidean Distance

when D KL (q∥p) is used instead, and Laplace propagation (LP,Smola et al., 2003) when Laplace approximation is used. There have been works that formulate federated learning as a probabilistic inference problem. Most notably, Al-Shedivat et al. (2021) formulate FL as a posterior inference problem.Achituve et al. (2021) apply Gaussian processes with deep kernel learning (Wil-

Figure 7: EMNIST-62 Experiments. Figures show the loss and accuracy of the global parameter estimation as a function of rounds for FedAvg, FedPA, and (stateless) FedSEP with various inference techniques. The transitions from FedAvg to FedPA and FedSEP happen at round 200. EMNIST-62 Experiments. We measure the number of rounds to reach certain accuracy thresholds (based on 10-round running averages) and the best accuracy attained within specific rounds (based on 100round running averages). We use I (Scaled Identity Covariance), M (MCMC), L (Laplace), and V (NGVI) to refer to different inference techniques. Input: D k , µ \k , Σ -k , T NGVI , N NGVI , β NGVI 2: Initialize s 0 , Σ \k,0 3: for t = 1, . . . , T NGVI do |D k | Fisher(θ, D k ).Task Hyperparameters fromAl-Shedivat et al. (2021)

Client Inference Hyperparameters Search SpaceScale α cov {1, 2, 5, 10} × 10 -2 1 × {10 -7 , 10 -8 , 10 -9 } {1, 5} × {10 -2 , 10 -3 , 10 -4 } MCMC Shrinkage 1 × {10 -3 , 10 -4 , 10 -5 , 10 -6 } Hyperparameters. † Client has two separate optimizers, one used in local optimization (SG-MCMC), and one used in local state updates (for stateful FedEP). When applied, the client state optimizer reuses the same configuration as the server optimizer. ‡ This is a per-data-point scale, and is also used in other approximate inference techniques. ⋆ The (stateful) FedEP uses 10 Laplace epochs.

CIFAR-100 Hyperparameter Analysis Experiments. FedEP and FedSEP refer to the stateful EP and stateless stochastic EP. We use I (Scaled Identity Covariance), M (MCMC), L (Laplace), and V (NGVI) to refer to different inference techniques.

ACKNOWLEDGMENTS

We thank the anonymous reviewers for their comments, and are grateful to Maruan Al-Shedivat for his feedback. EX is supported by NSF IIS1563887, NSF CCF1629559, NSF IIS1617583, NGA HM04762010002, NIGMS R01GM140467, NSF IIS1955532, NSF CNS2008248, NSF IIS2123952, and NSF BCS2040381. YK acknowledges the support of MIT-IBM Watson AI and Amazon.

A APPENDIX A.1 DAMPED CLIENT AND SERVER UPDATES

To simplify the notations, observe that Eq. 3 could be re-written in the following way,, where q \k (θ) = arg min q \k ∈Q D p k (θ) q -k (θ) ∥ q \k (θ) .A partially damped client update could be carried out by,where we define ∆q k (θ) ∝ q \k (θ) q global (θ) .Similarly, (damped) server updates could be written as the following,Expectation propagation (EP) Minka ( 2001); Vehtari et al. ( 2020) constructs a posterior approximation through iterating local computations that refine factors that approximate the posterior contribution from each client. In this spirit, we would ideally like to solve the following localized version of Eq. 2, where we replace one of the factors with its corresponding approximating factor,Unfortunately, the right-hand side of the divergence is the intractable posterior we would like to approximate in the first place. Instead, EP solves the following problem (Eq. 3),, where q -k (θ) ∝ q global (θ) q k (θ) .

A.3 ADDITIONAL EXPERIMENTS AND DETAILS

StackOverflow. Please see Fig. 5 for additional visualizations. EMNIST. Please see Fig. 7 and Table 6 for experimental results.Analysis. This section extends the experiments (the "small" setting) in Sec. 3.3. It looks at the performance as we increase the complexity (a proxy of quality) of approximate inference techniques.We vary the number of iterations in NGVI from 1 (cheap) to 10 (expensive) epochs. We can observe in Fig. 6 that as we increase NGVI's computations, the performance improves.A.4 HYPERPARAMETERS Please see Table 7 for hyperparameter details. In Table 8 , we also conduct experiments to understand their influence on the different algorithms.

