MIXED FEDERATED LEARNING: JOINT DECENTRALIZED AND CENTRALIZED LEARNING

Abstract

Federated learning (FL) enables learning from decentralized privacy-sensitive data, with computations on raw data confined to take place at edge clients. This paper introduces mixed FL, which incorporates an additional loss term calculated at the coordinating server (while maintaining FL's private data restrictions). For example, additional datacenter data can be leveraged to jointly learn from centralized (datacenter) and decentralized (federated) training data and better match an expected inference data distribution. Mixed FL also enables offloading some intensive computations (e.g., embedding regularization) to the server, greatly reducing communication and client computation load. For these and other mixed FL use cases, we present three algorithms: PARALLEL TRAINING, 1-WAY GRADIENT TRANSFER, and 2-WAY GRADIENT TRANSFER. We perform extensive experiments of the algorithms on three tasks, demonstrating that mixed FL can blend training data to achieve an oracle's accuracy on an inference distribution, and can reduce communication and computation overhead by more than 90%. Finally, we state convergence bounds for all algorithms, and give intuition on the mixed FL problems best suited to each. The theory confirms our empirical observations of how the algorithms perform under different mixed FL problem settings.

1. INTRODUCTION

Federated learning (FL) (McMahan et al., 2017) is a machine learning setting where multiple 'clients' (e.g., mobile phones) collaborate to train a model under coordination of a central server. Clients' raw data are never transferred. Instead, focused updates intended for immediate aggregation are used to achieve the learning objective (Kairouz et al., 2019) . FL typically delivers model quality improvements because training examples gathered in situ by clients reflect actual inference serving requests. For example, a mobile keyboard next-word prediction model can be trained from actual SMS messages, yielding higher accuracy than a model trained on a proxy document corpus. Because of the benefits, FL has been used to train production models for many applications (Hard et al., 2018; Ramaswamy et al., 2019; Apple, 2019; Ramaswamy et al., 2020; Hartmann, 2021; Hard et al., 2022) . Building on FL, we can gain significant benefits from 'mixed FL': jointlyfoot_0 training with an additional centralized objective in conjunction with the decentralized objective of FL. Let x be model parameters to be optimized. Let f denote a mixed loss, a sumfoot_1 of a federated loss f f and a centralized loss f c : f (x) = f f (x) + f c (x) Mixed loss f might be a more useful training objective than f f for many reasons, including: Mitigating Distribution Shift by Adding Centralized Data to FL While FL helps with reducing train vs. inference distribution skew, it may not remove it completely. Examples include: training device populations that are subsets of inference device populations (e.g., training on high-end phones, for eventual use also on low-end phones), label-biased example retention on edge clients (e.g., only retaining positive examples of a binary classification task), and infrequent safety-critical example



We use 'joint' to distinguish our work from sequential 'central-then-FL' use cases, e.g. transfer learning. To simplify we subsume any relative weights into loss terms, i.e. this can be f (x) = (wf ff(x))+(wc fc(x)).1

