WEAKLY-SUPERVISED DOMAIN ADAPTATION IN FED-ERATED LEARNING FOR HEALTHCARE Anonymous

Abstract

Federated domain adaptation (FDA) describes the setting where a set of source clients seek to optimize the performance of a target client. To be effective, FDA must address some of the distributional challenges of Federated learning (FL). For instance, FL systems exhibit distribution shifts across clients. Further, labeled data are not always available among the clients. To this end, we propose and compare novel approaches for FDA, combining the few labeled target samples with the source data when auxiliary labels are available to the clients. The in-distribution auxiliary information is included during local training to boost outof-domain accuracy. Also, during fine-tuning, we devise a simple yet efficient gradient projection method to detect the valuable components from each source client model towards the target direction. The extensive experiments on healthcare datasets show that our proposed framework outperforms the state-of-the-art unsupervised FDA methods with limited additional time and space complexity.

1. INTRODUCTION

Federated learning (FL) is a distributed learning paradigm, where an aggregated model is learned using local decentralized data on edge devices (McMahan et al., 2017) . FL systems usually share the model weights or gradient updates of clients to the server, which prevents direct exposure of the sensitive client data. As a result, data heterogeneity remains an important challenge in FL, and much of the research focuses on mitigating the negative impacts of the distribution shifts between clients' data (Wang et al., 2019; Karimireddy et al., 2020; Xie et al., 2020b) . Further, much of the FL literature has focused on settings where all datasets are fully-labeled. However, in the real world, one often encounters settings where the labels are scarce on some of the clients. To this end, multisource domain adaptation (MSDA) (Ben-David et al., 2010; Zhao et al., 2020; Guan & Liu, 2021) is a common solution to this problem, where models trained on several labeled, separate source domains are transferred to the unlabeled or sparsely labeled target domain. Here, we consider the more specialized setting of Federated domain adaptation (FDA) -where a set of source clients seek to optimize the performance of a target client. As an analogue to MSDA, one may consider clients data as different domains. Thus, goal is to learn a good model for the few-labeled target client data samples by transfer the useful knowledge from multiple source clients. In this work, we consider the FDA problem under weak supervision, where auxiliary labels are available to the clients. In brief, we propose novel approaches to deal with weakly-supervised FDA, focusing on techniques that adapt both the local training and fine-tuning stages. Motivating Application. Our work is inspired and applied to applications in predictive modeling for healthcare where there can be significant differences across hospitals, causing transfer errors across sites (Li et al., 2020; Guan et al., 2021; Wolleb et al., 2022) . Unlike many other industries, healthcare in the US is highly heterogeneous (e.g., HCA, the largest consortium of hospitals covers < 2% of the market (Statista, 2020; Wikipedia contributors, 2022) ), thus many variables are not standardized (Adnan et al., 2020; Osarogiagbon et al., 2021) . Hence, we consider experiments that simulate differences across institutions as a large shift. Further, we consider an FL application across several hospitals located at different states in the US. In this setting, FDA is employed to improve performance at a target hospital by leveraging information from all of the source hospitals. The human cost of labeling the images is expensive, thus the data are sparsely labeled. Further, in addition to the medical images, the data also include demographic information such as age, sex, race, among others. While such auxiliary data is ubiquitous, it is often ignored when working to improve centralized or federated models. Here, we show how this data can be used to significantly improve out-of-domain (OOD) performance when properly utilized in FL. Also, in FL, local models are gradually trained, so the importance of source domains may change dynamically during each iteration. Our work seeks to extract the valuable components of model updates instead of the whole models towards the target direction during each round. We discover that the few labeled target samples can provide sufficient guidance towards adapting quickly to the target domain. Our contributions are twofold, in local training and fine-tuning. First, we leverage auxiliary information to reduce the task risk on the target client during local training. Auxiliary information is often cheap and easy to find along with the image inputs, which may offer some useful signals for the unlabeled samples, because of the underlying correlations between auxiliary tasks and main task. Xie et al. (2021) and Wang et al. (2022) both found that in centralized setting, using auxiliary information can improve OOD performance, resulting in a smaller OOD risk compared with the baseline. In our work, we set up a cheap and efficient multi-task learning (MLT) framework with auxiliary tasks during local training of source clients, as shown in Figure 1a . By optimizing the main and auxiliary task losses together, we show empirically that one can boost target domain accuracy after fine-tuning on labeled target samples (Section 5.2). The gains are more evident when the distribution shifts are large, yet auxiliary tasks may introduce unexpected noise when domains are too close to each other. Secondly, we observe that including auxiliary information alone does not fully account for the importance of source domains. Thus, during fine-tuning, we propose a simple yet efficient gradient projection (GP) method. This method utilizes the useful source domain components and projects them towards the target direction. As shown in Figure 1b , during each communication round, we compute the model updates of source clients, and then project them on the target model update obtained by fine-tuning on a small set of labeled target samples. In this way, we greedily approach to the optimum of target domain: the positive cosine similarity between a pair of (source, target) updates serves as the importance of that domain. Our experiment results indicate that this gradient projection method achieves a better and more stable FDA performance on the target client, through combining projected gradients with the fine-tuning gradient. We show the superiority of the gradient projection method through the comprehensive experiments on both medical and general-purposed datasets in Section 5.2 and Appendix A.6. Combining two techniques together, our proposed framework outperforms the state-of-the-arts unsupervised FDA methods with limited additional computational cost. Also, we show empirically that our framework is resistant to data imbalances on the real-world MIDRC dataset (Section 5.2).

