FEDERATED GENERALIZED BAYESIAN LEARNING VIA DISTRIBUTED STEIN VARIATIONAL GRADIENT DESCENT

Abstract

This paper introduces Distributed Stein Variational Gradient Descent (DSVGD), a non-parametric generalized Bayesian inference framework for federated learning. DSVGD maintains a number of non-random and interacting particles at a central server to represent the current iterate of the model global posterior. The particles are iteratively downloaded and updated by one of the agents with the end goal of minimizing the global free energy. By varying the number of particles, DSVGD enables a flexible trade-off between per-iteration communication load and number of communication rounds. DSVGD is shown to compare favorably to benchmark frequentist and Bayesian federated learning strategies in terms of accuracy and scalability with respect to the number of agents, while also providing well-calibrated, and hence trustworthy, predictions.

1. INTRODUCTION

Federated learning refers to the collaborative training of a machine learning model across agents with distinct data sets, and it applies at different scales, from industrial data silos to mobile devices (Kairouz et al., 2019) . While some common challenges exist, such as the general statistical heterogeneity -"non-iidnes" -of the distributed data sets, each setting also brings its own distinct problems. In this paper, we are specifically interested in a small-scale federated learning setting consisting of mobile or embedded devices, each having a limited data set and running a small-sized model due to their constrained memory. As an example, consider the deployment of health monitors based on data from smart-watch ECG data. In this context, we argue that it is essential to tackle the following challenges, which are largely not addressed by existing solutions: • Trustworthiness: In applications such as personal health assistants, the learning agents' recommendations need to be reliable and trustworthy, e.g., to decide when to contact a doctor in case of a possible emergency; • Number of communication rounds: When models are small, the payload per communication round may not be the main contributor to the overall latency of the training process. In contrast, accommodating many communication rounds requiring arbitrating channel access among multiple devices may yield slow wall-clock time convergence (Lin et al., 2020) . Most existing federated learning algorithms, such as Federated Averaging (FedAvg) (McMahan et al., 2017) , are based on frequentist principles, relying on the identification of a single model parameter vector. Frequentist learning is known to be unable to capture epistemic uncertainty, yielding overconfident decisions (Guo et al., 2017) . Furthermore, the focus of most existing works is on reducing the load per-communication round via compression, rather than decreasing the number of rounds by providing more informative updates at each round (Kairouz et al., 2019) . This paper introduces a trustworthy solution that is able to reduce the number of communication rounds via a non-parametric variational inference-based implementation of federated Bayesian learning. Federated Bayesian learning has the general aim of computing the global posterior distribution in the model parameter space. Existing decentralized, or federated, Bayesian learning protocols are either based on Variational Inference (VI) (Angelino et al., 2016; Neiswanger et al., 2015; Broderick et al., 2013; Corinzia & Buhmann, 2019b) or Monte Carlo (MC) sampling (Ahn et al., 2014; Mesquita et al., 2020; Wei & Conlon, 2019) . State-of-the-art methods in either category include Partitioned Variational Inference (PVI), which has been recently introduced as a unifying distributed VI framework that relies on the optimization over parametric posteriors; and Distributed Stochastic Gradient Langevin Dynamics (DSGLD), which is an MC sampling technique that maintains a number of Markov chains updated via local Stochastic Gradient Descent (SGD) with the addition of

