FEDDAR: FEDERATED DOMAIN-AWARE REPRESENTA-TION LEARNING

Abstract

Cross-silo Federated learning (FL) has become a promising tool in machine learning applications for healthcare. It allows hospitals/institutions to train models with sufficient data while the data is kept private. To make sure the FL model is robust when facing heterogeneous data among FL clients, most efforts focus on personalizing models for clients. However, the latent relationships between clients' data are ignored. In this work, we focus on a special non-iid FL problem, called Domain-mixed FL, where each client's data distribution is assumed to be a mixture of several predefined domains. Recognizing the diversity of domains and the similarity within domains, we propose a novel method, FedDAR, which learns a domain shared representation and domain-wise personalized prediction heads in a decoupled manner. For simplified linear regression settings, we have theoretically proved that FedDAR enjoys a linear convergence rate. For general settings, we have performed intensive empirical studies on both synthetic and real-world medical datasets which demonstrate its superiority over prior FL methods. Our code is available at https://github.com/zlz0414/FedDAR. * Equal contribution We theoretically show our method enjoys nice properties like linear convergence and small sample complexity in a linear case. Through extensive experiments on both synthetic and real-world datasets, we demonstrate that FedDAR significantly improves performance over the state-of-the-art personalized FL methods. To the best of our knowledge, our paper is among the first efforts in domain-wise personalized federated learning that achieve such superior performance. Besides the literature we have discussed above, other works on personalization and fairness in federated learning are also closely related to our work.

1. INTRODUCTION

Federated learning (FL) (McMahan et al., 2017a ) is a machine learning approach that allows many clients(e.g. mobile devices or organizations) to collaboratively train a model without sharing the data. It has great potential to resolve the dilemma in real-world machine learning applications, especially in the domain of healthcare. A robust and generalizable model in medical application usually requires a large amount of diverse data to train. However, collecting a large-scale centralized dataset could be expensive or even impractical due to the constraints from regulatory, ethical and legal challenges, data privacy and protection (Rieke et al., 2020) . While promising, applying FL to real-world problems has many technical challenges. One eminent challenge is data heterogeneity. Data across the clients are assumed to be independently and identically distributed (iid) by many FL algorithms. But this assumption rarely holds in the real world. It has been shown that non-iid data distributions will cause the failure of standard FL strategies such as FedAvg (Jiang et al., 2019; Sattler et al., 2020; Kairouz et al., 2019; Li et al., 2020) . As an ideal model that can perform well on all clients may not exist, it requires FL algorithms to personalize the model for different data distributions. Prior theoretical work (Marfoq et al., 2021) shows that it is impossible to improve performances on all clients without making assumptions about the client's data distributions. Past works on personalized FL methods (Marfoq et al., 2021; Sattler et al., 2020; Ghosh et al., 2020; Mansour et al., 2020; Deng et al., 2020) make their own assumptions and tailor their methods to those assumptions. In this paper, we propose a new and more realistic assumption where each client's data distribution is a mixture of several predefined domains. We call our problem setting Domain-mixed FL. It is inspired by the fact that the diversity of the medical data can be attributed to some known concept of domains, e.g., different demographic/ethnic groups of patients (Szczepura, 2005; Ranganathan & Bhopal, 2006; NHS, 2004) , different manufacturers or protocols/workflows of image scanners (Mårtensson et al., 2020; Ciompi et al., 2017) and so on. It is necessary to address the ubiquitous issue of domain shifts among ethic groups (Szczepura, 2005; Ranganathan & Bhopal, 2006; NHS, 2004) or vendors (Yan et al., 2019; Garrucho et al., 2022; Guan & Liu, 2021) in healthcare data. Despite of the domain shifts, same domain at different clients are usually considered to have the same distribution. The data heterogeneity between FL clients actually comes from the distinct mixtures of diverse domains at clients. These factors motivate us to personalize model for each domain instead of client. Although our method is inspired by healthcare applications where the domain shifts issue is wellknown and domain labels are very basic and accessible, we believe that it can be generally applied to other domains like finance or recommendation systems where users/humans with different demography are involved (Ding et al., 2021; Asuncion & Newman, 2007) . However it would require a deep understanding of the data and background knowledge to verify the data distribution assumption as well as the accessibility of the domain label. FedEM (Marfoq et al., 2021) and FedMinMax (Papadaki et al., 2021) makes similar assumption on data distribution as ours. However, FedEM assumes the domains are unknown and tries to learn a linear combination of several shared component models with personalized mixture weights through an EM-like algorithm. FedMinMax doesn't acknowledge the domain shift between domains and still aims to learn one shared model across domains by adapting minmax optimization to FL setting . Our Contributions. We formulate the proposed problem setting, Domain-mixed FL. Through our analysis, we find prior FL methods, both generic FL methods like FedAvg (McMahan et al., 2017a) , and personalized FL methods like FedRep (Collins et al., 2021) , are sub-optimal under our setting. To address this issue, we propose a new algorithm, Federated Domain-Aware Representation Learning (FedDAR) . FedDAR learns a shared model for all the clients but embedded with domainwise personalized modules. The model contains two parts: an shared encoder across all domains and a multi-headed predictor whose heads are associated with domains. For an input from one specific domain, the model extracts representation via the shared encoder and then use the corresponding head to make the prediction. FedDAR decouples the learning of the encoder and heads by alternating between the updates of the encoder and the heads. It allows the clients to run many local updates on the heads without overfitting on domains with limited data samples. This also leads to faster convergence and better performed model. FedDAR also adapts different aggregation strategies for the two parts. We use a weighted average operation to aggregate the local updates for the encoder. With additional sample re-weighting, the overall training objective is equally weighted for each domain to encourage the fairness among domains. While for the heads, we propose a novel second-order aggregation algorithm to improve the optimality of aggregated heads. Fairness in Federated Learning. There are two commonly used definitions of fairness in existing FL works. One is client fairness, usually formulated as client parity (CP), which requires clients to have similar performance. A few works (Li et al., 2021a; 2019; Mohri et al., 2019; Yue et al., 2021; Zhang et al., 2020) have studied on this. Another is group fairness. In the centralized setting, the fundamental tradeoff between group fairness and accuracy has been studied (Menon & Williamson, 2018; Wick et al., 2019; Zhao & Gordon, 2019) , and various fair training algorithms have been proposed (Roh et al., 2020; Jiang & Nachum, 2020; Zafar et al., 2017; Zemel et al., 2013; Hardt et al., 2016) . Since the notions of group fairness is the same in FL setting, most of existing FL works adapt methods from centralized setting (Zeng et al., 2021; Du et al., 2021; Gálvez et al., 2021; Chu et al., 2021; Cui et al., 2021) . In this work, our method is not designed specifically for certain group fairness notions like demographic parity. Instead, we aim to achieve the best possible performance for each domain through personalization, admitting the difference between data domains. Moreover, our concept of data domains is not limited as demographic groups. It can also be applied to any other mixture of domain data, as long as our assumptions hold.

3. PROBLEM: DOMAIN-MIXED FEDERATED LEARNING

Notations. Federated learning involves multiple clients. We denote number of clients as n. We use i ∈ [n] ≜ {1, 2, ..., n} to index each client. Client i has a local data distribution D i which induces a local learning objective, i.e., the expected risk R i (f ) = E (xi,yi)∼Di [ℓ(f (x i ), y i )], where f : X → Y is the model mapping the input x ∈ X to the predicted label f (x) ∈ Y and ℓ : Y × Y → R is a generic loss function. In real practice, client i ∈ [n] has a finite number, say L i , of data samples, i.e., S i = {(x j i , y j i )} Li j=1 . L = n i=1 L i denotes the total number of data samples. Problem Formulation of Domain-mixed Federated Learning. We introduce a new formulation of FL problem by assuming each clients' local data distribution is a weighted mixture of M domain specific distributions. Specifically, we use { Dm } M m=1 to denote data distributions from M predefined domains. For client i, its local data distribution is D i = m π i,m Dm where the mixing coefficients π i,m stand for the probabilities of client i's data sample coming from domain m. Take medical application as an example, different hospitals are clients and different ethnic groups are domains. Each ethnic group have different health data while each hospital's data is a mix of ethnic group data. Further, the domains of the data samples are assumed to be known. We use a triplet of variables (x, y, z) to represent the input features, label and domain. The goal of our problem is to learn a model f (x, z) that can perform well in every domain, as shown by the following learning objective, min f R(f ) := 1 M M m=1 R m (f (•, m)) where R m (f (•, m)) = E (x,m)∼ Dm [ℓ(f (x, m), y)]. Our problem focuses on the setting that each domain have a different conditioned label distribution, i.e., P m (y|x) is different in each domain m.

3.1. COMPARISON WITH PRIOR DOMAIN-UNAWARE FL PROBLEM FORMULATIONS