2. PROBLEM SETUP

In this section, we introduce the framework of weakly-supervised FDA. We first provide the background knowledge of weakly-supervised MSDA. Following this, we extend the framework to federated learning setting. Weakly-supervised MSDA. Let D S and D T denote source and target domains. In each domain, we have sample pairs of x ∈ R d as the input, y ∈ R as the output label. In MSDA, we have N source domains D S = {D Si } N i=1 with n i labeled samples as well as a target domain with n T total samples, which consists of n l labeled samples and n T -n l (a large number of) unlabeled samples. During pretraining, we train a model h Si on each source domain using the corresponding labeled samples. The goal of fine-tuning is to learn a target model h T , which minimizes the target risk ϵ T (h) := Pr (x,y)∼D T [h(x) ̸ = y]. A common approach is to first aggregates the model parameters using N i=1 α i h Si with α i controlling the importance of each source domain such that N i=1 α i = 1. Then, we fine-tune the aggregated parameters using the set of labeled samples from D T . Federated problem setting. We now extend weakly-supervised MSDA to the FL setting. As shown in Figure 1a and Algorithm 1, we assume there are N + 1 clients in the system, where N clients {C Si } N i=1 have labeled data {D Si } N i=1 and the remaining client C T has D T with n l labeled data. Different from the centralized setting, weakly-supervised FDA requires the target client C T to have no direct access to the source data {D Si } N i=1 . The aggregation of the source domain models is performed by the server S using any federated aggregation rules such as FedAvg (McMahan et al., 2017) or SCAFFOLD (Karimireddy et al., 2020) . After S finishes the aggregation, it sends the model parameters to C T and performs a fine-tuning -just as with weakly-supervised MSDA. The final model is sent to all the source clients {C Si } N i=1 , which ends one communication round. 

3. LEVERAGING AUXILIARY INFORMATION DURING LOCAL TRAINING

