PERSONALIZED FEDERATED LEARNING WITH FEA-TURE ALIGNMENT AND CLASSIFIER COLLABORATION

Abstract

Data heterogeneity is one of the most challenging issues in federated learning, which motivates a variety of approaches to learn personalized models for participating clients. One such approach in deep neural networks based tasks is employing a shared feature representation and learning a customized classifier head for each client. However, previous works do not utilize the global knowledge during local representation learning and also neglect the fine-grained collaboration between local classifier heads, which limit the model generalization ability. In this work, we conduct explicit local-global feature alignment by leveraging global semantic knowledge for learning a better representation. Moreover, we quantify the benefit of classifier combination for each client as a function of the combining weights and derive an optimization problem for estimating optimal weights. Finally, extensive evaluation results on benchmark datasets with various heterogeneous data scenarios demonstrate the effectiveness of our proposed method.

1. INTRODUCTION

Modern learning tasks are usually enabled by deep neural networks (DNNs), which require huge quantities of training data to achieve satisfied model performance (Lecun et al., 2015; Krizhevsky et al., 2012; Hinton et al., 2012) . However, collecting data is too costly due to the increasingly large volume of data or even prohibited due to privacy protection. Hence, developing communicationefficient and privacy-preserving learning algorithms is of significant importance for fully taking advantage of the data in clients, e.g., data silos and mobile devices (Yang et al., 2019; Li et al., 2020a) . To this end, federated learning (FL) emerged as an innovative technique for collaborative model training over decentralized clients without gathering the raw data (McMahan et al., 2017) . A typical FL setup employs a central server to maintain a global model and allows partial client participation with infrequent model aggregation, e.g., the popular FedAvg, which has shown good performance when local data across clients are independent and identically distributed (IID). However, in the context of FL, data distributions across clients are usually not identical (non-IID or heterogeneity) since different devices generate or collect data separately and may have specific preferences, including feature distribution drift, label distribution skew and concept shift, which make it hard to learn a single global model that applies to all clients (Zhao et al., 2018; Zhu et al., 2021a; Li et al., 2022) . To remedy this, personalized federated learning (PFL) has been developed, where the goal is to learn a customized model for each client that has better performance on local data while still benefiting from collaborative training (Kulkarni et al., 2020; Tan et al., 2021a; Kairouz et al., 2021) . Such settings can be motivated by cross-silo FL, where autonomous clients (e.g., hospitals and corporations) may wish to satisfy client-specific target tasks. A practical FL framework should be aware of the data heterogeneity and flexibly accommodate local objectives during joint training. On the other hand, the DNNs based models are usually comprised of a feature extractor for extracting low-dimensional feature embeddings from data and a classifier 1 for making a classification decision. The success of deep learning in centralized systems and multi-task learning demonstrates that the feature extractor plays the role of common structure while the classifier tends to be highly task-correlated (Bengio et al., 2013; Collins et al., 2021; Caruana, 1993) . Moreover, clients in practical FL problems often deal with similar learning tasks and clustered structure among clients is assumed in many prior works (Ghosh et al., 2020; Sattler et al., 2021) . Hence, learning a better global feature representation and exploiting correlations between local tasks are of significant importance for improving personalized models. In this work, we mainly consider the label distribution shift scenario, where the number of classes is the same across clients while the number of data samples in each class has obvious drift, i.e., heterogeneous label distributions for the local tasks. We study federated learning from a multi-task learning perspective by leveraging both shared representation and inter-client classifier collaboration. Specifically, we make use of the global feature centroid of each class to regularize the local training, which can be regarded as explicit feature alignment and is able to reduce the representation diversity across locally trained feature extractors, thus facilitating the global aggregation. We also conduct flexible classifier collaboration through client-specific linear combination, which encourages similar clients to collaborate more and avoids negative transfer from unrelated clients. To estimate the proper combining weights, we utilize local feature statistics and data distribution information to achieve the best bias-variance trade-off by solving a quadratic programming problem that minimizes the expected testing loss for each client. Moreover, with a slight modification, our framework could still work well under the concept shift scenario, where the same label might have varied meanings across clients. Our contributions. We focus on deep learning-based classification tasks and propose a novel FL framework equipped by personalized aggregation of classifiers (FedPAC) and feature alignment to improve the overall performance of client-specific tasks. The proposed framework is evaluated on benchmark datasets with various levels of data heterogeneity to verify the effectiveness in achieving higher model performance. Our evaluation results demonstrate the proposed method can improve the average model accuracy by 2∼5%. To summarize, this paper makes the following key contributions: • We quantify the testing loss for each client under the classifier combination by characterizing the discrepancy between the learned model and the target data distribution, which illuminates a new bias-variance trade-off. • A novel personalized federated learning framework with feature representation alignment and optimal classifier combination is proposed for achieving fast convergence and high model performance. • Through extensive evaluation on real datasets with different levels of data heterogeneity, we demonstrate the high adaptability and robustness of FedPAC. Benefits of FedPAC. The benefits of our method over current personalized FL approaches include: (i) More local updates for representation learning. By leveraging feature alignment to control the drift of local representation learning, each client can make many local updates with less local-global parameter diversity at each communication round, which is beneficial in learning better representation in a communication-efficient manner. (ii) Gains by classifier heads collaboration. We employ a theoretically guaranteed optimal weighted averaging for combing heads from similar clients, which is capable of improving generalization ability for data-scarce clients while preventing negative knowledge transfer from unrelated clients. et al., 2021; Chen & Chao, 2021) . Besides, data sharing mechanisms and data augmentation methods are also investigated to mitigate the non-IID data challenges (Zhao et al., 2018; Yoon et al., 2021) .

2. RELATED WORK

From the model aggregation perspective, selecting clients with more contribution to global model performance can also speed up the convergence and mitigate the influence of non-IID data (Wang et al., 2020; Tang et al., 2021; Wu & Wang, 2021; Fraboni et al., 2021) . With the availability of public data, it is possible to employ knowledge distillation techniques to obtain a global model despite of the data heterogeneity (Lin et al., 2020; Zhu et al., 2021b) . The prototype-based methods are also utilized in some FL works, such as (Michieli & Ozay, 2021) proposes a prototype-based weight