Our FL problem introduces the concept of the domain and focuses on the model's performance in each domain. Many prior FL formulations does not recognize the existence of the domains. For example, the original federated learning algorithms like FedAvg (McMahan et al., 2017a) , FedProx (Li et al., 2020) learn a globally shared model that via minimizing the averaged risk, i.e., min f 1 n i R i (f ) . Some variants consider the fairness across the clients. To do so they optimize the worst client's performance, instead of the averaged performance, i.e., min f max i R i (f ). Further, personalized FL algorithms, such as FedRep (Collins et al., 2021) , customize the model's prediction for each client whose objective is min fi:i∈ [n] 1 n n i=1 R i (f i ). All the FL algorithms mentioned above will lead sub-optimal solutions to our problem since they do not make domain specific predictions. We illustrate this point by the following toy example of linear regression: We assume the data in m'th domain is generated via the following procedure: x ∈ R d is i.i.d sampled from a distribution p(x) with mean zero and covariance I d . The label y ∈ R obeys y = x ⊤ B * w * m where B * ∈ R d×k is ground truth linear embedding shared by all domains, and w * m ∈ R k is the linear head specific to domain m. Under this setting, Dm stands for data (x, y) Algorithm 1 FEDDAR Input: Data S 1:n ; number of local updates τ h for the heads, τ ϕ for representation; number of communication rounds T ; learning rate η. Initialize representation and heads ϕ 0 , h 0 1 , ..., h 0 M . for t = 1, 2, ..., T do Server sends ϕ t-1 , h t-1 1 , ..., h t-1 M to the n clients; for client i = 1, 2, ..., n in parallel do Client i initializes h t,0 i,m ← h t-1 m , ∀m ∈ [M ]. for s = 1 to τ h do h t,s i,m ← GRD( Ri,m (h t,s-1 i,m • ϕ t-1 ), h t,s-1 i,m , η), for all m ∈ [M ]. end for Client i sends updated heads h t,τ h i,m and Hessians H Ri,m (h t,τ h i,m ) to the server. end for Server aggregate the heads for each domain: for m ∈ [M ] do h t m ← HEADAGG({h t,τ h 1,m , H R1,m (h t,τ h 1,m )} n i=1 ) via Equation 8. end for Server sends h t 1 , ..., h t M to the n clients;  for client i = 1, 2, ..., n in parallel do for s = 1 to τ ϕ do ϕ t,s i ← GRD( Ri (ϕ t,s-1 i , {h t m } M m=1 ), ϕ t,s-1 i , η). end for Client i sends updated representation ϕ t i = ϕ 1 2n i∈[n] E (x,y)∼Di (y -x ⊤ Bw) 2 = i∈[n],m∈[M ] π i,m 2n E (x,y)∼ Dm (y -x ⊤ Bw) 2 FedRep: learns shared representation B and separated heads w i for each clients i rather than for each domain m, min B,w1,...,wn 1 2n i∈[n] E (x,y)∼Di (y -x ⊤ Bw i ) 2 = i∈[n],m∈[M ] π i,m 2n E (x,y)∼ Dm (y -x ⊤ Bw i ) 2 (3) FedDAR: In contrast, in the linear case, our proposed method, FedDAR, which will be introduced next, learns a shared representation B and separate heads w m for each domain m,  min B,w1,••• ,wm 1 2M i∈[n] m∈[M ] π i,m i ′ π i ′ ,m ) := 1 M M m=1 R m (h m • ϕ) We decouple the training between encoder and heads. Specifically, we alternates the learning between the encoder and the heads. The learning is done federatedly and has two conventional steps: (1) local updates; (2) aggregation at the server. Algorithm 1 shows the pseudocode code. Empirical Objectives with Re-weighting. Empirically, the objectives are estimated via the finite data samples at each client. We use S i,m to denote the set of samples from domain m in client i, with L i,m := |S i,m | denoting the sample size. Further, L i := M m=1 L i,m is the number of samples in client i while L m := n i=1 L i,m is the total number of samples belonging to domain m across all the clients. We denote the empirical risk at client i specific to domain m as Ri,m (h m • ϕ) := 1 Li,m (x,y)∈Si,m ℓ(h m • ϕ(x), y). The empirical risk at client i is designed as Ri (ϕ, h 1 , ..., h M ) = m Li,m Li u m Ri,m (h m • ϕ), where u m = L LmM re-weights the risk for each domain. Combining commonly used weighted average FL objective R(ϕ, h 1 , ..., h M ) = n i=1 Li L Ri (ϕ, h 1 , ..., h M ), the overall empirical risk is derived as the following, R(ϕ, h 1 , ..., h M ) := n i=1 L i L Ri (ϕ, h 1 , ..., h M ) = 1 M M m=1 Rm (h m • ϕ), where Rm (h m • ϕ) := n i=1 Li,m Lm Ri,m (h m • ϕ). This is consistent with Equation 5.

4.2. LOCAL UPDATES AT CLIENTS

In each communication round, clients use gradient descent methods to optimize representation ϕ(•; θ) and local heads h m (•; w m ) for m ∈ [M ] alternately. We use t to denote the current round. For a module f , f t-1 denotes its optimized version after t -1 rounds. Each round has multiple gradient descent iterations. We use f t,s to denote the module in round t after s iterations. Since the updates are made locally, clients maintain their own copies of both modules, we use subscripts i to index local copy at client i, e.g., f t,s i . We use GRD to denote a generic gradient-base optimization step which takes three inputs: objective function, variables, learning rate and maps them into a new module with updated variables. For example, the vanilla gradient descent has the form GRD(L(f w ), f w , η) = f w-η∇wL(fw) . For the heads, client i performs τ h local gradient-based updates to obtain optimal head given the current shared encoder ϕ t-1 . For s ∈ [τ h ], client i updates via h t,s i,m ← GRD( Ri,m (h t,s-1 i,m • ϕ t-1 ), h t,s-1 i,m , η). For the shared encoder, the clients executes τ ϕ local updates. Specifically, for s ∈ [τ ϕ ], client i updates the local copy of the encoder via ϕ t,s i ← GRD( Ri (ϕ t,s-1 i , {h t m } M m=1 ), ϕ t,s-1 i , η ). The re-weighting mentioned in last section is implemented by re-weighting each sample with u m when calculating the loss function.

4.3. AGGREGATION AT SERVER

We introduce two strategies: (1) weighted average (WA); (2) second-order aggregation (SA). Weighted average means the aggregated model parameters are the average of the local model's parameters weighted by the number of data samples. Specifically, for the shared encoder, we have θ t = n i=1 Li L θ t-1 . Similarly for each head, we have w t m = n i=1 Li,m Lm w t-1 m,i . Second-order aggregation is a more complex strategy. Ideally, we want the head aggregation generates the globally optimal model given a set of locally optimal model, as shown in the following, w * ∈ arg min w J (w) ≜ n i=1 α i J i (w), given w * i = arg min w R i (w) ∀i ∈ [n]. where J i is i'th client's virtual objective, α i := L i /L is the importance of the client, L i is the number of data samples. We call J i the virtual objective to distinguish it from the real learning objective R i . The virtual objective is defined as an objective that the local updates give the optimal solution w.r.t it. It is introduced since the local updates during two aggregated are not guaranteed to optimize the head to optimal w.r.t the real objective. For example, if each local update is single step gradient descent with a learning rate η, i.e., w t+1 i = w t -η∇ w R i (w t ). Then the virtual objective becomes J i (w) = R i (w t )+(w-w t ) ⊤ ∇ w R i (w t )+ 1 2η ∥w-w t ∥ 2 2 which satisfies w t+1 i ∈ arg min w J i (w). Such a virtual objective leads the solution of problem 7 to w * = 1 n n i=1 w * i which is the simple averaging strategy. However, in real practice, the local updates are usually more complicated which makes the virtual objective closer to the true objective. We consider the case that the virtual objective is the secondorder Taylor expansion of the true objective, i.e., J (w ) = R(w t ) + (w -w t ) ⊤ ∇ w R(w t ) + 1 2 (w -w t ) ⊤ H R (w t )(w -w t ) where H R is the Hessian matrix. Then each round of local update equivalent to a Newton-like step, w t+1 i = w t -H Ri (w t ) -1 ∇ w R i (w t ). While w t+1 = w t -H R (w t ) -1 ∇ w R(w t ) is the desired globally optima. Leveraging the fact that, ∇ w R(w) = i∈[n] α i ∇ w R i (w) and H R (w) = i∈[n] α i H Ri (w) , we can get w t+1 from w t+1 i via the following equation, which we call second-order aggregation, w t+1 = H R (w t ) -1 i∈[n] α i H Ri (w t )w t+1 i (8) Specifically, to implement second-order aggregation, in each round, the local clients first optimize the model locally for several epochs. Then we compute the Hessian matrices for each local model and send them to the server for aggregation. Note that sending the Hessian takes a communication cost being quadratic to the size of the weight. In real practice, the predictive head is usually small, e.g., a linear layer with hundreds of neurons. Thus it is acceptable to aggregate the Hessian matrix of the head's parameters. In the following, we provide two instances of our second-order aggregation with a linear head. 1. Linear Regression where R i (w) = 1 Li Li j=1 (w ⊤ x j i -y j ) 2 is quadratic itself. Thus the second order Taylor expansion of the objective itself, i.e., J i (w) = R i (w). In this case, H Ri (w) = X ⊤ i X i where X i = [x 1 i , • • • , x Li i ] ⊤ is the data matrix of client i. 2. Binary Classification where R i (w) = -1 Li Li j=1 y j i log σ(w ⊤ x j i )+(1-y j i ) log(1-σ(w ⊤ x j i )). σ is the sigmoid function. Let µ j i ≜ σ(w ⊤ x j i ) denote model's output. The gradient and the Hessian are, ∇ w R i (w) = 1 Li j (µ j i -y j i )x j i = 1 Li 1 ⊤ diag(µ i -y i )X ⊤ i and H Ri (w) = 1 Li X ⊤ i SX i where S ≜ diag(µ 1 i (1 -µ 1 i ), • • • , µ Li i (1 -µ Li i )) . Similar formulas can be derived for the multiclass classification. Please refer to the text book (Murphy, 2022) for the exact equations. Remark. In practice, when the dimension of w is larger than the number of samples of a certain domain, the Hessian may have small singular values which cause numerical instability. To mitigate this issue, one can either directly set the representation dimension k to some smaller number or add a (fully-connected) projection layer on top of a pretrained encoder to compress the representations to a lower dimensional space.

