THE BEST OF BOTH WORLDS: ACCURATE GLOBAL AND PERSONALIZED MODELS THROUGH FEDERATED LEARNING WITH DATA-FREE HYPER-KNOWLEDGE DISTILLATION

Abstract

Heterogeneity of data distributed across clients limits the performance of global models trained through federated learning, especially in the settings with highly imbalanced class distributions of local datasets. In recent years, personalized federated learning (pFL) has emerged as a potential solution to the challenges presented by heterogeneous data. However, existing pFL methods typically enhance performance of local models at the expense of the global model's accuracy. We propose FedHKD (Federated Hyper-Knowledge Distillation), a novel FL algorithm in which clients rely on knowledge distillation (KD) to train local models. In particular, each client extracts and sends to the server the means of local data representations and the corresponding soft predictions -information that we refer to as "hyper-knowledge". The server aggregates this information and broadcasts it to the clients in support of local training. Notably, unlike other KD-based pFL methods, FedHKD does not rely on a public dataset nor it deploys a generative model at the server. We analyze convergence of FedHKD and conduct extensive experiments on visual datasets in a variety of scenarios, demonstrating that FedHKD provides significant improvement in both personalized as well as global model performance compared to state-of-the-art FL methods designed for heterogeneous data settings.

1. INTRODUCTION

Federated learning (FL), a communication-efficient and privacy-preserving alternative to training on centrally aggregated data, relies on collaboration between clients who own local data to train a global machine learning model. A central server coordinates the training without violating clients' privacy -the server has no access to the clients' local data. The first ever such scheme, Federated Averaging (FedAvg) (McMahan et al., 2017) , alternates between two steps: (1) randomly selected client devices initialize their local models with the global model received from the server, and proceed to train on local data; (2) the server collects local model updates and aggregates them via weighted averaging to form a new global model. As analytically shown in (McMahan et al., 2017) , FedAvg is guaranteed to converge when the client data is independent and identically distributed (iid). A major problem in FL systems emerges when the clients' data is heterogeneous (Kairouz et al., 2021) . This is a common setting in practice since the data owned by clients participating in federated learning is likely to have originated from different distributions. In such settings, the FL procedure may converge slowly and the resulting global model may perform poorly on the local data of an individual client. To address this challenge, a number of FL methods aiming to enable learning on non-iid data has recently been proposed (Karimireddy et al., 2020; Li et al., 2020; 2021a; Acar et al., 2021; Liu et al., 2021; Yoon et al., 2021; Chen & Vikalo, 2022) . Unfortunately, these methods struggle to train a global model that performs well when the clients' data distributions differ significantly. Difficulties of learning on non-iid data, as well as the heterogeneity of the clients' resources (e.g., compute, communication, memory, power), motivated a variety of personalized FL (pFL) techniques (Arivazhagan et al., 2019; T Dinh et al., 2020; Zhang et al., 2020; Huang et al., 2021; Collins et al., 2021; Tan et al., 2022) . In a pFL system, each client leverages information received from the server and utilizes a customized objective to locally train its personalized model. Instead of focusing on global performance, a pFL client is concerned with improving the model's local performance empirically evaluated by running the local model on data having distribution similar to the distribution of local training data. Since most personalized FL schemes remain reliant upon on gradient or model aggregation, they are highly susceptible to 'stragglers' that slow down the training convergence process. FedProto (Tan et al., 2021) is proposed to address high communication cost and limitations of homogeneous models in federated learning. Instead of model parameters, in FedProto each client sends to the server only the class prototypes -the means of the representations of the samples in each class. Aggregating the prototypes rather than model updates significantly reduces communication costs and lifts the requirement of FedAvg that clients must deploy the same model architecture. However, note that even though FedProto improves local validation accuracy by utilizing aggregated class prototypes, it leads to barely any improvement in the global performance. Motivated by the success of Knowledge Distillation (KD) (Hinton et al., 2015) which infers soft predictions of samples as the 'knowledge' extracted from a neural network, a number of FL methods that aim to improve global model's generalization ability has been proposed (Jeong et al., 2018b; Li & Wang, 2019; Lin et al., 2020; Zhang et al., 2021) . However, most of the existing KD-based FL methods require that a public dataset is provided to all clients, limiting the feasibility of these methods in practical settings. In this paper we propose FedHKD (Federated Hyper-Knowledge Distillation), a novel FL framework that relies on prototype learning and knowledge distillation to facilitate training on heterogeneous data. Specifically, the clients in FedHKD compute mean representations and the corresponding mean soft predictions for the data classes in their local training sets; this information, which we refer to as "hyper-knowledge," is endued by differential privacy via the Gaussian mechanism and sent for aggregation to the server. The resulting globally aggregated hyper-knowledge is used by clients in the subsequent training epoch and helps lead to better personalized and global performance. A number of experiments on classification tasks involving SVHN (Netzer et al., 2011) , CIFAR10 and CIFAR100 datasets demonstrate that FedHKD consistently outperforms state-of-the-art approaches in terms of both local and global accuracy.

2. RELATED WORK

2.1 HETEROGENEOUS FEDERATED LEARNING Majority of the existing work on federated learning across data-heterogeneous clients can be organized in three categories. The first set of such methods aims to reduce variance of local training by introducing regularization terms in local objective (Karimireddy et al., 2020; Li et al., 2020; 2021a; Acar et al., 2021) . (Mendieta et al., 2022) analyze regularization-based FL algorithms and, motivated by the regularization technique GradAug in centralized learning (Yang et al., 2020) , propose FedAlign. Another set of techniques for FL on heterogeneous client data aims to replace the naive model update averaging strategy of FedAvg by more efficient aggregation schemes. To this end, PFNM (Yurochkin et al., 2019) applies a Bayesian non-parametric method to select and merge multi-layer perceptron (MLP) layers from local models into a more expressive global model in a layer-wise manner. FedMA ((Wang et al., 2020a) ) proceeds further in this direction and extends the same principle to CNNs and LSTMs. (Wang et al., 2020b) analyze convergence of heterogeneous federated learning and propose a novel normalized averaging method. Finally, the third set of methods utilize either the mixup mechanism (Zhang et al., 2017) or generative models to enrich diversity of local datasets (Yoon et al., 2021; Liu et al., 2021; Chen & Vikalo, 2022) . However, these methods introduce additional memory/computation costs and increase the required communication resources.

2.2. PERSONALIZED FEDERATED LEARNING

Motivated by the observation that a global model collaboratively trained on highly heterogeneous data may not generalize well on clients' local data, a number of personalized federated learning (pFL) techniques aiming to train customized local models have been proposed (Tan et al., 2022) . They can be categorized into two groups depending on whether or not they also train a global model. The pFL techniques focused on global model personalization follow a procedure similar to the plain vanilla FL -clients still need to upload all or a subset of model parameters to the server to enable global model aggregation. The global model is personalized by each client via local adaptation