After defining the framework for weakly-supervised FDA, we now look into the ways to boost the target client performance. In this section, we explain the idea of leveraging auxiliary information: optimizing the main and auxiliary tasks together during source clients' local training. Here, we only derive the loss objective for one source client C Si (since it is the same for all source clients). Let n be the total local training sample size, and ℓ as the loss function for the main task. Eq. 1 is the loss objective of the main task (image classification). L main = n j=1 ℓ(h(x j ), y j ) (1) Let K be the number of auxiliary tasks, and z k ∈ R T , b k be kth auxiliary input and output labels, respectively. Thus, z kj and b kj denote the jth sample input and output for kth auxiliary task. Each auxiliary task shares the same parameters with h except for the last layer (denoted as h l-1 ) and we define g k to be the feature mapping from the feature representation to the kth auxiliary output. Lastly, let ℓ aux as the loss function for auxiliary tasks. Then, we can construct the loss objective for auxiliary tasks as follows: L aux = n j=1 K k=1 ℓ aux (g k (h l-1 (z kj )), b kj ) In the end, we write the total loss objective for leveraging auxiliary information as Eq. 3, where α controls the weight of auxiliary task losses: L aux-inf o = L main + α • L aux = n j=1 ℓ(h(x j ), y j ) + α • K k=1 ℓ aux (g k (h l-1 (z kj )), b kj ) Including auxiliary information only changes the local training process (Step 1): C Si optimizes its model parameters using a summed loss together with auxiliary task losses. As shown in Figure 1a , we optimize the losses of all tasks together, which consist of negative/positive diagnosis as the main task, and race, sex as well as age information as auxiliary tasks. When fine-tuning on the target domain D T , we only optimize the model parameters using the main task loss without any auxiliary tasks (same as Xie et al. ( 2021 Si at round r. 1: Initialize h (r) Si ← h (r-1) T . 2: Optimize h (r) Si on D Si with L aux-inf o (h (r) Si , x, y, z, b, ℓ, ℓ aux , α). 3: Send h (r) Si to S Algorithm 3 Gradient Projection on C T Input: N source domain models h S = {h Si } N i=1 ; tar- get model h (r-1) T from the previous round r-1; global model h (r-1) global from the previous round r-1; target domain D T ; target client C T ; server S; GP weight control variable β; number of samples of source domains {n i } N i=1 . Output: Target model h (r) T at round r. 1: Step 2: S received all {h (r) Si } N i=1 2: h (r) global ← h (r-1) global 3: for h Si in {h (r) Si } N i=1 do 4: G Si ← h (r) Si -h (r) global 5: end for 6: Send h (r) global and {G (r) Si } N i=1 to C T 7: Step 3: C T received {G (r) Si } N i=1 8: h (r) T ← h (r-1) T 9: Optimize h (r) T on D T using labeled samples 10: G T ← h (r) T -h (r-1) T 11: Update h (r) T using Eq. 7 with {G (r) Si } N i=1 , G T , β, {n i } N i=1 . 12: Send h (r) T to S

4. UTILIZING SOURCE DOMAIN KNOWLEDGE VIA GRADIENT PROJECTION