4.4. THEORETICAL RESULT OF FEDDAR

For a simplified linear regression setting as discussed in domain-mixed FL (4) (cf. details in Appendix A), we give below the sample complexity required for an adapted version of our algorithm (Algorithm 2 in the appendix) to enjoy linear convergence. Due to the space limit, we only provide an informal statement to highlight the result. The formal statement and the proof are deferred in the appendix. Theorem 4.1 (Sample complexity of FedDAR convergence in linear case (informal)). Consider the linear setting for domain-mixed FL in (4). At each iteration, suppose that the number of samples used by each of n clients to update the encoder, is Ω( dk 2 n ), and that the aggregate number of samples used in the update for the domain-specific heads, is Ω(k). Then, for a suitably chosen step size, the distance between the encoder B t Algorithm 2 outputs and the true encoder B * converges at a linear rate. Remark. As our algorithm converges linearly to the true encoder, the per-iteration sample complexity of our algorithm gives a good estimate of the overall sample complexity. Since we expect the output of the encoder to be significantly lower-dimensional than the input (i.e. k ≪ d), our result indicates that Algorithm 2's sample complexity is dominated by Ω( d n ), implying that the complexity reduces significantly as the number of clients n increases. Moreover, a key implication of our result is the capacity for our algorithm to accommodate data imbalance across domains. We note that our approach requires Ω(dk 2 ) samples per iteration for the update of the shared representation B ∈ R d×k , whilst needing only Ω(k) samples per iteration for the update of each domain head. In particular, domains with more data can contribute disproportionately to the Ω(dk 2 ) samples required to learn the common representation, whilst domains with fewer data need only provide Ω(k) samples to update its domain head during the course of the algorithm. Whenever k ≪ d, which we believe is a reasonable assumption for many practical applications (e.g. medical imaging), the requirement of Ω(k) samples per domain is relatively mild. Conversely, forgoing the shared representation structure would require each domain to learn a separate d-dimensional classifier, requiring Ω(d) samples per domain, which can pose a challenge in problems with domain data imbalance.

5. EXPERIMENTS

We validate our method's effectiveness on both synthetic and real datasets. We first experiment on the exact synthetic dataset described in our theoretical analysis to verify our theory. We then conduct experiments on a real dataset, FairFace (Kärkkäinen & Joo, 2019) , with controlled domain distributions to investigate the robustness of our algorithm under different levels of heterogeneity. Finally, we compare our method with various baselines on a real federated learning benchmark, EXAM (Dayan et al., 2021) with real-world domain distributions. We also conduct extensive ablation studies on it to discern the contribution of each component of our method. Full details of experimental settings can be found in Appendix B.

5.1. SYNTHETIC DATA

We first run experiments on the linear regression problem analyzed in Appendix A. We generate (domain, data, label) samples as the following, z i ∼ M(π i ), x i ∼ N (0, I d ), y i ∼ N (w * zi ⊤ B * ⊤ x i , σ) where σ = 10 -3 controls label observation errors, M(π i ) is a multinomial domain distribution with parameter π i = [π i,1 , ..., π i,M ] ∈ ∆ M . The hyper-parameters of domain distributions π i are drawn from a Dirichlet distribution, i.e., π i ∼ Dir(αp), where p ∈ ∆ M is a prior domain distribution over M domains, and α > 0 is a concentration parameter controlling the heterogeneity of domain distributions among clients. The largest domain distribution heterogeneity is achieved as α → 0 where each client contains data only from a single randomly selected domain. On the other hand, when α → ∞, all clients have identical domain distributions that are equal to the prior p. We generate ground-truth representation B * ∈ R d×k and domain specific heads w * m , ∀m ∈ [M ] by sampling and normalizing Gaussian matrices. 4)Separate FedAvg which trains separate models for each domain using FedAvg. The results demonstrate that our method overcomes the heterogeneity of domain distributions across clients. FedDAR-WA fails to converge under this setting, confirming the effectiveness of the proposed second-order aggregation.

5.2. REAL DATA WITH CONTROLLED DISTRIBUTION

Dataset and Model. We use FairFace (Kärkkäinen & Joo, 2019) , a public face image dataset containing 7 race groups which are considered as the domains. Each image is labeled with one of 9 age groups and gender. We use the age label as the target to build a multi-class age classifier. We created an FL setting by dividing training data into n clients without duplication. Each client has a domain distribution π i ∼ Dir(αp) sampled from a Dirichlet distribution. The total number of samples at each client L i = 500 is set to be the same in all experiments. We control the heterogeneity of domain distributions by altering α. The label distributions are uniform for all the clients. Implementation and Evaluation. We use Imagenet (Deng et al., 2009) pre-trained ResNet-34 (He et al., 2016) for all experiments on this dataset. All the methods are trained for T = 100 communication rounds. We use Adam optimizer with a learning rate of 1 × 10 -4 for the first 60 rounds and 1 × 10 -5 for the last 40 rounds. Metrics and Results. Our evaluation metrics are the classification accuracy on the whole validation set of FairFace for each race group. We don't have extra local validation set for each client since we assume the data distribution within each domain is consistent across the clients. In Table 5 .2, we report the accuracy averaged over the final 10 rounds of communication following the common practice (Collins et al., 2021) . The result shows our FedDAR achieved the best performance compared with the baselines. Note that FedAvg + Multi-head also uses Equation 5 as objective for fair comparison. Effect of k. The limitation of using FedDAR-SA instead of FedDAR-WA is the need of tuning the dimension of representation k. Figure 5 .2 shows results of the average domain test accuracy with different k. We can see that FedDAR-SA can achieve better accuracy with a properly chosen k. We use k = 8 for all results with FedDAR-SA in Table 5 .2. Robustness to Varying Levels of Heterogeneity. From the result with various α, we can observe that the performance of FedDAR-SA is very stable no matter how heterogeneous the domain mixtures are. However, the baselines' accuracy decrease when α becomes smaller.

5.3. REAL DATA WITH REAL-WORLD DATA DISTRIBUTION

Dataset and Model. We use the EXAM dataset (Dayan et al., 2021), a large-scale, real-world healthcare FL study. We use part of the dataset including 6 clients with a total of 7,681 cases. We use race groups as domains. The dataset is collected from suspected COVID-19 patients at the visit of the emergency department (ED), including both Chest X-rays (CXR) and electronic medical records (EMR). We adopt the same data preprocessing procedure and the model as (Dayan et al., 2021) . Our task is to predict whether the patient received oxygen therapy higher than high-flow oxygen in 72 hours which indicates severe symptoms. Baselines. (1) methods that learn one global model, FedAvg (McMahan et al., 2017a) , FedProx (Li et al., 2020) , FedMinMax (Papadaki et al., 2021) (Liang et al., 2020) , FedBN (Li et al., 2021b) . Implementation and Evaluation. We apply 5-fold cross-validation. All the models are trained for T = 20 communication rounds with Adam optimizer and a learning rate of 10 -4 . The models are evaluated by aggregating predictions on the local validation sets and then calculating the area under curve (AUC) for each domain. We also report the AUCs averaged on clients' local validation set. Average Performance Across Domains and Clients. Table 3 shows the average of AUCs across domains and clients. We can see that our methods, both FedDAR-WA and FedDAR-SA, achieve significantly better performance than all the baselines under both domain-wise and client-wise metrics. The gap between our domain-wise personalized approach and other client-wise personalized baselines shows the validity of learning domain-wise personalized models facing diversity across domains. The reason that fine-tuning methods induce worse results is mainly because of the imbalanced label distribution. Each local training dataset doesn't have enough positive cases to do proper fine-tuning. Fairness Across Domains. The AUCs of each specific domain in Table 3 , show that our proposed FedDAR method uniformly increases the AUC for each domain. The column of the minimum AUC among domains also verifies that our method indeed improves the fairness across the domains. 2 , we see that adding multi-head alone does not improve results. We conjecture that alternating update prevents the overfitting of the heads with limited samples. This is also shown by the result in Table 5 .2, where FedAvg+MH tends to perform badly on certain underrepresented domains especially when domain distributions are highly heterogeneous (α is small). Meanwhile, using domain labels directly as feature input is not as good as multi-head, and not compatible with alternating update; iii) projection (Proj) and aggregation method (AGG): Results in Table 2 shows that using second-order aggregation with the projection of the features gives the best result.

6. CONCLUSIONS

We propose a novel personalized federated learning framework that assumes the mixture of domain data distribution. Our approach, FedDAR, achieves a balanced performance across domains by learning a global representation and domain-specific heads, despite the heterogeneity of domain distributions across clients. Our method is effective, as supported by both theoretical and empirical justifications. It has been tested on face recognition and medical imaging FL datasets and can be easily extended to other complicated tasks. However, our method has some limitations: i) it requires the domain information for all samples; ii) it does not consider heterogeneity of label distributions; iii) it has a potentially expensive communication cost caused by sending Hessian matrices, especially when the output dimension is big. We plan to address these limitations in future work, along with other research directions such as improving fairness across domains and exploring the setting where domains are structured, hierarchical, continuously indexed (Wang et al., 2020; Nasery et al., 2021) or multi-dimensional (characterized by multiple factors).

A FEDDAR FOR LINEAR REPRESENTATION

