FEDDAR: FEDERATED DOMAIN-AWARE REPRESENTA-TION LEARNING

Abstract

Cross-silo Federated learning (FL) has become a promising tool in machine learning applications for healthcare. It allows hospitals/institutions to train models with sufficient data while the data is kept private. To make sure the FL model is robust when facing heterogeneous data among FL clients, most efforts focus on personalizing models for clients. However, the latent relationships between clients' data are ignored. In this work, we focus on a special non-iid FL problem, called Domain-mixed FL, where each client's data distribution is assumed to be a mixture of several predefined domains. Recognizing the diversity of domains and the similarity within domains, we propose a novel method, FedDAR, which learns a domain shared representation and domain-wise personalized prediction heads in a decoupled manner. For simplified linear regression settings, we have theoretically proved that FedDAR enjoys a linear convergence rate. For general settings, we have performed intensive empirical studies on both synthetic and real-world medical datasets which demonstrate its superiority over prior FL methods. Our code is available at https://github.com/zlz0414/FedDAR.

1. INTRODUCTION

Federated learning (FL) (McMahan et al., 2017a ) is a machine learning approach that allows many clients(e.g. mobile devices or organizations) to collaboratively train a model without sharing the data. It has great potential to resolve the dilemma in real-world machine learning applications, especially in the domain of healthcare. A robust and generalizable model in medical application usually requires a large amount of diverse data to train. However, collecting a large-scale centralized dataset could be expensive or even impractical due to the constraints from regulatory, ethical and legal challenges, data privacy and protection (Rieke et al., 2020) . While promising, applying FL to real-world problems has many technical challenges. One eminent challenge is data heterogeneity. Data across the clients are assumed to be independently and identically distributed (iid) by many FL algorithms. But this assumption rarely holds in the real world. It has been shown that non-iid data distributions will cause the failure of standard FL strategies such as FedAvg (Jiang et al., 2019; Sattler et al., 2020; Kairouz et al., 2019; Li et al., 2020) . As an ideal model that can perform well on all clients may not exist, it requires FL algorithms to personalize the model for different data distributions. Prior theoretical work (Marfoq et al., 2021) shows that it is impossible to improve performances on all clients without making assumptions about the client's data distributions. Past works on personalized FL methods (Marfoq et al., 2021; Sattler et al., 2020; Ghosh et al., 2020; Mansour et al., 2020; Deng et al., 2020) make their own assumptions and tailor their methods to those assumptions. In this paper, we propose a new and more realistic assumption where each client's data distribution is a mixture of several predefined domains. We call our problem setting Domain-mixed FL. It is inspired by the fact that the diversity of the medical data can be attributed to some known concept of domains, e.g., different demographic/ethnic groups of patients (Szczepura, 2005; Ranganathan & Bhopal, 2006; NHS, 2004) , different manufacturers or protocols/workflows of image scanners (Mårtensson et al., 2020; Ciompi et al., 2017) and so on. It is necessary to address the ubiquitous issue of domain shifts among ethic groups (Szczepura, 2005; Ranganathan & Bhopal, 2006; NHS, 2004) or vendors (Yan et al., 2019; Garrucho et al., 2022; Guan & Liu, 2021) in healthcare data. Despite of the domain shifts, same domain at different clients are usually considered to have the same distribution. The data heterogeneity between FL clients actually comes from the distinct mixtures of diverse domains at clients. These factors motivate us to personalize model for each domain instead of client. Although our method is inspired by healthcare applications where the domain shifts issue is wellknown and domain labels are very basic and accessible, we believe that it can be generally applied to other domains like finance or recommendation systems where users/humans with different demography are involved (Ding et al., 2021; Asuncion & Newman, 2007) . However it would require a deep understanding of the data and background knowledge to verify the data distribution assumption as well as the accessibility of the domain label. FedEM (Marfoq et al., 2021) and FedMinMax(Papadaki et al., 2021) makes similar assumption on data distribution as ours. However, FedEM assumes the domains are unknown and tries to learn a linear combination of several shared component models with personalized mixture weights through an EM-like algorithm. FedMinMax doesn't acknowledge the domain shift between domains and still aims to learn one shared model across domains by adapting minmax optimization to FL setting . Our Contributions. We formulate the proposed problem setting, Domain-mixed FL. Through our analysis, we find prior FL methods, both generic FL methods like FedAvg (McMahan et al., 2017a) , and personalized FL methods like FedRep (Collins et al., 2021) , are sub-optimal under our setting. To address this issue, we propose a new algorithm, Federated Domain-Aware Representation Learning (FedDAR). FedDAR learns a shared model for all the clients but embedded with domainwise personalized modules. The model contains two parts: an shared encoder across all domains and a multi-headed predictor whose heads are associated with domains. For an input from one specific domain, the model extracts representation via the shared encoder and then use the corresponding head to make the prediction. FedDAR decouples the learning of the encoder and heads by alternating between the updates of the encoder and the heads. It allows the clients to run many local updates on the heads without overfitting on domains with limited data samples. This also leads to faster convergence and better performed model. FedDAR also adapts different aggregation strategies for the two parts. We use a weighted average operation to aggregate the local updates for the encoder. With additional sample re-weighting, the overall training objective is equally weighted for each domain to encourage the fairness among domains. While for the heads, we propose a novel second-order aggregation algorithm to improve the optimality of aggregated heads. We theoretically show our method enjoys nice properties like linear convergence and small sample complexity in a linear case. Through extensive experiments on both synthetic and real-world datasets, we demonstrate that FedDAR significantly improves performance over the state-of-the-art personalized FL methods. To the best of our knowledge, our paper is among the first efforts in domain-wise personalized federated learning that achieve such superior performance.

2. RELATED WORK

Besides the literature we have discussed above, other works on personalization and fairness in federated learning are also closely related to our work. Personalized Federated Learning. Personalized federated learning has been studied from a variety of perspectives: i) local fine-tuning (Wang et al., 2019; Yu et al., 2020) ii) meta-learning (Chen et al., 2018; Fallah et al., 2020; Jiang et al., 2019; Khodak et al., 2019) iii) local/global model interpolation (Deng et al., 2020; Corinzia et al., 2019; Mansour et al., 2020) . iv) clustered FL that partition clients into clusters and learn optimal model for each cluster (Sattler et al., 2020; Mansour et al., 2020; Ghosh et al., 2020 )(Zhu et al., 2021) . v) Multi-Task Learning(MTL) (Vanhaesebrouck et al., 2017; Smith et al., 2017; Zantedeschi et al., 2020 ) (Hanzely & Richtárik, 2020; Hanzely et al., 2020; T Dinh et al., 2020; Huang et al., 2021; Li et al., 2021a) vi) local representations or heads for clients (Arivazhagan et al., 2019; Liang et al., 2020; Collins et al., 2021 )(Luo et al., 2022) . vii) personalized model through hypernetwork or super model (Shamsian et al., 2021; Chen & Chao, 2021; Xu et al., 2022) . The personalization module in our approach is similar to vi) and (Zhu et al., 2021) with a multi-branch network. However, the targets we are personalizing the model for are the domains instead of clients.