Leveraging auxiliary information does not consider the importance of each source client contributing to the target client, because we only use a simple aggregation rule. In weakly-supervised FDA, how can one better utilize the knowledge from both labeled target samples and source client models? Here, we suggest a novel Gradient Projection (GP) method. Algorithm intuition. The small set of labeled target samples provides a useful signal on the direction of target domain optimum. Thus, during each communication round, server S does not aggregate the weights (Algorithm 1), but instead computes the model updates denoted by {G (r) Si } N i=1 where G (r) Si ≃ h (r) Si -h (r-1) global at round r from all source clients {C Si } N i=1 . On the target client C T , it will perform gradient projection using cosine similarity on each G (r) Si towards the target direction G (r) T ≃ h (r) T -h (r-1) T , which could be computed after fine-tuning on the small set of the labeled target samples. In this way, we greedily maximize knowledge transfer to the target domain in each round; the projection of G (r) Si on G (r) T could be regarded as the weight of D Si at round r. By combining G (r) T with projected gradients ({G (r) Si } N i=1 on G (r) T ) controlled by a hyper-parameter β, we observe a more steady convergence towards the target direction, as noted in the experiment results outlined in Section 5.2. Figure 1b and Algorithm 3 illustrate the procedure of gradient projection. Details of the Gradient Projection. We compute the cosine similarity (Eq. 4) between one source client model update G T for each layer of the model (for a finer projection). In addition, we align the magnitude of the model updates according to the number of target/source samples, batch sizes, local updates, and learning rates (more details are in Appendix A.2). To prevent negative projection, we set the threshold for function GP to be 0. For a certain layer l of G Si and G l T ), the cosine similarity and corresponding gradient projection result is (Eq. 5): cos(G l Si , G l T ) = G l Si G l T ∥G l Si ∥∥G l T ∥ (4) GP(G l Si , G l T ) = cos(G l Si , G l T ), if cos(G l Si , G l T ) > 0 0, if cos(G l Si , G l T ) ≤ 0 The total gradient projection P GP from all source clients {G (r) Si } N i=1 projected on the target direction G T could be computed as Eq. 6. We use L to denote all layers of current model updates. n i denotes the number of samples trained on source client C Si , which is adapted from FedAvg (McMahan et al., 2017) to redeem data imbalance issue. Hence, we normalize the gradient projections according to their number of samples. Also, + + L l∈L concatenates the projected gradients of all layers. P GP = ++ L l∈L N i=0 GP(G l Si , G l T ) • n i N i n i • G l Si (6) Lastly, a hyper-parameter β is used to incorporate target update G T into P GP to have a more stable performance. The final target model weight h (r) T at round r is thus expressed as: h (r) T = h (r-1) T + (1 -β) • P GP + β • G T (7)

5. EXPERIMENTS

We evaluate our proposed framework on three medical datasets: CheXpert (Irvin et al., 2019) balanced data across clients and and a small distribution shift among domains. (c) MIDRCfoot_0 : highly imbalanced data across clients and a large distribution shift across domains. Note that this is a real-world application with data from hospitals in different locations across U.S. The first two experiments are designed to highlight the kinds of extreme data shift often noted in healthcare, e.g., where the same features are measured differently across hospitals (Wiens et al., 2014) , or where the same variables' names refer to very different measurements/ diagnoses across health systems. Unfortunately, there are little public data illustrating these kinds of important and understudied shifts. Thus, we designed representative semi-synthetic shifts to illustrate the extent and impact of the issue.

5.1. EXPERIMENTAL SETUP

We provide basic information and experiment setup of three medical imaging datasets in Figure 2 . More details on data splitting and implementation details are discussed at Appendix A.3. CheXpert is a widely used medical imaging dataset for multi-label classification consisting of 224,316 chest radiographs from 65,240 patients (Irvin et al., 2019) . We use sex, age, frontal/lateral information from patients to construct auxiliary labels and split domains by labelled condition i.e., selecting "Lung Lesion", "Edema", "Consolidation", "Pneumonia", "Atelectasis"-all of which are lung conditions, as source and target domains. Thus, the task is to predict a new lung condition based on labels of existing lung conditions. MIMIC is a large dataset of 227,835 imaging studies for 65,379 patients between 2011 -2016 (Johnson et al., 2019)) . We set domains using the race information provided and merge the result into four main categories: White, Black, Asian, Hispanic/Latino. Thus, the task is to predict conditions for a previously unobserved racial group condition based on labels collected from other groups. The distribution shifts between race domains are considered small, as we can get a high accuracy simply using FedAvg, as shown in Table 2 . MIDRC includes Computed Radiography (CR) images as the primary input. We evaluate the proposed framework using NC, CA, IN, TX states as source clients, and try to adapt to the target client IL which has a large number of unlabeled samples. The statistics of these five states are shown in Table 1 . We collect race, sex, and age data as auxiliary information.

5.2. MAIN RESULTS

