MIXED FEDERATED LEARNING: JOINT DECENTRALIZED AND CENTRALIZED LEARNING

Abstract

Federated learning (FL) enables learning from decentralized privacy-sensitive data, with computations on raw data confined to take place at edge clients. This paper introduces mixed FL, which incorporates an additional loss term calculated at the coordinating server (while maintaining FL's private data restrictions). For example, additional datacenter data can be leveraged to jointly learn from centralized (datacenter) and decentralized (federated) training data and better match an expected inference data distribution. Mixed FL also enables offloading some intensive computations (e.g., embedding regularization) to the server, greatly reducing communication and client computation load. For these and other mixed FL use cases, we present three algorithms: PARALLEL TRAINING, 1-WAY GRADIENT TRANSFER, and 2-WAY GRADIENT TRANSFER. We perform extensive experiments of the algorithms on three tasks, demonstrating that mixed FL can blend training data to achieve an oracle's accuracy on an inference distribution, and can reduce communication and computation overhead by more than 90%. Finally, we state convergence bounds for all algorithms, and give intuition on the mixed FL problems best suited to each. The theory confirms our empirical observations of how the algorithms perform under different mixed FL problem settings.

1. INTRODUCTION

Federated learning (FL) (McMahan et al., 2017) is a machine learning setting where multiple 'clients' (e.g., mobile phones) collaborate to train a model under coordination of a central server. Clients' raw data are never transferred. Instead, focused updates intended for immediate aggregation are used to achieve the learning objective (Kairouz et al., 2019) . FL typically delivers model quality improvements because training examples gathered in situ by clients reflect actual inference serving requests. For example, a mobile keyboard next-word prediction model can be trained from actual SMS messages, yielding higher accuracy than a model trained on a proxy document corpus. Because of the benefits, FL has been used to train production models for many applications (Hard et al., 2018; Ramaswamy et al., 2019; Apple, 2019; Ramaswamy et al., 2020; Hartmann, 2021; Hard et al., 2022) . Building on FL, we can gain significant benefits from 'mixed FL': jointlyfoot_0 training with an additional centralized objective in conjunction with the decentralized objective of FL. Let x be model parameters to be optimized. Let f denote a mixed loss, a sumfoot_1 of a federated loss f f and a centralized loss f c : f (x) = f f (x) + f c (x) Mixed loss f might be a more useful training objective than f f for many reasons, including: Mitigating Distribution Shift by Adding Centralized Data to FL While FL helps with reducing train vs. inference distribution skew, it may not remove it completely. Examples include: training device populations that are subsets of inference device populations (e.g., training on high-end phones, for eventual use also on low-end phones), label-biased example retention on edge clients (e.g., only retaining positive examples of a binary classification task), and infrequent safety-critical example events with outsized importance (e.g., automotive hard-braking events needed to train a self-driving AI) (Anonymous, a). The benefits of FL can be achieved while overcoming remaining distribution skew by incorporating data from an additional datacenter dataset, via mixed FL. This affords a composite set of training data that better matches the inference distribution. Reducing Client Computation and Communication In representation learning, negative examples are used to push dissimilar items apart in a latent embedding space while keeping positive examples closer together (Oord et al., 2018) . In federated settings, clients' caches may have limited local negative examples, and recent work (Anonymous, b) showed this significantly degrades performance compared to centralized learning. This work also showed that using an additional loss (a regularization) to push representations apart, instead of negative examples, can resolve this performance gap. However, if done naively this requires communicating and computing over a large embedding table, introducing massive overhead for large-scale tasks. Applying mixed FL, where federated loss f f is the primary 'affinity' loss and centralized loss f c is the 'spreadout' regularization, avoids communicating the entire embedding table to clients and greatly reduces client computation. Though mixed FL can clearly be useful, an actual process to minimize f is not trivial. FL requires that clients' data stay on device, as they contain private information that possibly reveals personal identity. Moreover, centralized loss/data is expected to differ significantlyfoot_2 from client loss/data.

Contributions

• We motivate the mixed FL problem and present three algorithms for addressing it: PARALLEL TRAINING (PT), 1-WAY GRADIENT TRANSFER (1-W GT), and 2-WAY GRADIENT TRANSFER (2-W GT). These algorithms maintain the data privacy protections inherent in FL. [Section 2] • We experiment with facial attribute classification and language modeling, demonstrating that our algorithms overcome distribution shift. We match the accuracy of hypothetical 'oracle' scenarios where the entire inference distribution was colocated for training. [Section 4] • We experiment with user-embedding based movie recommendation, reducing communication overhead by 93.9% and client computation by 99.9% with no degradation in quality. [Section 4] • We state convergence bounds for each algorithm (in strongly, general, and non-convex settings), providing theoretical explanations for convergence behaviors we observe in the experiments. This indicates how the algorithms will perform on new mixed FL tasks. [Section 5]

2. ALGORITHMS

In FL, the loss function f f is an average of client loss functions f i . The client loss f i is an expectation over batches of data examples B i on client i. f f (x) = 1 N N i=1 f i (x), f i (x) = E Bi [f i (x; B i )] FEDAVG (McMahan et al., 2017 ) is a ubiquitous, heuristic FL method designed to minimize Equation 2 w.r.t. model x in a manner that allows all client data (B i ) to remain at respective clients i. Providing strong privacy protection is a major motivation for FL. Storing raw data locally on clients rather than replicating it on servers decreases the attack surface of the system. Also, using focused ephemeral updates and early aggregation follows principles of data minimization (White House Report, 2013).foot_3  While training with loss f f via FEDAVG can yield an effective model x, this paper shows there are scenarios where 'mixing' in an additional 'centralized' loss f c proves beneficial to the training of x. Such a loss term can make use of batches of centralized data examples B c , from a datacenter dataset: f c (x) = E Bc [f c (x; B c )]



We use 'joint' to distinguish our work from sequential 'central-then-FL' use cases, e.g. transfer learning. To simplify we subsume any relative weights into loss terms, i.e. this can be f (x) = (wf ff(x))+(wc fc(x)). Were they not to differ, one could treat a centralized compute node as an additional client in standard FL, and simply make use of an established FL algorithm like FEDAVG for training x. Even stronger privacy properties are possible when FL is combined with technologies such as differential privacy (DP) and secure multiparty computation (SMPC)(Wang et al., 2021).