A.1 SETUP We retain the setup for linear regression considered at the start of Section 3.1. We additionally define W * ≜ [w * 1 , • • • , w * M ] ⊤ ∈ R M ×k as the concatenation of domain specific heads. For notational convenience, we let (x i,m , y i,m ) denote an (input, output) sample coming from client i and the m-th domain. To measure the distance between any two matrices A, B with the same dimensions, we use the principal angle distance (Golub & Van Loan, 2013) , given by dist(A, B) ≜ ∥A ⊤ ⊥ B∥ 2 , where A ⊥ denotes a matrix whose columns form a basis for the orthogonal complement of the range of A. To simplify the analysis, we further make the following assumptions. Assumption A.1 (Sub-Gaussianilty). For each m ∈ [M ] and i ∈ [n], the samples x i,m ∈ R d are independent, mean zero, have covariance I d , and has subgaussian norm 1, i.e. for every v ∈ R d , E[exp(v ⊤ x i,m )] ≤ exp(∥v∥ 2 /2). Assumption A.2 (Domain diversity). Let σ min, * ≜ σ min ( 1 √ M W * ), i. e., σ min, * is the minimum singular value of the head matrix. Then σ min, * > 0.

Assumption A.3 (Ground truth normalization). The true domain parameters satisfy

1 2 √ k ≤ ∥w * m ∥ ≤ √ k for each m ∈ [M ], and B * has orthonormal columns. All the above assumptions aim to simplify the theoretical analysis whilst only imposing mild constraints on the data distribution and the parameters of the target functions. Similar assumptions have also been adapted in prior work (Collins et al., 2021) .

A.2 FEDDAR ADAPTED TO LINEAR REGRESSION

We analyze an adapted version of our FedDAR algorithm. Since the linear regression problem has an analytic solution, to ease analysis, we update the heads {w m } M m=1 at the server in closed form using local gradient information. Meanwhile, we update the representation B by taking a step using the averaged local gradients. Algorithm 2 shows the procedure of this adapted version. The local objective for i-th client in m-th domain at t-th iteration, f t i,m (w m , B t ) is defined as the following, f t i,m (w m , B t ) ≜ 1 2 L t i,m j=1 (y j i,m -w ⊤ m B ⊤ x j i,m ) 2 , where L t i,m is the number of samples from domain m at client i. We assume in each iteration the data points {x j i,m , y j i,m } j∈[L t i,m ] are all newly sampled from the distribution. We denote L = m L t i,m . Note that since the objective function has a quadratic form, thus its gradient w.r.t either w m or B has a linear form of A i,m w m -a i,n or C i,m Bc i,m which we write down explicitly in Appendix B. After every global update of the representation B, we apply an additional QR decomposition to normalize it to be column-wise orthogonal.

A.3 CONVERGENCE ANALYSIS

We first provide a brief proof sketch. Overall, our approach largely follows that in Collins et al. (2021) , with a few differences needed to handle the spreading of a domain's data across different clients. We note that we also tightened the analysis compared with Collins et al. (2021) , such that each domain only needs O(k) samples as opposed to O(k 2 ) samples as in Collins et al. (2021) (where the requirement is for each client to have O(k 2 ) samples since they considered the case where each client has a separate head). This can yield a significant improvement when k is moderately large and there is data imbalance. 1. First, in Lemma A.5, we show that our estimated weight matrix W t+1 ∈ R M ×k (which is our estimation at time t + 1 of the true domain weights matrix W * ) satisfies the relationship W t+1 = W * (B * ) ⊤ B t + F t , Algorithm 2 FEDDAR for linear regression Input: Step size η; number of rounds T Client initialization: each agent i ∈ [n] collects L 0 samples, and sends Z i := L 0 i=1 (y 0,j i ) 2 x 0,j i (x 0,j i ) ⊤ to the server. Server initialization: finds U DU ⊤ ← rank-k SVD( 1 nL 0 n i Z i ); sets B 0 ← U . for t = 0, 1, . . .

, T do

Server sends current B t to clients. Client computation for W t+1 : for client i ∈ [n] do Selects L new samples {(x j i , y j i )}. Computes ∇ wm f t i,m (w m , B t ) = A t i,m w m -a t i,m for each domain m ∈ [M ]. Sends (A t i,m , a t i,m , L t i,m ) back to server. end for Server update for W t+1 : Server chooses w t+1 m ∈ w m ∈ R k : ∇ wm 1 i L t i,m n i=1 f t i,m (w m , B t ) = 0 , ∀m ∈ [M ], i.e., w t+1 m that satisfies ( i A t i,m )w t+1 m = i a t i,m . Sends W t+1 = [w 1 , • • • , w M ] ⊤ ∈ R M ×k to clients. Client computation for B t+1 : for client i ∈ [n] do Selects L new samples {x j i , y j i }. Computes ∇ B t ′ i,m (w t+1 m , B t ) = C t i,m B t -c t i,m for each m ∈ [M ]. Sends (∇ B f t ′ i,m (w t+1 m , B t ), L t ′ i,m ) back to server. end for Server update for B t+1 : Server computes Bt+1 ← B t -η 1 m M m=1 1 i L t ′ i,m n i=1 ∇ B f t ′ i,m (w t+1 m , B ). Server performs QR decomposition Bt+1 , R t+1 = QR( Bt+1 ). Server updates B t+1 ← Bt+1 . end for where B t is our estimation of the true (high to low dimensional) representation embedding B * at time t, and F t is an error term that we can show to be bounded (sufficiently small) in terms of dist(B t , B * ), with the scale of the error depending on the (random) number of samples L m seen at time t for each domain m ∈ [M ]. Bounding ∥F t ∥ in our setting requires some care since the samples for each domain are spread over many clients. We refer the reader to Section A.4.2 for the details. 2. Second, we show in Section A.4.3 that the update for B t satisfies the relationship (see equation 21) B t+1 = B t - η M (Q t ) ⊤ W t+1 - η M H Q , where Q t := W t+1 (B t ) ⊤ -W * (B * ) ⊤ , and H Q denotes an error term which can be shown to be bounded (sufficiently small) in terms of dist(B t , B * ). Further simplifying, we have dist(B t+1 , B * ) = (B * ⊥ ) ⊤ B t+1 = (B * ⊥ ) ⊤ B t - η M B t (W t+1 ) ⊤ -B * (W * ) ⊤ W t+1 - η M H Q , = dist(B t , B * ) - η M (B * ⊥ ) ⊤ B t (W t+1 ) ⊤ W t+1 - η M (B * ⊥ ) ⊤ H Q ≤ dist(B t , B * ) - η M dist(B t , B * )σ 2 min (W t+1 ) - η M (B * ⊥ ) ⊤ H Q By upper bounding ∥H Q ∥ in terms of dist(B t , B * ) and providing an appropriate lower bound on σ min (W ⊤ t+1 W t+1 ), we can then show by picking η sufficiently small and under other suitable assumptions, the quantity dist(B t , B * ) decays at a linear rate (see Equation 29). Again, a difference from the analysis in Collins et al. (2021) is our handling of bounding ∥H Q ∥, since the samples for each sample are spread over different clients. 3. An issue created by the spreading of samples for a domain across different clients is how to pick an appropriate sample size such that the ∥F t ∥ and ∥H Q ∥ terms can be suitably bounded. In Lemma A.10, we prescribe a suitable sample size for each client such that each domain gets sufficient samples (with high probability) for the purposes of our analysis. We now proceed with a detailed analysis. We first present a theorem that states our adapted FedDAR(Algorithm 2) enjoys linear convergence. The theorem is followed by multiple remarks which highlight key detailed points of our convergence result. Theorem A.4 (Algorithm 2 convergence). Define E 0 := 1 -dist 2 (B 0 , B * ), σmax, * := σ max 1 √ M W * , σmin, * := σ min 1 √ M W * . Let κ := σmax, * σmin, * . Suppose L ≥ Ω max dk 2 κ 4 nE 2 0 , k 2 κ 4 E 2 0 min m∈[M ] ( m i=1 π i,m ) . ( ) Then, for any T and any η ≤ 1/(4σ 2 max, * ), with probability at least 1 -T e -80 , dist(B T , B * ) ≤ (1 -ηE 0 σ2 min, * /2) T /2 dist(B 0 , B * ). ( ) Linear convergence speed: The convergence of B T to B * is linear, assuming that (1) σ min ( 1 √ M W * ) > 0 and that (2) 1 -ηE 0 σ2 min ∈ (0, 1). Initialization of B 0 : For our convergence result to be meaningful, we need dist(B 0 , B * ) to be close to 0. We show in Appendix A that our algorithm's choice of initial B 0 ensures that dist(B 0 , B * ) is close enough to 0 whilst preserving privacy. When the number of samples is uniform across the domains, this comes only at the cost of a logarithmic increase in sample complexity. Sample complexity: The per-iteration sample complexity per client is L. We note that in the requirement for L (9), we need that L ≥ Ω(dk 2 κ 4 /n); this comes from the updates for B t ∈ R d×k . While we expect that d could be large, a large number of clients n helps to mitigate the increase in sample complexity arising from d. We also need L ≥ Ω(k 2 κ 4 m i=1 π i,m ) for every domain m ∈ [M ]; this requirement comes from the updates for w t m for each of the M domains. A.4 PROOF OF THEOREM A.4

A.4.1 ANALYSIS FOR UPDATE OF W t+1