Table 2 and Figure 3 CheXpert. For balanced and large distribution shifts, AuxInfo and GP improve around 6% and 18% of accuracy and AUC individually. Combining the two together, we can get a further 2 ∼ 3% gain with close-to-oracle performance. However, we notice that AuxInfo slows down the convergence speed of the training procedure. MIMIC. For balanced and small distribution shifts, the boost compared with baseline (FedAvg) becomes small. Yet GP still manages to achieve a 1% boost with a close-to-oracle performance. We think when shifts are small, including auxiliary information may hinder the fine-tuning of the model, introducing extra noise during the training procedure. The variance is quite large for AuxInfo during the first several epochs, which leads to a slow convergence when combining two techniques. When the local signals are not helpful, doing GP on top enlarges the negative effect. We hypothesize that this is why it results in a slightly worsened performance.

MIDRC.

For imbalanced and large distribution shifts, both AuxInfo and GP achieve a significant increase of 6% and 7% on the target client accuracy/AUC. 1 ∼ 2% extra increase is obtained when combining AuxInfo with GP. It is interesting to see when client data is imbalanced, AuxInfo actually achieves a faster convergence. We think data imbalance may require more signals coming from auxiliary information to converge. In general, data imbalance has little impact on the performance of our proposed framework. We hypothesize that is because we have normalized the client sample size when updating the model weights. In contrast, SOTA methods are not resistant to this issue. Effects of labeled target sample sizes. We perform the experiment on MIDRC dataset with various target sample sizes of (50, 100, 500, 1000, 2000), testing its impact on AuxInfo and GP individually, as shown in Table 4 and Figure 4a . We report AUC scores for a more accurate comparison. Generally, as we increase the number of target samples, the target domain accuracy boosts as well.

Computational complexity analysis. As shown in

Apart from that, the improvements coming from AuxInfo and GP compared with the baselines also increase when we have more labeled target samples. Effects of number of auxiliary tasks. To compare the contribution of each auxiliary task, we train FedAvg Finetune AuxInfo with a single auxiliary task branch of race/sex/age. In Figure 4b , we report the performance gain of each auxiliary information compared with FedAvg Finetune. The auxiliary tasks seem to have a summed gain effect (race gain + sex gain + age gain ≈ all gain). Choice of controlled hyper-parameter β between gradient projections P GP and fine-tuning gradient update G T . We set β = 0, 0.2, 0.4, 0.6, 0.8 and use 50, 20 labeled target samples on CheXpert and MIMIC datasets. We use a small number of labeled samples to better evaluate the effectiveness of GP. Figure 4c and Figure 4d present the target domain accuracies with different β choices. When β = 0, it relies solely on the set of labeled samples to optimize the parameters. In other words, it could be regarded as not transferring any knowledge from source clients. For both large and small shift cases, We observe around β = 0.4, both the accuracy and AUC reach the highest values. Whereas, when β is large, the performance drops severely for CheXpert but has little influence on MIMIC, for which a close-to-oracle performance is more easily attained. For the large shift case, we find that the total gradient updates become too greedy, taking an overly-large step projecting the gradients from source clients, leading to a worsened performance, while it results in little harm for the small shift case. Therefore, we choose β = 0.5 for our experiments, though the tuning of β may change slightly as the number of labeled target samples varies. There recent works try to tackle label deficiency problem with self-supervision or semi-supervision for better personalized models (Jeong et al., 2020; He et al., 2021; Yang et al., 2021) . Compared to them, we explore a new setting with fully-labeled source clients and one few-labeled target client, improving FDA performance under weak supervision.

Federated domain adaptation.

