FEDERATED LEARNING OF A MIXTURE OF GLOBAL AND LOCAL MODELS Anonymous authors Paper under double-blind review

Abstract

We propose a new optimization formulation for training federated learning models. The standard formulation has the form of an empirical risk minimization problem constructed to find a single global model trained from the private data stored across all participating devices. In contrast, our formulation seeks an explicit trade-off between this traditional global model and the local models, which can be learned by each device from its own private data without any communication. Further, we develop several efficient variants of SGD (with and without partial participation and with and without variance reduction) for solving the new formulation and prove communication complexity guarantees. Notably, our methods are similar but not identical to federated averaging / local SGD, thus shedding some light on the essence of the elusive method. In particular, our methods do not perform full averaging steps and instead merely take steps towards averaging. We argue for the benefits of this new paradigm for federated learning.

1. INTRODUCTION

With the proliferation of mobile phones, wearable devices, tablets, and smart home devices comes an increase in the volume of data captured and stored on them. This data contains a wealth of potentially useful information to the owners of these devices, and more so if appropriate machine learning models could be trained on the heterogeneous data stored across the network of such devices. The traditional approach involves moving the relevant data to a data center where centralized machine learning techniques can be efficiently applied (Dean et al., 2012; Reddi et al., 2016) . However, this approach is not without issues. First, many device users are increasingly sensitive to privacy concerns and prefer their data to never leave their devices. Second, moving data from their place of origin to a centralized location is very inefficient in terms of energy and time.

1.1. FEDERATED LEARNING

Federated learning (FL) (McMahan et al., 2016; Konečný et al., 2016b; a; McMahan et al., 2017) has emerged as an interdisciplinary field focused on addressing these issues by training machine learning models directly on edge devices. The currently prevalent paradigm (Li et al., 2019; Kairouz et al., 2019) casts supervised FL as an empirical risk minimization problem of the form min x∈R d 1 n n i=1 f i (x), ( ) where n is the number of devices participating in training, x ∈ R d encodes the d parameters of a global model (e.g., weights of a neural network) and f i (x) := E ξ∼Di [f (x, ξ)] represents the aggregate loss of model x on the local data represented by distribution D i stored on device i. One of the defining characteristics of FL is that the data distributions D i may possess very different properties across the devices. Hence, any potential FL method is explicitly required to be able to work under the heterogeneous data setting. The most popular method for solving (1) in the context of FL is the FedAvg algorithm (McMahan et al., 2016) . In its most simple form, when one does not employ partial participation, model compression, or stochastic approximation, FedAvg reduces to Local Gradient Descent (LGD) (Khaled

