WEAKLY-SUPERVISED DOMAIN ADAPTATION IN FED-ERATED LEARNING FOR HEALTHCARE Anonymous

Abstract

Federated domain adaptation (FDA) describes the setting where a set of source clients seek to optimize the performance of a target client. To be effective, FDA must address some of the distributional challenges of Federated learning (FL). For instance, FL systems exhibit distribution shifts across clients. Further, labeled data are not always available among the clients. To this end, we propose and compare novel approaches for FDA, combining the few labeled target samples with the source data when auxiliary labels are available to the clients. The in-distribution auxiliary information is included during local training to boost outof-domain accuracy. Also, during fine-tuning, we devise a simple yet efficient gradient projection method to detect the valuable components from each source client model towards the target direction. The extensive experiments on healthcare datasets show that our proposed framework outperforms the state-of-the-art unsupervised FDA methods with limited additional time and space complexity.

1. INTRODUCTION

Federated learning (FL) is a distributed learning paradigm, where an aggregated model is learned using local decentralized data on edge devices (McMahan et al., 2017) . FL systems usually share the model weights or gradient updates of clients to the server, which prevents direct exposure of the sensitive client data. As a result, data heterogeneity remains an important challenge in FL, and much of the research focuses on mitigating the negative impacts of the distribution shifts between clients' data (Wang et al., 2019; Karimireddy et al., 2020; Xie et al., 2020b) . Further, much of the FL literature has focused on settings where all datasets are fully-labeled. However, in the real world, one often encounters settings where the labels are scarce on some of the clients. To this end, multisource domain adaptation (MSDA) (Ben-David et al., 2010; Zhao et al., 2020; Guan & Liu, 2021) is a common solution to this problem, where models trained on several labeled, separate source domains are transferred to the unlabeled or sparsely labeled target domain. Here, we consider the more specialized setting of Federated domain adaptation (FDA) -where a set of source clients seek to optimize the performance of a target client. As an analogue to MSDA, one may consider clients data as different domains. Thus, goal is to learn a good model for the few-labeled target client data samples by transfer the useful knowledge from multiple source clients. In this work, we consider the FDA problem under weak supervision, where auxiliary labels are available to the clients. In brief, we propose novel approaches to deal with weakly-supervised FDA, focusing on techniques that adapt both the local training and fine-tuning stages. Motivating Application. Our work is inspired and applied to applications in predictive modeling for healthcare where there can be significant differences across hospitals, causing transfer errors across sites (Li et al., 2020; Guan et al., 2021; Wolleb et al., 2022) . Unlike many other industries, healthcare in the US is highly heterogeneous (e.g., HCA, the largest consortium of hospitals covers < 2% of the market (Statista, 2020; Wikipedia contributors, 2022)), thus many variables are not standardized (Adnan et al., 2020; Osarogiagbon et al., 2021) . Hence, we consider experiments that simulate differences across institutions as a large shift. Further, we consider an FL application across several hospitals located at different states in the US. In this setting, FDA is employed to improve performance at a target hospital by leveraging information from all of the source hospitals. The human cost of labeling the images is expensive, thus the data are sparsely labeled. Further, in addition to the medical images, the data also include demographic information such as age, sex, race, among others. While such auxiliary data is ubiquitous, it is often ignored when working to improve centralized or federated models. Here, we show how this data can be used to significantly improve out-of-domain (OOD) performance when properly utilized in FL. Also, in FL, local models are gradually trained, so the importance of source domains may change dynamically during each iteration. Our work seeks to extract the valuable components of model updates instead of the whole models towards the target direction during each round. We discover that the few labeled target samples can provide sufficient guidance towards adapting quickly to the target domain. Our contributions are twofold, in local training and fine-tuning. First, we leverage auxiliary information to reduce the task risk on the target client during local training. Auxiliary information is often cheap and easy to find along with the image inputs, which may offer some useful signals for the unlabeled samples, because of the underlying correlations between auxiliary tasks and main task. Xie et al. (2021) and Wang et al. (2022) both found that in centralized setting, using auxiliary information can improve OOD performance, resulting in a smaller OOD risk compared with the baseline. In our work, we set up a cheap and efficient multi-task learning (MLT) framework with auxiliary tasks during local training of source clients, as shown in Figure 1a . By optimizing the main and auxiliary task losses together, we show empirically that one can boost target domain accuracy after fine-tuning on labeled target samples (Section 5.2). The gains are more evident when the distribution shifts are large, yet auxiliary tasks may introduce unexpected noise when domains are too close to each other. Secondly, we observe that including auxiliary information alone does not fully account for the importance of source domains. Thus, during fine-tuning, we propose a simple yet efficient gradient projection (GP) method. This method utilizes the useful source domain components and projects them towards the target direction. As shown in Figure 1b , during each communication round, we compute the model updates of source clients, and then project them on the target model update obtained by fine-tuning on a small set of labeled target samples. In this way, we greedily approach to the optimum of target domain: the positive cosine similarity between a pair of (source, target) updates serves as the importance of that domain. Our experiment results indicate that this gradient projection method achieves a better and more stable FDA performance on the target client, through combining projected gradients with the fine-tuning gradient. We show the superiority of the gradient projection method through the comprehensive experiments on both medical and general-purposed datasets in Section 5.2 and Appendix A.6. Combining two techniques together, our proposed framework outperforms the state-of-the-arts unsupervised FDA methods with limited additional computational cost. Also, we show empirically that our framework is resistant to data imbalances on the real-world MIDRC dataset (Section 5.2).

2. PROBLEM SETUP

In this section, we introduce the framework of weakly-supervised FDA. We first provide the background knowledge of weakly-supervised MSDA. Following this, we extend the framework to federated learning setting.



Figure 1: (a) Proposed framework of weakly-supervised FDA: we set up a MTL framework leveraging auxiliary labels during the source clients' local training. (b) Intuition of GP: we project the valuable components of source gradients towards the target direction to boost FDA performance.

