FEDERATED LEARNING OF A MIXTURE OF GLOBAL AND LOCAL MODELS Anonymous authors Paper under double-blind review

Abstract

We propose a new optimization formulation for training federated learning models. The standard formulation has the form of an empirical risk minimization problem constructed to find a single global model trained from the private data stored across all participating devices. In contrast, our formulation seeks an explicit trade-off between this traditional global model and the local models, which can be learned by each device from its own private data without any communication. Further, we develop several efficient variants of SGD (with and without partial participation and with and without variance reduction) for solving the new formulation and prove communication complexity guarantees. Notably, our methods are similar but not identical to federated averaging / local SGD, thus shedding some light on the essence of the elusive method. In particular, our methods do not perform full averaging steps and instead merely take steps towards averaging. We argue for the benefits of this new paradigm for federated learning.

1. INTRODUCTION

With the proliferation of mobile phones, wearable devices, tablets, and smart home devices comes an increase in the volume of data captured and stored on them. This data contains a wealth of potentially useful information to the owners of these devices, and more so if appropriate machine learning models could be trained on the heterogeneous data stored across the network of such devices. The traditional approach involves moving the relevant data to a data center where centralized machine learning techniques can be efficiently applied (Dean et al., 2012; Reddi et al., 2016) . However, this approach is not without issues. First, many device users are increasingly sensitive to privacy concerns and prefer their data to never leave their devices. Second, moving data from their place of origin to a centralized location is very inefficient in terms of energy and time.

1.1. FEDERATED LEARNING

Federated learning (FL) (McMahan et al., 2016; Konečný et al., 2016b; a; McMahan et al., 2017) has emerged as an interdisciplinary field focused on addressing these issues by training machine learning models directly on edge devices. The currently prevalent paradigm (Li et al., 2019; Kairouz et al., 2019) casts supervised FL as an empirical risk minimization problem of the form min x∈R d 1 n n i=1 f i (x), ( ) where n is the number of devices participating in training, x ∈ R d encodes the d parameters of a global model (e.g., weights of a neural network) and f i (x) := E ξ∼Di [f (x, ξ)] represents the aggregate loss of model x on the local data represented by distribution D i stored on device i. One of the defining characteristics of FL is that the data distributions D i may possess very different properties across the devices. Hence, any potential FL method is explicitly required to be able to work under the heterogeneous data setting. The most popular method for solving (1) in the context of FL is the FedAvg algorithm (McMahan et al., 2016) . In its most simple form, when one does not employ partial participation, model compression, or stochastic approximation, FedAvg reduces to Local Gradient Descent (LGD) (Khaled Under review as a conference paper at ICLR 2021 ICLR et al., 2019;; 2020) , which is an extension of GD performing more than a single gradient step on each device before aggregation. FedAvg has been shown to work well empirically, particularly for non-convex problems, but comes with poor convergence guarantees compared to the non-local counterparts when data are heterogeneous.

Some issues with current approaches to FL

The first motivation for our research comes from the appreciation that data heterogeneity does not merely present challenges to the design of new provably efficient training methods for solving (1), but also inevitably raises questions about the utility of such a global solution to individual users. Indeed, a global model trained across all the data from all devices might be so removed from the typical data and usage patterns experienced by an individual user as to render it virtually useless. This issue has been observed before, and various approaches have been proposed to address it. For instance, the MOCHA (Smith et al., 2017) The second motivation for our work is the realization that even very simple variants of FedAvg, such as LGD, which should be easier to analyze, fail to provide theoretical improvements in communication complexity over their non-local cousins, in this case, GD (Khaled et al., 2019; 2020). 1 This observation is at odds with the practical success of local methods in FL. This leads us to ask the question: if LGD does not theoretically improve upon GD as a solver for the traditional global problem (1), perhaps LGD should not be seen as a method for solving (1) at all. In such a case, what problem does LGD solve? A good answer to this question would shed light on the workings of LGD, and by analogy, on the role local steps play in more elaborate FL methods such as local SGD (Stich, 2020; Khaled et al., 2020) and FedAvg.

2. CONTRIBUTIONS

In our work we argue that the two motivations mentioned in the introduction point in the same direction, i.e., we show that a single solution can be devised addressing both problems at the same time. Our main contributions are: New formulation of FL which seeks an implicit mixture of global and local models. We propose a new optimization formulation of FL. Instead of learning a single global model by solving (1), we propose to learn a mixture of the global model and the purely local models which can be trained by each device i using its data D i only. Our formulation (see Sec. 3) lifts the problem from R d to R nd , allowing each device i to learn a personalized model x i ∈ R d . These personalized models are encouraged to not depart too much from their mean by the inclusion of a quadratic penalty ψ multiplied by a parameter λ ≥ 0. Admittedly, the idea of softly-enforced similarity of the local models was already introduced in the domain of the multi-task relationship learning (Zhang & Yeung, 2010; Liu et al., 2017; Wang et al., 2018) and distributed optimization (Lan et al., 2018; Gorbunov et al., 2019; Zhang et al., 2015) . The mixture objective we propose (see ( 2)) is a special case of their setup, which justifies our approach from the modeling perspective. Note that Zhang et al. (2015); Liu et al. (2017); Wang et al. (2018) provide efficient algorithms to solve the mixture objective already. However, none of the mentioned papers consider the FL application, nor they shed a light on the communication complexity of LGD algorithms, which we do in our work. Theoretical properties of the new formulation. We study the properties of the optimal solution of our formulation, thus developing an algorithmic-free theory. When the penalty parameter is set to zero, then obviously, each device is allowed to train their own model without any dependence on the data stored on other devices. Such purely local models are rarely useful. We prove that the optimal



After our paper was completed, a lower bound on the performance of local SGD was presented that is worse than the known minibatch SGD guarantee(Woodworth et al., 2020a), confirming that the local methods do not outperform their non-local counterparts in the heterogeneous setup. Similarly, the benefit of local methods in the non-heterogeneous scenario was questioned in(Woodworth et al., 2020b).



framework uses a multi-task learning approach to allow for personalization. Next,(Khodak et al., 2019)  propose a generic online algorithm for gradientbased parameter-transfer meta-learning and demonstrate improved practical performance over Fe-dAvg(McMahan et al., 2017). Approaches based on variational inference(Corinzia & Buhmann,  2019), cyclic patterns in practical FL data sampling(Eichner et al., 2019)  transfer learning(Zhao  et al., 2018)  and explicit model mixing(Peterson et al., 2019)  have been proposed.

