FEDERATED GENERALIZED BAYESIAN LEARNING VIA DISTRIBUTED STEIN VARIATIONAL GRADIENT DESCENT

Abstract

This paper introduces Distributed Stein Variational Gradient Descent (DSVGD), a non-parametric generalized Bayesian inference framework for federated learning. DSVGD maintains a number of non-random and interacting particles at a central server to represent the current iterate of the model global posterior. The particles are iteratively downloaded and updated by one of the agents with the end goal of minimizing the global free energy. By varying the number of particles, DSVGD enables a flexible trade-off between per-iteration communication load and number of communication rounds. DSVGD is shown to compare favorably to benchmark frequentist and Bayesian federated learning strategies in terms of accuracy and scalability with respect to the number of agents, while also providing well-calibrated, and hence trustworthy, predictions.

1. INTRODUCTION

Federated learning refers to the collaborative training of a machine learning model across agents with distinct data sets, and it applies at different scales, from industrial data silos to mobile devices (Kairouz et al., 2019) . While some common challenges exist, such as the general statistical heterogeneity -"non-iidnes" -of the distributed data sets, each setting also brings its own distinct problems. In this paper, we are specifically interested in a small-scale federated learning setting consisting of mobile or embedded devices, each having a limited data set and running a small-sized model due to their constrained memory. As an example, consider the deployment of health monitors based on data from smart-watch ECG data. In this context, we argue that it is essential to tackle the following challenges, which are largely not addressed by existing solutions: • Trustworthiness: In applications such as personal health assistants, the learning agents' recommendations need to be reliable and trustworthy, e.g., to decide when to contact a doctor in case of a possible emergency; • Number of communication rounds: When models are small, the payload per communication round may not be the main contributor to the overall latency of the training process. In contrast, accommodating many communication rounds requiring arbitrating channel access among multiple devices may yield slow wall-clock time convergence (Lin et al., 2020) . Most existing federated learning algorithms, such as Federated Averaging (FedAvg) (McMahan et al., 2017) , are based on frequentist principles, relying on the identification of a single model parameter vector. Frequentist learning is known to be unable to capture epistemic uncertainty, yielding overconfident decisions (Guo et al., 2017) . Furthermore, the focus of most existing works is on reducing the load per-communication round via compression, rather than decreasing the number of rounds by providing more informative updates at each round (Kairouz et al., 2019) . This paper introduces a trustworthy solution that is able to reduce the number of communication rounds via a non-parametric variational inference-based implementation of federated Bayesian learning. Federated Bayesian learning has the general aim of computing the global posterior distribution in the model parameter space. Existing decentralized, or federated, Bayesian learning protocols are either based on Variational Inference (VI) (Angelino et al., 2016; Neiswanger et al., 2015; Broderick et al., 2013; Corinzia & Buhmann, 2019b) or Monte Carlo (MC) sampling (Ahn et al., 2014; Mesquita et al., 2020; Wei & Conlon, 2019) . State-of-the-art methods in either category include Partitioned Variational Inference (PVI), which has been recently introduced as a unifying distributed VI framework that relies on the optimization over parametric posteriors; and Distributed Stochastic Gradient Langevin Dynamics (DSGLD), which is an MC sampling technique that maintains a number of Markov chains updated via local Stochastic Gradient Descent (SGD) with the addition of Gaussian noise (Ahn et al., 2014; Welling & Teh, 2011) . The performance of VI-based protocols is generally limited by the bias entailed by the variational approximation, while MC sampling is slow and suffers from the difficulty of assessing convergence (Angelino et al., 2016) . Stein Variational Gradient Descent (SVGD) has been introduced in (Liu & Wang, 2016) as a nonparametric Bayesian framework that approximates a target posterior distribution via non-random and interacting particles. SVGD inherits the flexibility of non-parametric Bayesian inference methods, while improving the convergence speed of MC sampling (Liu & Wang, 2016) . By controlling the number of particles, SVGD can provide flexible performance in terms of bias, convergence speed, and per-iteration complexity. This paper introduces a novel non-parametric distributed learning algorithm, termed Distributed Stein Variational Gradient Descent (DSVGD), that transfers the mentioned benefits of SVGD to federated learning. As illustrated in Fig. 1 , DSVGD targets a generalized Bayesian learning formulation, with arbitrary loss functions (Knoblauch et al., 2019) ; and maintains a number of non-random and interacting particles at a central server to represent the current iterate of the global posterior. At each iteration, the particles are downloaded and updated by one of the agents by minimizing a local free energy functional before being uploaded to the server. DSVGD is shown to enable (i) a trade-off between per-iteration communication load and number of communication rounds by varying the number of particles; while (ii) being able to make trustworthy decisions through Bayesian inference.

2. SYSTEM SET-UP

We consider the federated learning set-up in Fig. 1 , where each agent k = 1, . . . , K has a distinct local dataset with associated training loss L k (θ) for model parameter θ. The agents communicate through a central node with the goal of computing the global posterior distribution q(θ) over the shared model parameter θ ∈ R d for some prior distribution p 0 (θ) (Angelino et al., 2016) . Specifically, following the generalized Bayesian learning framework (Knoblauch et al., 2019) , the agents aim at obtaining the distribution q(θ) that minimizes the global free energy min q(θ) F (q(θ)) = K k=1 E θ∼q(θ) [L k (θ)] + αD(q(θ)||p 0 (θ)) , where α > 0 is a temperature parameter. The (generalized, or Gibbs) global posterior q opt (θ) solving problem (1) must strike a balance between minimizing the sum loss function (first term in F (q)) and the model complexity defined by the divergence from a reference prior (second term in F (q)). It is given as q opt (θ) = 1 Z • qopt (θ), with qopt (θ) = p 0 (θ) exp - 1 α K k=1 L k (θ) , where we denoted as Z the normalization constant. It is useful to note that the global free energy can also be written as the scaled KL F (q(θ)) = αD(q(θ)||q opt (θ)).



Figure 1: Federated learning across K agents equipped with local datasets and assisted by a central server: (a) in DVI agents exchange the current model posterior q (i) (θ) with the server, while (b) in DSVGD agents exchange particles {θ n } N n=1 providing a non-parametric estimate of the posterior.