There are a considerable amount of recent work on multi-source domain adaptation with unsupervised setting, with recent highlights on adversarial training (Saito et al., 2018; Zhao et al., 2018) , knowledge distillation (Nguyen et al., 2021) , and source-free methods (Liang et al., 2020) . Peng et al. (2020) ; Li et al. (2020) are the first to extend MSDA into FL setting; they apply adversarial adaptation techniques to align the representations of nodes. More recently, in KD3A (Feng et al., 2021) and COPA (Wu & Gong, 2021) , the server with unlabeled target samples aggregates the local models by learning the importance of each source domain, via knowledge distillation and collaborative optimization. Our work is in contrast to them that primarily focus on the unsupervised setting. Here, these methods rely heavily on both source and target data with complicated training procedure on the server. Our framework is computationally efficient, exploring FDA problem in a new manner with auxiliary labels available to clients. Auxiliary information in domain generalization. In-N-Out (Xie et al., 2021) investigate both out-of-distribution (OOD) and in-distribution performance of using auxiliary information as inputs and outputs with self-training. In medical scene, Wang et al. (2022) find by pre-training and finetuning on the auxiliary tasks, one could improve the transfer performance between datasets on the primary task. They consider single source-target scenario with no labeled target data while our work focuses on federated MSDA setting with few labeled target data. Their frameworks cannot be directly adapted to our setting, since they require training on the source samples again after training on out-of-domain (target) samples. Our MLT framework properly leverages auxiliary labels in the new setting, is cheap and efficient to compute, with good improvement for large shift cases. Using additional gradient information in FL. Model updates in each communication round could provide valuable insights of client convergence directions, which is mostly explored for byzantine robustness in FL. Zeno++ (Xie et al., 2020a) and FlTrust (Cao et al., 2021) leverage the additional gradient computed from a small clean training dataset on the server, which helps compute the scores of candidate gradients for detecting the malicious adversaries. In our work, we utilize the additional gradient provided by the labeled target samples for FDA problem. In a simple yet effective way, we project source gradients towards the target direction. Our results support that we transfer the knowledge from source clients to the target client with a better and more stable performance.

7. CONCLUSION AND FUTURE WORK

We show that including auxiliary information during local training and gradient projection during fine-tuning, can help address significant distribution shifts and label deficiency issues existing in current federated learning systems, particularly for real medical applications. Our results on healthcare datasets show our proposed framework improves FDA performance with small additional computational cost. Future work includes evaluating on the fly/ offline finetuning scheme, exploring how to select the set of labeled target samples in the real-world case to better align with the distribution of the unlabeled part, analyzing the impact of more factors related to domain discrepancy, and extending current framework to more general transfer learning setting. , respectively. In our real training process, because we use different learning rates, training samples for source and target clients, we need to align the magnitude of model updates. Eq. 8 aligns the model updates from source clients to the target client and Eq. 9 combines the projection results with the target updates. We use lr T and lr S to denote the target and source learning rates; batchsize T and batchsize S are the batch sizes for target and source domains, respectively; n l is the labeled sample size on target client and n i is the sample size for source client C Si ; r S is the rounds of local updates on source clients. P GP = + + l∈L N i=0   GP h (r) S i -h (r-1) global l , h T -h (r-1) T l • n i N i n i • n l batchsize T n i batchsize S • lr T lr S • 1 r S • h (r) S i -h (r-1) global   (8) h (r) T = h (r-1) T + (1 -β) • P GP + β • (h (r) T -h (r-1) T ) A.3 IMPLEMENTATION AND DATA SPLITTING DETAILS CheXpert. We randomly sampled 4,000 source labeled samples (2,000 positive, 2,000 negative) from each domain for and 1,000 target labeled samples for fine-tuning. To create a larger distribution shift, for negative samples, we randomly sample the ones with labels "0" for that condition instead of using "No Finding" labels. MIMIC. We sample 5,00 (250 with findings and 250 without findings) for each source domain and adapt to the target domain with 100 labeled samples. This small sample size is intentionally chosen to increase the difficulty of this federated domain adaptation task. For the auxiliary labels, we use sex and age information from the database. (Deng, 2012) datasets with imbalanced labels. We illustrate details of two datasets and present the experiment results in this section. A.6.1 DATASETS Amazon Review. This dataset is for a binary sentimental analysis task including four domains. By randomly choosing three of these domains as source domains and the rest one as the target domain, we train a simple CNN model for the classification task. We use 2,000 training samples (the same as KD3A (Feng et al., 2021) for a better comparison) for each source domain and 400 labeled target samples. Non-IID MNIST. We adapt the Non-IID benchmark (Li et al., 2022) to construct the source and target domains for MNIST dataset in a Non-IID manner. To make the task harder, we choose the data partition with quantity-based label imbalance with only 3 classes available for each source client (though predicting for 10 classes) and we have 8 source clients in the system. For the target client, we have 10 classes with all digits with a noise-based feature imbalance, creating a shift from the source clients. We used a CNN architecture to do the experiment on classification between digits. Also, we use 100 labeled target samples and test the accuracy on 10,000 unlabeled target samples. A.6.2 RESULTS Amazon Review. From the result in Table 5 , this dataset includes a comparatively simple task with small distribution shifts between clients, yet GP outperforms the state-of-art unsupervised KD3A by 0.38% in average target domain accuracy using 400 labeled target samples. Though our setting is different, GP has a comparable performance against the unsupervised state-of-the-art method. Non-IID MNIST. The distribution shifts between clients are larger for this dataset and GP obtains a larger boost of performance compared with the previous dataset (20% for FedAvg and 10% for FedAvg Finetune), as shown in Table 6 . Hence, we can see that GP improves the target accuracy more significantly when client shifts are bigger.