Since we are analyzing the update step for any iteration t, unless necessary we drop all t superscripts. Let L m = n i=1 L i,m denote the number of samples from domain m ∈ [M ] across the n clients. Then, we can express ∇ wm n i=1 f i,m (w m , B) as ∇ wm n i=1 f i,m (w m , B) = n i=1 Li,m j=1 (w ⊤ m B ⊤ x j i,m -y j i,m )B ⊤ x j i,m . Since y j i,m = (w * m ) ⊤ (B * ) ⊤ x j i,m , it follows that following Algorithm 2,   1 L m n i=1 Li,m j=1 B ⊤ x j i,m (x j i,m ) ⊤ B   Gm w t+1 m = 1 L m n i=1 Li,m j=1 B ⊤ x j i,m (x j i,m ) ⊤ B * w * m . (11) Reexpressing, assuming G m is invertible, we have w t+1 m = B ⊤ B * w * m +   G -1 m   1 L m n i=1 Li,m j=1 B ⊤ x j i,m (x j i,m ) ⊤ B * w * m   -B ⊤ B * w * m   (12) Intuitively, assuming L m is large enough, 1 L m n i=1 Li,m j=1 x j i,m (x j i,m ) ⊤ ≈ I d . Hence, G -1 m   1 L m n i=1 Li,m j=1 B ⊤ x j i,m (x j i,m ) ⊤ B * w * m   ≈ B ⊤ B * w * m . This then implies that W t+1 = W * (B * ) ⊤ B + F, ( ) where the m-th row of F is F ⊤ m :=   G -1 m   1 L m n i=1 Li,m j=1 B ⊤ x j i,m (x j i,m ) ⊤ B * w * m   -B ⊤ B * w * m   ⊤ . Note the similarity of equation 18 to ( 17) in (Collins et al., 2021) . Following a similar analysis as (Collins et al., 2021) , we should also be able to bound the Frobenius norm of F in terms of dist(B, B * ). Below, we formalize the argument. First, we have the following lemma. Lemma A.5 (Update for W t+1 ). For each time t, let L t m := n i=1 L t i,m denote the number of samples from domain m ∈ [M ] across the n clients at time t. For convenience, we drop the time index unless absolutely necessary. We define the terms X m := 1 L m n i=1 Li,m j=1 x j i,m (x j i,m ) ⊤ , G m := 1 L m n i=1 Li,m j=1 B ⊤ x j i,m (x j i,m ) ⊤ B . Then, assuming that G m is invertible, the update for W takes the form W t+1 = W * (B * ) ⊤ B + F, ( ) where the m-th row of F is F ⊤ m := G -1 m B ⊤ X m (I -BB ⊤ )B * w * m ⊤ . ( ) Proof. We can express ∇ wm n i=1 f i,m (w m , B) as ∇ wm n i=1 f i,m (w m , B) = n i=1 Li,m j=1 (w ⊤ m B ⊤ x j i,m -y j i,m )B ⊤ x j i,m . Since y j i,m = (w * m ) ⊤ (B * ) ⊤ x j i,m , it follows that following Algorithm 2,   1 L m n i=1 Li,m j=1 B ⊤ x j i,m (x j i,m ) ⊤ B   Gm w t+1 m = 1 L m n i=1 Li,m j=1 B ⊤ x j i,m (x j i,m ) ⊤ B * w * m . (16) Reexpressing, assuming G m is invertible, we have w t+1 m = B ⊤ B * w * m +   G -1 m   1 L m n i=1 Li,m j=1 B ⊤ x j i,m (x j i,m ) ⊤ B * w * m   -B ⊤ B * w * m   . ( ) This then implies that W t+1 = W * (B * ) ⊤ B + F, where the m-th row of F is F ⊤ m :=   G -1 m   1 L m n i=1 Li,m j=1 B ⊤ x j i,m (x j i,m ) ⊤ B * w * m   -B ⊤ B * w * m   ⊤ = G -1 m B ⊤ X m B * w * m -G -1 m G m B ⊤ B * w * m ⊤ = G -1 m B ⊤ X m B * w * m -G -1 m B ⊤ X m BB ⊤ B * w * m ⊤ = G -1 m B ⊤ X m (I -BB ⊤ )B * w * m ⊤ . A.4.2 BOUNDING ∥F ∥ F We will proceed to bound the Frobenius norm of F . We begin by showing that G -1 m exists and (both lower and upper) bounding its spectral norm. Lemma A.6. Let L min := min m∈[M ] L m . Let δ k := 10C √ k √ log(M ) √ Lmin for some absolute constant C. Suppose that 0 ≤ δ k < 1. Then, with probability at least 1 -e -80k log(M ) , G -1 m exists for each m ∈ [M ], and ∥G -1 m ∥ 2 ≤ 1 1 -δ k ∀m ∈ [M ]. Proof. Note that G m := 1 L m n i=1 Li,m j=1 B ⊤ x j i,m (x j i,m ) ⊤ B . Let v j i,m := B ⊤ x j i,m . Since B ⊤ B = I, it follows that each v j i,m is i.i.d 1-subgaussian. Then, applying the same argument in Theorem 4.6.1 of Vershynin 2018, we have (cf. equation (4.22) in Vershynin 2018) σ min (G m ) ≥ 1 -C √ k √ L m + z √ L m δ k,m with probability at least 1-e -z 2 for z ≥ 0 and some absolute constant C, assuming that 0 ≤ δ k,m ≤ 1. Consider the choice z = 9 √ k log(M ). Then, δ k,m = C √ k √ L m + 9 √ k log(M ) √ L m ≤ 10C √ k log M √ L m ≤ 10C √ k √ log M √ L min . Suppose we choose L min ≥ 1 such that δ k,m < 1. Then, taking a union bound, with probability at least 1 -M e -z 2 = 1 -M exp(-81k} log(M ))=1 -exp(-80k log(M )), σ min (G m ) ≥ 1 -δ k,m ≥ 1 - 10C √ k √ log M √ L min > 0 ∀m ∈ [M ]. Therefore, with probability at least 1 -exp(-80k log(M )), G -1 m exists for every m ∈ [M ], and in addition, ∥G -1 m ∥ 2 ≤ 1 1 -δ k ∀m ∈ [M ]. We next bound the operator norm of term B ⊤ X m (I -BB ⊤ )B * . Lemma A.7. Let L min := min m∈[M ] L m . Let δ k := 10C √ k √ log M √ Lmin for some absolute constant C. Suppose L min is such that 0 ≤ δ k < 1. Then, with probability at least 1 -e -50k log M , ∥B ⊤ X m (I -BB ⊤ )B * ∥ 2 ≤ dist(B * , B)δ k . Proof. We will use an ϵ-net argument, similar to the proof of Theorem 4.6.1 in (Vershynin, 2018) . First, by Corollary 4.2.13 in (Vershynin, 2018) , there exists an 1/4-net N of the unit sphere S k-1 with cardinality N ≤ 9 k . Using Lemma 4.4.1 in (Vershynin, 2018) , we have that ∥B ⊤ X m (I -BB ⊤ )B * ∥ 2 ≤ 2 max z∈N B ⊤ X m (I -BB ⊤ )B * z, z . To prove our result, by applying a union bound over m ∈ [M ], it suffices to show that with the probability at least 1 -e -100k 2 log M , max z∈N B ⊤ X m (I -BB ⊤ )B * z, z ≤ δ km 2 ∀m ∈ [M ], where we recall that δ k,m = C √ k √ L m + 9 √ k log(M ) √ L m ≤ δ k . We will assume that min m L m := L min ≥ 1 is chosen large enough such that δ k,m ≤ 1. For a fixed z ∈ S k-1 , observe that B ⊤ X m (I -BB ⊤ )B * z, z = 1 L m n i=1 Li,m j=1 B ⊤ x j i,m (x j i,m ) ⊤ (I -BB ⊤ )B * z, z := 1 L m n i=1 Li,m j=1 (z ⊤ u j i,m )((v j i,m ) ⊤ z), where we defined u j i,m := B ⊤ x j i,m , and v j i,m = (B * ) ⊤ (I -BB ⊤ )x j i,m . Since each x j i,m is 1-subgaussian, ∥B∥ 2 = 1, and ∥(I -BB ⊤ )B * ∥ 2 = dist(B * , B), it follows that z ⊤ u j i,m is subgaussian with norm at most 1, and (v j i,m ) ⊤ z is subgaussian with norm at most dist(B * , B). Thus, the random variable α j i,m := (z ⊤ u j i,m )((v j i,m ) ⊤ z) (for a fixed unit z) is subexponential with sub-exponential norm at most dist(B * , B). Moreover, note that α j i,m is mean-zero, since E[u j i,m (v j i,m ) ⊤ ] = E[B ⊤ x j i,m (x j i,m ) ⊤ (I -BB ⊤ )B * ] = B ⊤ (I -BB ⊤ )B * = 0, as x j i,m is assumed to have identity covariance. Thus, the α j i,m 's are i.i.d mean-zero subexponential variables each with subexponential norm at most dist(B * , B). Hence, by Bernstein's inequality (cf. Corollary 2.8.3 in (Vershynin, 2018) ), P B ⊤ X m (I -BB ⊤ )B * z, z ≥ δ k,m dist(B * , B) 2 = P   1 L m n i=1 Li,m j=1 α j i,m ≥ δ k,m dist(B * , B) 2   ≤ 2 exp -c min( δ k,m dist(B * , B) dist(B * , B) , δ k,m dist(B * , B) dist(B * , B) 2 )L m = 2 exp(-cδ 2 k,m L m ) ≤ 2 exp(-cC 2 (k + 81k log(M ) log(M ))). Above we used the assumption that δ k,m ≤ 1 to simplify the minimum operator in the exponent. Taking a union bound over each z ∈ N , it follows that P ∥B ⊤ X m (I -BB ⊤ )B * ∥ 2 ≥ δ k,m dist(B * , B) ≤ P 2 max z∈N B ⊤ X m (I -BB ⊤ )B * z, z ≥ δ k,m dist(B * , B) ≤ 2 • 9 k exp(-cC 2 (k + 81k log(M ))) ≤ exp(-51k log M ), where the last inequality follows by picking C large enough (but still it is an absolute constant). (Here the choice of 51 in the exponent is somewhat arbitrary; any choice smaller than 81 should work). By applying a union bound over the domains m ∈ [M ], this then completes our proof. We are now finally ready to bound ∥F ∥ F . Lemma A.8. Let L min := min m∈[M ] L m . Let δ k := 10C √ k √ log(M ) √ Lmin for some absolute constant C. Suppose that 0 ≤ δ k < 1. Then, with probability at least 1 -2e -50k log(M ) , ∥F ∥ F ≤ δ k 1 -δ k dist(B * , B)∥W * ∥ F . Proof. By Lemma A.6 and Lemma A.7, we have that with probability at least 1 -2e -50k log M , G -1 m (B ⊤ X m (I -BB ⊤ )B * ) 2 ≤ G -1 m 2 B ⊤ X m (I -BB ⊤ )B * 2 ≤ 1 1 -δ k (δ k dist(B * , B)) . The proof then follows by recalling that the m-th row, F ⊤ m , takes the form F ⊤ m = G -1 m B ⊤ X m (I -BB ⊤ )B * w * m ⊤ .

