FEDERATED LEARNING AS VARIATIONAL INFERENCE: A SCALABLE EXPECTATION PROPAGATION APPROACH

Abstract

The canonical formulation of federated learning treats it as a distributed optimization problem where the model parameters are optimized against a global loss function that decomposes across client loss functions. A recent alternative formulation instead treats federated learning as a distributed inference problem, where the goal is to infer a global posterior from partitioned client data (Al-Shedivat et al., 2021). This paper extends the inference view and describes a variational inference formulation of federated learning where the goal is to find a global variational posterior that well-approximates the true posterior. This naturally motivates an expectation propagation approach to federated learning (FedEP), where approximations to the global posterior are iteratively refined through probabilistic message-passing between the central server and the clients. We conduct an extensive empirical study across various algorithmic considerations and describe practical strategies for scaling up expectation propagation to the modern federated setting. We apply FedEP on standard federated learning benchmarks and find that it outperforms strong baselines in terms of both convergence speed and accuracy. 1

1. INTRODUCTION AND BACKGROUND

Many applications of machine learning require training a centralized model over decentralized, heterogeneous, and potentially private datasets. For example, hospitals may be interested in collaboratively training a model for predictive healthcare, but privacy rules might require each hospital's data to remain local. Federated Learning (FL, McMahan et al., 2017; Kairouz et al., 2021; Wang et al., 2021) has emerged as a privacy-preserving training paradigm that does not require clients' private data to leave their local devices. FL introduces new challenges on top of classic distributed learning: expensive communication, statistical/hardware heterogeneity, and data privacy (Li et al., 2020a) . The canonical formulation of FL treats it as a distributed optimization problem where the model parameters θ are trained on K (potentially private) datasets D = k∈[K] D k , θ = arg min θ L(θ), where L(θ) = ∑ k∈[K] -log p(D k | θ). Standard distributed optimization algorithms (e.g., data-parallel SGD) are too communicationintensive to be practical under the FL setup. Federated Averaging (FedAvg, McMahan et al., 2017) reduces communication costs by allowing clients to perform multiple local SGD steps/epochs before the parameter updates are sent back to the central server and aggregated. However, due to client data heterogeneity, more local computations could lead to stale or biased client updates, and hence sub-optimal behavior (Charles & Konečnỳ, 2020; Woodworth et al., 2020; Wang et al., 2020a) . An alternative approach is to consider a Bayesian formulation of the FL problem (Al-Shedivat et al., 2021) . Here, we are interested in estimating the posterior of parameters p(θ | D) given a prior p(θ) (such as an improper uniform or a Gaussian prior) and a collection of client likelihoods p(D k | θ) that are independent given the model parameters, p(θ | D) ∝ p(θ) ∏ k∈[K] p(D k | θ). In this case the posterior naturally factorizes across partitioned client data, wherein the global posterior equates to a multiplicative aggregate of local factors (and the prior). However, exact posterior inference is in general intractable for even modestly-sized models and datasets and requires approx-imate inference techniques. In this paper we turn to variational inference, in effect transforming the federated optimization problem into a distributed inference problem. Concretely, we view the solution of federated learning as the mode of a variational (posterior) distribution q ∈ Q with some divergence function D(•∥•) (e.g., KL-divergence), θ = arg max θ q(θ), where q(θ) = arg min q∈Q D (p (θ | D) ∥ q (θ)) . (1) Under this approach, clients use local computation to perform posterior inference (instead of parameter/gradient estimation) in parallel. In exchange, possibly fewer lockstep synchronization and communication steps are required between clients and servers. One way to operationalize Eq. 1 is through federated posterior averaging (FedPA, Al-Shedivat et al., 2021) , where each client independently runs an approximate inference procedure and then sends the local posterior parameters to the server to be multiplicatively aggregated. However, there is no guarantee that independent approximations to local posteriors will lead to a good global approximate posterior. Motivated by the rich line of work on variational inference on streaming/partitioned data (Broderick et al., 2013; Vehtari et al., 2020) , this work instead considers an expectation propagation (EP, Minka, 2001) approach to FL. In EP, each partition of the data maintains its own local contribution to the global posterior that is iteratively refined through probabilistic message-passing. When applied to FL, this results in an intuitive training scheme where at each round, each client (1) receives the current approximation to the global posterior from the centralized server, (2) carries out local inference to update its local approximation, and (3) sends the refined approximation to the server to be aggregated. Conceptually, this federated learning with expectation propagation (FedEP) approach extends FedPA by taking into account the current global approximation in step (2). However, scaling up classic expectation propagation to the modern federated setting is challenging due to the high dimensionality of model parameters and the large number of clients. Indeed, while there is some existing work on expectation propagation-based federated learning (Corinzia et al., 2019; Kassab & Simeone, 2022; Ashman et al., 2022) , they typically focus on small models (fewer than 100K parameters) and few clients (at most 100 clients). In this paper we conduct an extensive empirical study across various algorithmic considerations to scale up expectation propagation to contemporary benchmarks (e.g., models with many millions of parameters and datasets with hundreds of thousands of clients). When applied on top of modern FL benchmarks, our approach outperforms strong FedAvg and FedPA baselines.

2. FEDERATED LEARNING WITH EXPECTATION PROPAGATION

The probabilistic view from Eq. 1 motivates an alternative formulation of federated learning based on variational inference. First observe that the global posterior p (θ | D) given a collection of datasets D = k∈[K] D k factorizes as, p (θ | D) ∝ p(θ) K ∏ k=1 p(D k | θ) = K ∏ k=0 p k (θ), where for convenience we define p 0 (θ) := p(θ) to be the prior and further use p k (θ) := p(D k | θ) to refer to the local likelihood associated with k-th data partition. To simplify notation we hereon refer to the global posterior as p global (θ) and drop the conditioning on D. Now consider an approximating global posterior q global (θ) that admits the same factorization as the above, i.e., q global (θ) ∝ ∏ K k=0 q k (θ). Plugging in these terms into Eq. 1 gives the following objective, arg max θ K ∏ k=0 q k (θ), where {q k (θ)} K k=0 = arg min q k ∈Q D K ∏ k=0 p k (θ) ∥ K ∏ k=0 q k (θ) . ( ) Here Q is the variational family, which is assumed to be the same for all clients. This global objective is in general intractable; evaluating ∏ k p k (θ) requires accessing all clients' data and violates the standard FL assumption. This section presents a probabilistic message-passing algorithm based on expectation propagation (EP, Minka, 2001) .

2.1. EXPECTATION PROPAGATION

EP is an iterative algorithm in which an intractable target density p global (θ) is approximated by a tractable density q global (θ) using a collection of localized inference procedures. In EP, each local inference problem is a function of just p k and the current global estimate, making it appropriate for the FL setting.



Code: https://github.com/HanGuo97/expectation-propagation. This work was completed while Han Guo was a visiting student at MIT.