MIDRC data is semi-public, and is available by request https://www.midrc.org/



Figure 1: (a) Proposed framework of weakly-supervised FDA: we set up a MTL framework leveraging auxiliary labels during the source clients' local training. (b) Intuition of GP: we project the valuable components of source gradients towards the target direction to boost FDA performance.

round r -1; aggregation rule aggr; server S. Output: Target model h (r) T at round r. 1: Step 1: Local training on {C Si } N i=1 2: for D Si in D S do

)), which makes other steps identical to Algorithm 1. Algorithm 2 displays the local training procedure with auxiliary tasks on one of the source clients C Si . Algorithm 2 Domain Adaptation with Auxiliary Information (Local training on C Si ) Input: One source domain D Si ; one source client C Si ; target model h (r-1) T from the previous round r -1; input images x, auxiliary input z, main task label y, auxiliary output label b; main task loss function ℓ and auxiliary task loss function ℓ aux ; loss weight control hyper-parameter α. Output: source model h (r)

Si and target client update G (r)

(for simplicity, we denote them as G l

, MIMIC(Johnson et al., 2019), and a real-world imbalanced dataset from MIDRC tasked with COVID-19 detection from X-ray images. The data are split to represent three scenarios: (a) CheXpert: balanced data across clients and a large distribution shift among domains. (b) MIMIC:

Figure 2: Summary of the medical datasets in our experiments.

compare the target domain accuracy and convergence speed of following methods on three datasets: a) FedAvg only aggregates the source clients using FedAvg (McMahan et al., 2017) without fine-tuning on any target samples; b) FedAvg Finetune performs fine-tuning step after source client aggregations (Algorithm 1); c) FedAvg Finetune AuxInfo includes auxiliary information during local training (Algorithm 2); d) FedAvg Finetune GP performs gradient projection during fine-tuning (Algorithm 3); e) FedAvg Finetune AuxInfo GP combines c) and d)

Figure 3: Target Domain Accuracy v.s Rounds on (a) CheXpert (b) MIMIC (c) MIDRC datasets. For the large shift scenarios (a) and (c), AuxInfo and GP both significantly improve target domain performance. For the small shift scenario (b), GP still manages to achieve 1% boost while AuxInfo may introduce extra noise into the training procedure.

(a) MIDRC: target sample size. (b) MIDRC: comparisons of auxiliary tasks. (c) CheXpert: β. (d) MIMIC: β.

Figure 4: (a) Target AUC with different sample sizes on FedAvg v.s Finetune v.s. AuxInfo v.s. GP for MIDRC. (b) Performance gain compared with FedAvg Finetune using different auxiliary inputs. (c) and (d) display the value of β v.s. target domain ACC/AUC for CheXpert and MIMIC.