A.4.3 ANALYSIS OF UPDATE FOR B t+1

Similarly to (Collins et al., 2021) , we define Q t = W t+1 (B t ) ⊤ -(W * )(B * ) ⊤ . Below, we drop the time index and use B, Q, W to denote B t , Q t , and W t+1 respectively. Based on algorithm 2, we have that Bt+1 = B - η M M m=1 1 L m n i=1 Li,m j=1 (w ⊤ m B ⊤ x j i,m -y j i,m )x j i,m w ⊤ m = B - η M M m=1 1 L m n i=1 Li,m j=1 A j i,m , W B ⊤ -A j i,m , W * (B * ) ⊤ (A j i,m ) ⊤ W, A j i,m := e m (x j i,m ) ⊤ = B - η M M m=1 1 L m n i=1 Li,m j=1 A j i,m , Q (A j i,m ) ⊤ W = B - η M M m=1 1 L m n i=1 Li,m j=1 x j i,m (x j i,m ) ⊤ q m w ⊤ m = B - η M Q ⊤ W -   η M M m=1 1 L m n i=1 Li,m j=1 x j i,m (x j i,m ) ⊤ q m w ⊤ m - η M Q ⊤ W   H Q . Above, we define q m ∈ R d to denote the m-th row of Q (viewed as a column vector). Note again that since 1 L m n i=1 Li,m j=1 x j i,m (x j i,m ) ⊤ ≈ I d , the term H Q in equation 21 can be appropriately bounded. Note the resemblance of equation 21 to (53) in (Collins et al., 2021) ; the crucial difference is that we will need to lower bound 1 m σ 2 min (W * ), instead of 1 n σ 2 min (W * ) as in (Collins et al., 2021) . Thus we should be able to carry out the rest of the analysis in a similar way to the outline in (Collins et al., 2021) and derive an analogous result to Theorem 1 in (Collins et al., 2021) . We first bound the error term H Q . Lemma A.9. Let H t Q := η M M m=1 1 L m n i=1 Li,m j=1 x j i,m (x j i,m ) ⊤ q m (w t+1 m ) ⊤ - η M (Q t ) ⊤ W t+1 . Let γ k := 20k √ d c √ nL for some absolute constant c. Suppose that 0 ≤ γ k < k. Then, for any t, with probability at least 1 -exp(-90d) -2e -50k log M , ∥H t Q ∥ 2 ≤ ηγ k dist(B * , B t ). Proof. As before, we may omit the time superscript t in cases where it is clear for notational convenience. The proof is based on the argument in Lemma 5 in (Collins et al., 2021) . Again, the main tool is an ϵ-net argument. We first bound ∥q m ∥ 2 and ∥w m ∥ 2 . Bounding q m : With probability at least 1 -2e -50k log M , for each m ∈ [M ], we have that ∥q m ∥ 2 = B t ((B t ) ⊤ B * w * m + F m ) -B * w * m 2 ≤ (B t (B t ) ⊤ -I)B * w * m 2 + B t F m 2 ≤ dist(B t , B * )∥w * m ∥ 2 + ∥F m ∥ 2 ≤ √ kdist(B t , B * ) + δ k 1 -δ k dist(B t , B * )∥w * m ∥ 2 ≤ 2 √ kdist(B t , B * ). Above, we utilized the assumption that ∥w * m ∥ 2 ≤ √ k, the orthonormality of B t (which was derived as the orthogonal matrix from a Gram-Schmidt procedure), the assumption that 0 < δ k ≤ 1/2, as well Lemma A.8 which bounds ∥F m ∥ 2 (for all m) with probability at least 1 -2e -50k log M . Bounding w m : Note that for notational convenience, we let w m denote w t+1 m . For each t and every m ∈ [M ], we have that ∥w t+1 m ∥ 2 = (B t ) ⊤ B * w * m + F m 2 ≤ ∥w * m ∥ 2 + ∥F m ∥ 2 ≤ ∥w * m ∥ 2 + δ k 1 -δ k dist(B t , B * )∥w * m ∥ 2 ≤ 3 √ k, with probability at least 1 -2e -50k log M , where again we used Lemma A.8 to handle ∥F m ∥ 2 , the assumption that δ k < 1/2, and the fact that dist(B t , B * ) ≤ 2. For the rest of the proof, we condition on the event E := ∥q m ∥ 2 ≤ 2 √ kdist(B t , B * ) and ∥w m ∥ 2 ≤ 3 √ k ∀m ∈ [M ] , which holds with probability at least 1 -2e -50k log M . ϵ-net argument to bound H Q : Again, note that there exists an 1/4-net N k of the unit sphere S k-1 and an 1/4-net N d of the unit sphere S d-1 with cardinalities less than or equal to 9 k and 9 d respectively. Note now that by Equation 4.13 in (Vershynin, 2018) , we have ∥H Q ∥ 2 = η M M m=1 1 L m n i=1 Li,m j=1 x j i,m (x j i,m ) ⊤ q m w ⊤ m - η M Q ⊤ W 2 ≤ 2η max u∈N d ,v∈N k 1 M M m=1 1 L m n i=1 Li,m j=1 x j i,m (x j i,m ) ⊤ q m w ⊤ m -q m w ⊤ m u, v = 2η max u∈N d ,v∈N k 1 M M m=1 1 L m n i=1 Li,m j=1 u ⊤ x j i,m (x j i,m ) ⊤ q m w ⊤ m v -q m w ⊤ m u, v Fix now a u ∈ N d and v ∈ N k . Note now that u ⊤ x j i,m (x j i,m ) ⊤ q m w ⊤ m v is subexponential with norm less than or equal to ∥q m ∥ 2 ∥w m ∥ 2 ≤ 6kdist(B t , B * ), since it is the product of two subgaussian variables u ⊤ x j i,m and (x j i,m ) ⊤ q m w ⊤ m v with subgaussian norms bounded by 1 and ∥q m ∥ 2 ∥w m ∥ 2 respectively. Note also that E u ⊤ x j i,m (x j i,m ) ⊤ q m w ⊤ m v = E q m w ⊤ m u, v . Thus, by Bernstein's inequality, carrying on from equation 22, we have that P   1 M M m=1 1 L m n i=1 Li,m j=1 u ⊤ x j i,m (x j i,m ) ⊤ q m w ⊤ m v -q m w ⊤ m u, v ≥ ρ   ≤ exp -cnL min ρ 6kdist(B t , B * ) , ρ kdist(B t , B * ) 2 ≤ exp -cnL ρ kdist(B t , B * ) 2 , where we will choose ρ such that ρ kdist(B t ,B * ) ≤ 1 to simplify the exponent in the way we did, and c is an absolute constant that may change from line to line. Above, we also used the fact that M m=1 L m = nL (recall that L is the total number of samples per agent and there are n agents). Consider the choice ρ = 10 k √ ddist(B t , B * ) c √ nL . Then, P   1 M M m=1 1 L m n i=1 Li,m j=1 u ⊤ x j i,m (x j i,m ) ⊤ q m w ⊤ m v -q m w ⊤ m u, v ≥ ρ   ≤ exp -cnL ρ kdist(B t , B * ) 2 ≤ exp(-100d). Taking a union bound over all u ∈ N d and v ∈ N k , it follows then that P ∥H Q ∥ 2 η ≥ 2ρ ≤ 9 d+k exp(-100d) ≤ exp(-90d), where above we used the fact that d ≥ k. . Then, with probability at least 1 -exp(-90), L min ≥ α. Proof. Note that L m = n i=1 L j=1 1(domain(x j i ) = m), which is a sum of nL independent random variables bounded between 0 and 1. Moreover, E[L m ] = n i=1 π i,m L, where π i,m is the probability that a datapoint comes from domain m for client i. Note finally that E[ 1(domain(x j i ) = m) 2 ] = π i,m . Hence, by Bernstein's inequality, it follows that for any s > 0, P L i,m ≤ n i=1 π i,m L -s ≤ exp - s 2 /2 n i=1 L j=1 π i,m + s/3 . Since we wish to perform union bound over the M domains, we seek to choose s and L such that exp - s 2 /2 n i=1 L j=1 π i,m + s/3 ≤ exp (-91 log M ) , so that M exp - s 2 /2 n i=1 L j=1 π i,m + s/3 ≤ M exp (-91 log M ) ≤ exp (-90 log M ) . To this end, note that we need s 2 /2 n i=1 L j=1 π i,m + s/3 ≥ 91 log M ⇐⇒ s 2 ≥ 2 • 91 log M   n i=1 L j=1 π i,m + s/3   ⇐⇒ s ≥ 182 log M n i=1 π i,m L + 182 log M 3 2 + 182 log M 3 Suppose we pick L such that n i=1 π i,m L ≥ 182 log M, so that 182 log M n i=1 π i,m L + 182 log M 3 2 + 182 log M 3 ≤ 2 n i=1 π i,m L. Then, by picking s = 2 n i=1 π i,m L, it follows that exp - s 2 /2 n i=1 L j=1 π i,m + s/3 ≤ exp (-91 log M ) , such that for each m ∈ [M ], P   L i,m ≤ n i=1 π i,m L -2 n i=1 π i,m L   ≤ exp(-91 log M ). By choosing L such that n i=1 π i,m L ≥ 4, it follows that P L i,m ≤ n i=1 π i,m L 2 ≤ exp(-91 log M ). The result now follows by choosing L such that it also satisfies n i=1 π i,m L 2 ≥ α for each m.  L ≥ 400dk 2 nc 1 min 1 2 , 8E 0 /(25 • 5κ 2 ) 2 , where c > 0 is absolute constant. Suppose also that L ≥ max    182 log M n i=1 π i,m , 16 n i=1 π i,m , 2 (100Ck log M ) 1 (min{ 1 2 ,8E0/(25•5κ 2 )}) 2 n i=1 π i,m    , which by Lemma A.10, ensures that with probability at least 1 -e -90 , L t min ≥ (100Ck log M ) 1 min 1 2 , 8E 0 /(25 • 5κ 2 ) 2 , where L t min = min m∈[M ] L t m denotes the minimum number of samples from any domain at iteration t, and C > 0 is an absolute constant. Then, for any η ≤ 1/(4σ 2 max, * ), we have dist(B t+1 , B * ) ≤ (1 -ηE 0 σmin, * /2) 1/2 dist(B t , B * ), with probability at least 1 -e -80 . Proof. We begin with the observation that W t+1 = W * (B * ) ⊤ B t + F t Bt+1 = B t - η M (Q t ) ⊤ W t+1 -H t Q , where Q t = W t+1 (B t ) ⊤ -(W * )(B * ) ⊤ , H t Q := η M M m=1 1 L m n i=1 Li,m j=1 x j i,m (x j i,m ) ⊤ q m (w t+1 m ) ⊤ - η M (Q t ) ⊤ W t+1 . Above Bt+1 denotes the estimate of B before we perform the QR decomposition. We note that the updates for W and B are exactly analogous to the updates for W and B as seen in the proof of Lemma 6 in (Collins et al., 2021) . The only two differences are 1. The definitions of F in our paper and (Collins et al., 2021) are slightly different. However, in both cases, ∥F ∥ F ≤ δ k 1 -δ k dist(B * , B)∥W * ∥ F for some term δ k ≤ 1/2 with high probabilities. In our case, this event holds with probability at least 1 -2e -50k log M , whilst in (Collins et al., 2021) , the event holds with probability at least 1 -exp(-110k 2 log n). 2. The update for Bt+1 in (Collins et al., 2021) takes the form Bt+1 = B t - η rn (Q t ) ⊤ W t+1 - η rn 1 m A † A(Q t ) -Q t ⊤ W t+1 , where 0 ≤ r ≤ 1 is a ratio term used in (Collins et al., 2021) , and m above represents the number of samples used by each learner in (Collins et al., 2021) (which is different from our use of m as an index over the domains). However, we note that with high probabilities, H t Q 2 ≤ ηγ k dist(B t , B * ), η rn 1 m A † A(Q t ) -Q t ⊤ W t+1 2 ≤ ηγ k dist(B t , B * ), where the definition of γ k in both papers differ but both satisfy the assumption that γ k ≤ k. Due to these similarities in the updates for W t+1 and B t+1 with the update in (Collins et al., 2021) , the proof of this lemma follows naturally from the proof of Lemma 6 in (Collins et al., 2021) , by plugging in η M (Q t ) ⊤ W t+1 in the update for Bt+1 in place of η rn (Q t ) ⊤ W t+1 as in (Collins et al., 2021) . In particular, following the same analysis as in (Collins et al., 2021) , we see that on the events in Lemma A.8 and Lemma A.9, following the equation immediately after Equation ( 84) in (Collins et al., 2021) , we have dist(B t , B * ) ≤ 1 1 -4η δk (1-δk ) 2 σmax, * 2 1 -ησ 2 min, * E 0 + 2η δk (1 -δk ) 2 σ2 max, * dist(B t , B * ), where in our case δk = δ k + γ k . Then, by choosing δk < 16E 0 /(25 • 5κ 2 ), it follows that δk < 1/5, and so 1 -ησ 2 min, * E 0 + 2η δk (1 -δk ) 2 σ2 max, * ≤ 1 -4η δk (1 -δ2 k ) σ2 max, * ≤ 1 -ηE 0 σ2 min, * /2, as in equation ( 85) in (Collins et al., 2021) , such that dist(B t+1 , B * ) ≤ (1 -ηE 0 σ2 min, * /2) 1/2 dist(B t , B * ). It remains for us to understand what the constraint on δk spelt out in equation 25, and the constraints on δ k and γ k (in Lemmas A.8 and A.9 respectively) mean in our choice of the sample size L for each agent, and the domain size L m at each iteration. Observe that we need δ k = 10C √ k √ log M √ L min ≤ 1 2 , γ k = 20k √ d c √ nL ≤ 1 2 , ( ) δk = δ k + γ k = 10C √ k √ log M √ L min + 20k √ d c √ nL ≤ 16E 0 /(25 • 5κ 2 ), where c, C > 0 are absolute constants. By choosing L min ≥ (100Ck log M ) 1 min 1 2 , 8E 0 /(25 • 5κ 2 ) 2 L ≥ 400dk 2 nc 1 min 1 2 , 8E 0 /(25 • 5κ 2 ) 2 , we ensure that the requirements in equation 26, equation 27 and equation 28 are all satisfied. The final result then follows by applying Lemma A.10. This then yields the following convergence result, which is a more complete statement of A.4.  L ≥ 400dk 2 nc 1 min 1 2 , 8E 0 /(25 • 5κ 2 ) , where c > 0 is absolute constant. Suppose also that L ≥ max 182 log M n i=1 π i,m , 16 n i=1 π i,m , 2 (100Ck log M ) 1 min{1/2,8E0/(25•5κ 2 )} n i=1 π i,m . Then, for any η ≤ 1/(4σ 2 max, * ), we have dist(B t+1 , B * ) ≤ (1 -ηE 0 σmin, * /2) 1/2 dist(B t , B * ), with probability at least 1 -e -80 . Then for any T and any η ≤ 1/(4σ 2 max, * ), we have dist(B t , B * ) ≤ (1 -ηE 0 σ2 min, * /2) T /2 dist(B 0 , B * ), with probability at least 1 -T e -80 . By assuming that σ 2 min, * > 0, the bound in Theorem 1 decays exponentially. We note that the total number of samples required per client scales with L log(1/ϵ). In addition, in order for the result to be meaningful, we implicitly assume that E 0 is close to 1 such that 0 < 1 -ηE 0 σ2 min < 1. To do so, we note it is possible to choose B 0 such that dist(B 0 , B * ) is close enough to 0, with only a logarithmic increase in sample complexity when the number of samples is uniform across the domains. The argument follows the proof of Theorem 3 in (Tripuraneni et al., 2021) . Theorem A.13. Suppose Assumptions A.1, A.2, A.3 all hold. Suppose also that x 0,j i ∼ N (0, I d ) independently for all i ∈ [n]. Suppose each client i sends the server Z i := L 0 j=1 (y 0,j i ) 2 x j i (x j i ) ⊤ , as well as the integer value of L i , such that the server can compute Z := 1 where m(i, j) denotes the sample of the j-th sample from the i-th client. Let σ min, * := σ min ( Λ), and let σ max, * := σ max ( Λ). Suppose that L 0 ≥ cpolylog(d, nL 0 )σ max, * dk 2 /(nσ 2 min, * ). Then, with probability at least 1 -(nL 0 ) -100 , we have that dist(B 0 , B * ) 2 ≤ Õ σ max, * k 2 d σ 2 min, * nL 0 . In particular, when the number of samples is uniform across the domains, we have that Proof. We omit the proof since it is a slight variant of Theorem 3 in (Tripuraneni et al., 2021) . For completeness, note that in the case when the number of samples is uniform across the domains, some algebra shows that For the synthetic data experiments, we adapt the code from (Collins et al., 2021) and follow a similar protocol. The ground-truth matrices W * ∈ R M ×k and B * ∈ R d×k are generated following the same way as (Collins et al., 2021 ) by sampling each element from i.i.d. standard normal distribution and taking the QR factorization. The same L samples are used for each client during the whole training process. Test samples are generated in the same way as the traning samples but without noise. For all the methods, models are initalized with ramdom Gaussian samples. We set α = 0.4 for experiments in Figure 5 .2.

