FEDERATED LEARNING AS VARIATIONAL INFERENCE: A SCALABLE EXPECTATION PROPAGATION APPROACH

Abstract

The canonical formulation of federated learning treats it as a distributed optimization problem where the model parameters are optimized against a global loss function that decomposes across client loss functions. A recent alternative formulation instead treats federated learning as a distributed inference problem, where the goal is to infer a global posterior from partitioned client data (Al-Shedivat et al., 2021). This paper extends the inference view and describes a variational inference formulation of federated learning where the goal is to find a global variational posterior that well-approximates the true posterior. This naturally motivates an expectation propagation approach to federated learning (FedEP), where approximations to the global posterior are iteratively refined through probabilistic message-passing between the central server and the clients. We conduct an extensive empirical study across various algorithmic considerations and describe practical strategies for scaling up expectation propagation to the modern federated setting. We apply FedEP on standard federated learning benchmarks and find that it outperforms strong baselines in terms of both convergence speed and accuracy. 1

1. INTRODUCTION AND BACKGROUND

Many applications of machine learning require training a centralized model over decentralized, heterogeneous, and potentially private datasets. For example, hospitals may be interested in collaboratively training a model for predictive healthcare, but privacy rules might require each hospital's data to remain local. Federated Learning (FL, McMahan et al., 2017; Kairouz et al., 2021; Wang et al., 2021) has emerged as a privacy-preserving training paradigm that does not require clients' private data to leave their local devices. FL introduces new challenges on top of classic distributed learning: expensive communication, statistical/hardware heterogeneity, and data privacy (Li et al., 2020a) . The canonical formulation of FL treats it as a distributed optimization problem where the model parameters θ are trained on K (potentially private) datasets D = k∈[K] D k , θ = arg min θ L(θ), where L(θ) = ∑ k∈[K] -log p(D k | θ). Standard distributed optimization algorithms (e.g., data-parallel SGD) are too communicationintensive to be practical under the FL setup. Federated Averaging (FedAvg, McMahan et al., 2017) reduces communication costs by allowing clients to perform multiple local SGD steps/epochs before the parameter updates are sent back to the central server and aggregated. However, due to client data heterogeneity, more local computations could lead to stale or biased client updates, and hence sub-optimal behavior (Charles & Konečnỳ, 2020; Woodworth et al., 2020; Wang et al., 2020a) . An alternative approach is to consider a Bayesian formulation of the FL problem (Al-Shedivat et al., 2021) . Here, we are interested in estimating the posterior of parameters p(θ | D) given a prior p(θ) (such as an improper uniform or a Gaussian prior) and a collection of client likelihoods p(D k | θ) that are independent given the model parameters, p(θ | D) ∝ p(θ) ∏ k∈[K] p(D k | θ). In this case the posterior naturally factorizes across partitioned client data, wherein the global posterior equates to a multiplicative aggregate of local factors (and the prior). However, exact posterior inference is in general intractable for even modestly-sized models and datasets and requires approx-



Code: https://github.com/HanGuo97/expectation-propagation. This work was completed while Han Guo was a visiting student at MIT.