Data heterogeneity and label deficiency in federated learning. Distribution shifts between clients remains a crucial challenge in FL. Current work often focus on improving the aggregation rules: Karimireddy et al. (2020) use control variates and Xie et al. (2020b) cluster the client weights via EM algorithm to correct the drifts among clients. In medical scene, Jiang et al. (2022) and Dinsdale et al. (2022) try to mitigate local and global drifts via harmonisation. However, people usually assume the local training is fully-supervised for all clients at present. The truth is, label deficiency problem could happen in any of the clients.

Figure 5: ResNet-18 model architecture for training auxiliary tasks.

Statistics of MIDRC dataset: an extremely imbalanced scenario (IL consists most samples).

Target domain accuracy and AUC scores (%) on three medical datasets with comparisons with SOTA methods. Results are reported averaged across 3 trials and 95% confidence intervals.



Comparison of computational efficiency. We calculate the average time for running 1 global epoch on CheXpert with N = 4 source clients using the same Quadro RTX 6000 GPU. We do not consider communication and testing cost and assume the clients' training happens sequentially.

Target domain accuracy (%) on Non-IID MNIST dataset.A.6 SUPPLEMENTARY EXPERIMENTS ON GRADIENT PROJECTION METHODApart from experimenting on medical imaging datasets for GP, we also test GP on two general-purposed datasets: AmazonReview(McAuley et al., 2015) and self-generated Non-IID MNIST

8. REPRODUCIBILITY STATEMENT

We have provided the details of our dataset preprocessing, hyper-parameters, training scheme, and model architecture in Section 5 and in the Appendix. Also, we have uploaded the source code of our proposed framework as part of the the supplementary materials. Because access to the MIDRC data is restricted to approved users, we are unable to include the original data. We will release the code upon acceptance.

annex

MIDRC. To set up a real-world case with multiple source domains, we split the CR dataset according to the zip code and select 5 states (IL, NC, CA, IN, TX) as our source and target domains. We use all labeled samples in the source domains for local training, and 2,000 target labeled samples for fine-tuning. As a real-world dataset, the number of samples are extremely imbalanced across the clients, since the dataset collects data mostly from Chicago, which potentially would introduce more distributional noise into clients.Setup. We used FedAvg as the aggregation rule for baselines, and set the hyper-parameters α = 1 and β = 0.5, source learning rate as 10 -3 and target learning rate as 2 • 10 -4 , communication rounds r = 80 for MIMIC dataset and r = 50 for MIDRC and CheXpert datasets, as well as the local training step size to be 1. We use cross-entropy losses for classification tasks and Adam optimizer (Kingma & Ba, 2014) . We select around 20% ∼ 25% of target domain labeled samples for fine-tuning under weak supervision. Further, we use pretrained ResNet-18 (He et al., 2016) model with last layer hidden size of 128 for the training of three datasets. When training with auxiliary tasks, we added branches to the output layer of ResNet-18 as shown in Figure 5 .

A.4 DERIVATION OF GRADIENT PROJECTION METHOD'S TIME AND SPACE COMPLEXITY

Time complexity: Assume the total parameter is m and we have l layers. To make it simpler, assume each layer has an average of m l parameters. Computing cosine similarity for all layers of one source client is O(( m l ) 2 • l) = O(m 2 /l). We have N source clients so the total time cost for GP is O(N • m 2 /l).

Space complexity:

The extra memory cost for GP (computing cosine similarity) is O(1) per client for storing the current cosine similarity value.

A.5 CHOICE OF THE LOSS WEIGHT HYPER-PARAMETER α FOR AUXILIARY TASKS

We observe that the convergence speed for AuxInfo is slow for small distribution shift case (MIMIC dataset). Thus, we further conduct the ablation study on MIMIC dataset for hyperparameter α , which controls the loss weights between the main task and auxiliary tasks. We set α = 0.2, 0.4, 0.6, 0.8 and Figure 6 exhibits their target accuracies v.s. epochs. Setting α smaller may lead to a faster convergence while the final performances are almost the same for different α values. Thus, we set α = 1 for all experiments for a fair comparison. 