B.3.2 REAL DATA WITH CONTROLLED DISTRIBUTION

Implementation details. We use Imagenet (Deng et al., 2009) pre-trained ResNet-34 (He et al., 2016) for all experiments on this dataset. All the methods are trained for T = 100 communication rounds, with 20 rounds of FedAvg as warmup. For FedDAR-WA and FedDAR-SA, 5 epochs of local updates are executed for both heads and representation at each round. For the baselines, 5 epochs of local updates are executed at each round for fair comparison. We use Adam optimizer with a learning rate of 1 × 10 -4 for the first 60 rounds and 1 × 10 -5 for the last 40 rounds. The images are resized to 224 × 224 with only random horizontal flip for augmentation. The learning rate and the number of local epochs is tuned by grid search with a fixed batch size of 64. We tuned the projection dimension k for FedDAR-SA among {4,8,16,32,64} with α = 1.0 and used k = 8 for all other α. Our evaluation metrics are the classification accuracy on the whole validation set of FairFace for each race group. We don't have extra local validation set to each client since we assume the data distribution within each domain is consistent across the clients. The numbers reported are the average over the final 10 rounds of communication following the standard practice in (Collins et al., 2021) , and the average of three independent runs with different random seeds.

B.3.3 REAL DATA WITH REAL-WORLD DATA DISTRIBUTION

