FEDERATED MIXTURE OF EXPERTS

Abstract

Federated learning (FL) has emerged as the predominant approach for collaborative training of neural network models across multiple users, without the need to gather the data at a central location. One of the important challenges in this setting is data heterogeneity; different users have different data characteristics. For this reason, training and using a single global model might be suboptimal when considering the performance of each of the individual user's data. In this work, we tackle this problem via Federated Mixture of Experts, FedMix, a framework that allows us to train an ensemble of specialized models. FedMix adaptively selects and trains a user-specific selection of the ensemble members. We show that users with similar data characteristics select the same members and therefore share statistical strength while mitigating the effect of non-i.i.d data. Empirically, we show through an extensive experimental evaluation that FedMix improves performance compared to using a single global model while requiring similar or less communication costs.

1. INTRODUCTION

Figure 1 : A sliding window of the gradient divergence (defined in Appendix D), on Ci-far10 in the setup of Section 4 for FedAvg and FedMix (K = 4). An ever-increasing amount of devices are being connected to the internet, sensing their environment, and generating vast amounts of data. The term federated learning (FL) has been established to describe the scenario where we aim to learn from the data generated by this "federation" of devices (McMahan et al., 2016) . Not only does the number of sensing devices increase, but also their processing power is increasing continuously to the point that it becomes viable to perform inference and training of machine learning models on device. In federated learning, the goal is to learn from these client devices' data without collecting the data centrally, which naturally allows for more private exchange of information. Several challenges arise in the federated scenario. Federated devices are generally resource-constrained, both in their computational capacity as well as in communication bandwidth and latency. In a practical example, a smartphone has limited heat dissipation capacity and must communicate via Wi-Fi. From a global perspective, devices' processing power and network connection can be highly heterogeneous across geographical regions and socio-economical status of device owners, causing practical issues (Bonawitz et al., 2019) and raising questions of fairness in FL (Li et al., 2019; Mohri et al., 2019) . One of the key challenges in FL that we aim to address in this work is the non-i.i.d nature of the shards of data that are distributed across devices. In non-federated machine learning, assuming independent and identically distributed data is generally justifiable and not detrimental to model performance. In FL however, each client performs a series of parameter updates on its own data shard to amortize the costs of communication. Over time, the direction of progress across shards with non-i.i.d data starts diverging (as shown in Figure 1 ), which can set back training progress, significantly slow down convergence and decrease model performance (Hsu et al., 2019) . To this end, we propose Federated Mixture of Experts (FedMix), an algorithm for FL that allows for training an ensemble of specialized models instead of a single global model. In FedMix, expert models are learning to specialize in regions of the input space such that, for a given expert, each client's progress on that expert is aligned. FedMix allows each client to learn which experts are relevant for its shard and we show how it can be extended for inference on a previously unseen client. FedMix shows competitive performance against the established standard in FL, FedAvg (McMahan et al., 2016; Deng et al., 2020) across a range of visual classification tasks. Code will be released upon publication.

2. FEDERATED MIXTURE OF EXPERTS

Federated learning (McMahan et al., 2016) deals with the problem of learning a server model with parameters w, e.g., a neural network, from a dataset D = {(x 1 , y 1 ), . . . , (x N , y N )} of N datapoints that is distributed across S shards, i.e., D = D 1 ∪ • • • ∪ D S , without accessing the shard specific datasets directly. By defining a loss function L s (D s ; w) per shard, the total risk can be written as arg min w S s=1 N s N L s (D s ; w), L s (D s ; w) := 1 N s Ns i=1 L(D si ; w). It is easy to see that this objective corresponds to empirical risk minimization over the joint dataset D with a loss L(•) for each datapoint. In federated learning one is interested in reducing the communication costs; for this reason McMahan et al. (2016) propose to do multiple gradient updates for w in the inner optimization objective for each shard s, thus obtaining "local" models with parameters w s . These multiple gradient updates are denoted as "local epochs", i.e., amount of passes through the entire local dataset, with an abbreviation of E. Each of the shards then communicates the local model w s to the server and the server updates the global model at "round" t by averaging the parameters of the local models w t = s Ns N w t s . This constitutes federated averaging (FedAvg) (McMahan et al., 2016) , the standard in federated learning. One of the main challenges in federated learning is the fact that usually the data are non-i.i.d. distributed across the shards S, that is p(D|s i ) = p(D|s j ) for i = j. On the one hand, this can make learning a single global model from all of the data with the classical FedAvg problematic. On the other hand, there is one extreme that does not suffer from this issue; learning S individual models, i.e., only optimizing w s on D s . Although these individual models by definition do not suffer from non-i.i.d data, clearly we should aim to do better and exchange meaningful information between clients to learn more robust and expressive models. With FedMix, we propose to strike a balance between the two aforementioned extremes; learning a single global model and learning S individual models. For this reason, we revisit an old model formulation, the Mixture of Experts (MoE). The classical formulation of a MoE model (Jacobs et al., 1991; Jordan & Jacobs, 1994 ) contains a set of K experts and a gating mechanism that is responsible for choosing an expert for a given data-point. A MoE model for a data point (x, y) can generally be described by p w 1:K ,θ (y|x) = K z=1 p wz (y|x, z)p θ (z|x), where z is a categorical variable that denotes the expert, w k are the parameters of expert k and θ are the parameters of the selection mechanism. The MoE was proposed as a model for datasets where different subsets of the data exhibit different relationships between input x and output y. Instead of training a single global model to fit this relationship everywhere, each expert performs well on a different subset of the input space. The gating function models the decision boundary between input regions, assigning data-points from subsets of the input region to their respective experts.



Figure 2: FedMix graphical model. The generative model is depicted with solid lines and the inference model with dashed lines.