Dataset details. The detailed statistics of the partial EXAM dataset is summarized in Table 6 . The "Other" category includes American Indian or Alaska native, native Hawaiian or other Pacific islander and patients with more than one race or unknown race. ≥HFO % means the percentage of cases with positive labels (receiving oxygen therapy higher or equal to high-flow oxygen with 72 hours). 



x ∼ p(x) and y = x ⊤ B * w * m . For each client, the local data D i is a mix of data from different domains with mixed coefficients, i.e., D i = m π i,m Dm . FedAvg: learns a single model B and w across the all clients via the following objective, min B,w

Figure 5.2 shows result of our experiments where we set n = 100 clients, M = 5 domains, feature dimension k = 2. We vary the number of training samples per client from 5 to 20. The result shows that FedDAR-SA, achieves four orders of magnitude smaller errors than all the baselines: (1) Local-Only where each client train a model using its own data; (2) FedAvg which learns a single shared model; (3) FedRep which learns shared representation and client-specific heads. (4)Separate FedAvg which trains separate models for each domain using FedAvg. The results demonstrate that our method overcomes the heterogeneity of domain distributions across clients. FedDAR-WA fails to converge under this setting, confirming the effectiveness of the proposed second-order aggregation.

Figure 1: Performance under a different number of training samples per client, the error bars show the standard error from three independent runs.

Figure 2: Age classification accuracy as a function of representation dimension k.

Descent lemma). Define E 0 := 1 -dist 2 (B 0 , B * ) and σmax, * := σ max 1 √ M W * and σmin, * := σ min 1 √ M W * . Let κ := σmax, * σmin, * . Consider any iteration t. Suppose that

Convergence result for Algorithm 2). Define E 0 := 1 -dist 2 (B 0 , B * ) and σmax, * := σ max 1 √ M W * and σmin, * := σ min 1 √ M W * . Let κ := σmax, * σmin, * . Suppose that

i . Then, the server computes U DU ⊤ ← rank-k SVD (Z), and sets B 0 := U . Let

dist(B 0 , B * ) 2 ≤ Õ κ 4 k 2 d nL 0 ,where we recall that κ := σmax, * /σ min, * , and σmax, * := σ max 1

dist(B 0 , B * ) 2 ≤ Õ κ 2 k 2 d σ2 min, * nL 0 .However, since k/4M ≤ ∥W * ∥ 2

.1 ALGORITHM OVERVIEW Our model is made of a shared encoder ϕ(•; θ) and M domain specific heads h m (•; w m ) whose are parameterized by neural networks with the weights θ and w m , ∀m ∈ [M ]. According to our problem formation in Equation 1, our algorithm aims to solve the following optimization,

Min, max and average test accuracy of age classification across 7 domains (race groups) on FairFace with number of clients n = 5, number of samples at each client L i = 500

AUCs result on EXAM dataset with the domain being race group. Numbers are the means and standard deviations of metrics from 5-fold cross-validation.

Ablation results of different components' contribution in FedDAR.

A.4.4 COMBINING EARLIER ARGUMENT: CONVERGENCE OF FEDDARAs seen in Lemma A.8, we require that L min := min m∈[M ] L m to be lower bounded. However, since L m is a stochastic variable, we are unable to directly lower bound it. Below, we provide a result that converts a lower bound on each client's sample size L (a deterministic quantity we can control) to a high-probability lower bound on L min . Lemma A.10. Let L min := min m∈[M ] L m . For any α > 0, suppose that for each m ∈ [M ],

Data summary of the partial EXAM dataset used in our study. We apply 5-fold cross validation. The input of the model is one chest x-ray image resized to 224x224 paired with a 22-dimensional electronic health record(EHR) data, the representation dimension is 278 if it is not projected. All the models are trained for T = 20 communication rounds with Adam optimizer and a learning rate of 1 × 10 -4 . For each round we do 1 local epoch for all the methods. For all the methods, the models are initialized with the same pretrained model as in(Dayan et al., 2021) without any warmup. For FedDAR-SA and FedDAR-WA, we execute 5 epochs of update for heads on each round, and set representation dimension k = 16 for FedDAR-SA. Hyperparameters including learning rate, number of epochs for head update and representation dimension are tuned through grid search with a fixed batch size of 36. For FedRep,FedDARand FedPer. For LG-FedAvg, we treated the last fully-connected layer as the global parameters and all other layers as local representation. For FedMinMax, multiple local iterations are executed during each round instead of one step of GD for reasonable comparison. For FedProx

ACKNOWLEDGMENTS

This work has been supported by NIH 1R01HL159183.

B.1 EXPERIMENTS ON FAIRFACE DATASET FOR GENDER CLASSIFICATION

We also conduct experiments for gender classification on FairFace with the same settings. The best representation dimension is k = 2 for this task, probably due to the smaller diversity across the domains. We can see that the results shown in We perform additional experiments on the digits dataset with five data domains with feature shift (Li et al., 2021b) . Details are described in the following paragraphs. From Table 5 , we can see that FedDAR-WA outperforms FedAvg consistently except in the case where domain distributions are extremely heterogeneous (α = 0.1). In this case, each client tends to have data from only one domain. It is difficult for the proposed method to learn a good domain-specific head for the domain with the most different data (more obvious feature shift) under this circumstance. For other levels of heterogeneity, although the min and max domain accuracies are similar between FedAvg and FedDAR-WA, the average accuracies are improved as a result of the domain-wise personalized model. On the other hand, without an alternative update of the head and representation, FedAvg + Multi-head will overfit quickly. We don't include the results of FedDAR-SA here because using representation dimension k ≥ 64 causes numerical instability during head aggregation and failure to converge. While using representation dimension k ≤ 32 leads to lower accuracy.Datasets. We use the same digits dataset containing five different data domains as (Li et al., 2021b) . Specifically, we use SVHN (Netzer et al., 2011) , USPS (Hull, 1994) , SynthDigits (Ganin & Lempitsky, 2015) , MNIST-M (Ganin & Lempitsky, 2015) and MNIST (LeCun et al., 1998) as five data domains. Similarity to the experiments on FairFace datraset, the training data is divided into n clients without duplication. Each client has a domain distribution π i ∼ Dir(αp) sampled from a Dirichlet distribution.Implementation Details. We adapt the codebase from (Li et al., 2021b) . A 6-layer CNN with 3 convolutional layers and 3 fully-connected layers is used, with the last layer as domain-specific head. We use SGD optimizer with learning rate 10 -2 and cross-entropy loss. The batch size is set to 32, and the total communication rounds is set to 100. For each method, we first train the model for 10 rounds with 1 local epoch using FedAvg as warmup. The accuracy shown is the average over the last ten communication rounds. We repeat experiment for each setting three times with different random seeds and report the averages.we tuned µ among {0.05, 0.1, 0.25, 0.5} and used µ = 0.1. For the fine-tuning methods, we only fine-tune the global trained model locally with Adam optimizer and learning rate of 5e -5 for 1 epoch since more epochs of fine-tuning leads to worse results.The models are evaluated by aggregating predictions on the local validation sets then calculating the area under curve (AUC) for each domain. The average AUCs on local validation set of clients are also reported. The AUC shown is first averaged over the last five communication rounds, and then averaged over five runs of 5-fold cross validation.

C DISCUSSION ON COMMUNICATION AND PRIVACY

Communication For FedDAR-WA, the only communication overhead comes from the extra parameters of multiple heads for different domains, which only slightly increase the communication cost. For FedDAR-SA, we need to send a Hessian with k 2 × N 2 parameters from each client to the server at each round. This might be costly when both representation dimension k and output dimension N are large. However, compared to sending millions of parameters of neural network, the extra communication cost is acceptable.Privacy For the FedDAR-WA, there is no extra parameters shared compared to FedAvg. So there is no additional privacy risk introduced. Privacy techniques like homomorphic encryption (Cheon et al., 2017) or differential privacy (McMahan et al., 2017b; Kairouz et al., 2021) that apply to FedAvg also works for FedDAR-WA. In fact, the multi-head design of different domains makes it harder to perform gradient based attack (Zhu et al., 2019) targeting our method. Because the attacker need to first figure out which domain the sample comes from. For FedDAR-SA, the only extra parameters shared is the Hessian matrices, which are aggregated results from all the local data. Recovering the information for a specific sample from Hessian is extremely difficult. Under the worst circumstance, what the attacker can recover from the Hessian is the label and the features at last layer, which hardly ease the difficulty of recovering original input.

