A HIERARCHICAL BAYESIAN APPROACH TO FEDER-ATED LEARNING

Abstract

We propose a novel hierarchical Bayesian approach to Federated Learning (FL), where our model reasonably describes the generative process of clients' local data via hierarchical Bayesian modeling: constituting random variables of local models for clients that are governed by a higher-level global variate. Interestingly, the variational inference in our Bayesian model leads to an optimisation problem whose block-coordinate descent solution becomes a distributed algorithm that is separable over clients and allows them not to reveal their own private data at all, thus fully compatible with FL. We also highlight that our block-coordinate algorithm has particular forms that subsume the well-known FL algorithms including Fed-Avg and Fed-Prox as special cases. That is, we not only justify the previous Fed-Avg and Fed-Prox algorithms whose learning protocols look intuitive but theoretically less underpinned, but also generalise them even further via principled Bayesian approaches. Beyond introducing novel modeling and derivations, we also offer convergence analysis showing that our block-coordinate FL algorithm converges to an (local) optimum of the objective at the rate of O(1/ √ t), the same rate as regular (centralised) SGD, as well as the generalisation error analysis where we prove that the test error of our model on unseen data is guaranteed to vanish as we increase the training data size, thus asymptotically optimal.

1. INTRODUCTION

Federated Learning (FL) aims to enable a set of clients to collaboratively train a model in a privacy preserving manner, without sharing data with each other or a central server. Compared to conventional centralised optimisation problems, FL comes with a host of statistical and systems challenges -such as communication bottlenecks and sporadic participation. The key statistical challenge is non-i.i.d. data distributions across clients, each of which has a different data collection bias and potentially a different data annotation policy/labeling function -for example, in the case of any user preference learning. The classic and most popularly deployed FL algorithms are Fed-Avg (McMahan et al., 2017) and Fed-Prox (Li et al., 2018) , however, even when a global model can be learned, it often underperforms on each client's local data distribution in scenarios of high heterogeneity (Li et al., 2019; Karimireddy et al., 2019; Wang et al., 2020) . Studies have attempted to alleviate this by personalising learning at each client, allowing each local model to deviate from the shared global model Sun et al. (2021) . However, this remains challenging given that each client may have a limited amount of local data for personalised learning. These challenges have motivated several attempts to model the FL problem from a Bayesian perspective. Introducing distributions on model parameters θ has enabled various schemes for estimating a global model posterior p(θ|D 1:N ) from clients' local posteriors p(θ|D i ), or to regularise the learning of local models given a prior defined by the global model Zhang et al. (2022) ; Al-Shedivat et al. (2021) ; Chen & Chao (2021) . However, these methods are not complete and principled solutionshaving not yet have provided full Bayesian descriptions of the FL problem, and having had resort to ad-hoc treatments to achieve tractable learning. The key difference is that they fundamentally treat network weights θ as a random variable shared across all clients. We introduce a hierarchical Bayesian model that assigns each client it's own random variable for model weights θ i , and these are linked via a higher level random variable ϕ as p(θ 1:N , ϕ) = p(ϕ) N i=1 p(θ i |ϕ). This has several crucial benefits: Firstly, given this hierarchy, variational inference in our framework decomposes into separable optimisation problems over θ i s and ϕ, enabling a practical Bayesian learning algorithm to be derived that is fully compatible with FL constraints, without resorting to ad-hoc treatments or strong assumptions. Secondly, this framework can be instantiated with different assumptions on p(θ i |ϕ) to deal elegantly and robustly with different kinds of statistical heterogeneity, as well as for principled and effective model personalisation. Our resulting algorithm, termed Federated Hierarchical Bayes (FedHB) is empirically effective, as we demonstrate in a wide range of experiments on established benchmarks. More importantly, it benefits from rigorous theoretical support. In particular, we provide convergence guarantees showing that FedHB has the same O(1/ √ T ) convergence rate as centralised SGD algorithms, which are not provided by related prior art Zhang et al. (2022) ; Chen & Chao (2021) . We also provide a generalisation bound showing that FedHB is asymptotically optimal, which has not been shown by prior work such as Al-Shedivat et al. (2021) . Furthermore we show that FedHB subsumes classic methods FedAvg McMahan et al. (2017) and FedProx Li et al. (2018) as special cases, and ultimately provides additional justification and explanation for these seminal methods.

2. BAYESIAN FL: GENERAL FRAMEWORK

We introduce two types of latent random variables, ϕ and {θ i } N i=1 . Each θ i is deployed as the network weights for client i's backbone. The variable ϕ can be viewed as a globally shared variable that is responsible for linking the individual client parameters θ i . We assume conditionally independent and identical priors, p(θ 1:N |ϕ) = N i=1 p(θ i |ϕ). Thus the prior for the latent variables (ϕ, {θ i } N i=1 ) is formed in a hierarchical manner as (1) . The local data for client i, denoted by D i , is generatedfoot_0 by θ i , (Prior) p(ϕ, θ 1:N ) = p(ϕ) N i=1 p(θ i |ϕ) (Likelihood) p(D i |θ i ) = (x,y)∈Di p(y|x, θ i ), where p(y|x, θ i ) is a conventional neural network model (e.g., softmax link for classification tasks). See the graphical model in Fig. 1(a) where the iid clients are governed by a single random variable ϕ. Given the data D 1 , . . . , D N , we infer the posterior, p(ϕ, θ 1:N |D 1:N ) ∝ p(ϕ) N i=1 p(θ i |ϕ)p(D i |θ i ), which is intractable in general, and we adopt the variational inference to approximate it: q(ϕ, θ 1:N ; L) := q(ϕ; L 0 ) N i=1 q i (θ i ; L i ), where the variational parameters L consists of L 0 (parameters for q(ϕ)) and {L i } N i=1 's (parameters for q i (θ i )'s from individual clients). Note that although θ i 's are independent across clients under (2) , they are differently modeled (emphasised by the subscript i in notation q i ), reflecting different posterior beliefs originating from heterogeneity of local data D i 's.

2.1. FROM VARIATIONAL INFERENCE TO FEDERATED LEARNING ALGORITHM

Using the standard variational inference techniques (Blei et al., 2017; Kingma & Welling, 2014) , we can derive the ELBO objective function (details in Appendix A). We denote the negative ELBO by L (to be minimised over L) as follows: L(L) := N i=1 E qi(θi) [-log p(D i |θ i )] + E q(ϕ) KL(q i (θ i )||p(θ i |ϕ)) + KL(q(ϕ)||p(ϕ)), (3) where we drop the dependency on L in notation for simplicity. Instead of optimizing (3) over the parameters L jointly as usual practice, we consider block-wise optimisation, also known as blockcoordinate optimisation (Wright, 2015) , specifically alternating two steps: (i) updating/optimizing all L i 's i = 1, . . . , N while fixing L 0 , and (ii) updating L 0 with all L i 's fixed. That is, • Optimisation over L 1 , . . . , L N (L 0 fixed). min {Li} N i=1 N i=1 E qi(θi) [-log p(D i |θ i )] + E q(ϕ) KL(q i (θ i )||p(θ i |ϕ)) . As ( 4) is completely separable over i, and we can optimise each summand independently as: min Li L i (L i ) := E qi(θi;Li) [-log p(D i |θ i )] + E q(ϕ;L0) KL(q i (θ i ; L i )||p(θ i |ϕ)) . (5) So ( 5) constitutes local update/optimisation for client i. Note that each client i needs to access its private data D i only without data from others, thus fully compatible with FL. • Optimisation over L 0 (L 1 , . . . , L N fixed). (6) This constitutes server update criteria while the latest q i (θ i ; L i )'s from local clients being fixed. Remarkably, the server needs not access any local data at all, suitable for FL. This nice property originates from the independence assumption in our approximate posterior (2) . Interpretation. First, server's loss function (6) tells us that the server needs to update q(ϕ; L 0 ) in such a way that (i) it puts mass on those ϕ that have high compatibility scores log p(θ i |ϕ) with the current local models θ i ∼ q i (θ i ), thus aiming to be aligned with local models, and (ii) it does not deviate from the prior p(ϕ). Clients' loss function (5) indicates that each client i needs to minimise the class prediction error on its own data D i (first term), and at the same time, to stay close to the current global standard ϕ ∼ q(ϕ) by reducing the KL divergence from p(θ i |ϕ) (second term).

2.2. FORMALISATION OF GLOBAL PREDICTION AND PERSONALISATION TASKS

Two important tasks in FL are: global prediction and personalisation. The former evaluates the trained model on novel test data sampled from a distribution possibly different from training data. Personalisation is the task of adapting the trained model on a new dataset called personalised data. In our Bayesian model, these two tasks can be formally defined as Bayesian inference problems. Global prediction. The task is to predict the class label of a novel test input x * which may or may not come from the same distributions as the training data D 1 , . . . D N . Under our Bayesian model, it can be turned into a probabilistic inference problem p(y * |x * , D 1:N ). Let θ be the local model that generates the output y * given x * . Exploiting conditional independence from Fig. 1  (c), p(y * |x * , D 1:N ) = p(y * |x * , θ) p(θ|ϕ) p(ϕ|D 1:N ) dθdϕ (7) ≈ p(y * |x * , θ) p(θ|ϕ) q(ϕ) dθdϕ = p(y * |x * , θ) p(θ|ϕ) q(ϕ)dϕ dθ, where in (8) we use p(ϕ|D 1:N ) ≈ q(ϕ). The inner integral (in parentheses) in ( 8) either admits a closed form (Sec. 3.1) or can be approximated (e.g., Monte-Carlo estimation). Personalisation. It formally refers to the task of learning a prediction model p(y|x) given an unseen (personal) training dataset D p that comes from some unknown distribution p p (x, y), so that the personalised model p performs well on novel (in-distribution) test points (x p , y p ) ∼ p p (x, y). Evidently we need to exploit (and benefit from) the trained model from the FL training stage. To this end many existing approaches simply resort to finetuning, that is, training on D p warm-starting with the FL-trained model. However, a potential issue is the lack of a solid principle on how to balance the initial FL-trained model and personal data fitting to avoid underfitting and overfitting. In our Bayesian framework, the personalisation can be seen as another posterior inference problem with additional evidence of the personal training data D p . Prediction on a test point x p amounts to inferring: p(y p |x p , D p , D 1:N ) = p(y p |x p , θ) p(θ|D p , D 1:N ) dθ. So, it boils down to the task of posterior inference p(θ|D p , D 1:N ) given both the personal data D p and the FL training data D 1:N . Under our hierarchical model, by exploiting conditional independence from graphical model (Fig. 1(d )), we can link the posterior to our FL-trained q(ϕ) as follows: p(θ|D p , D 1:N ) ≈ p(θ|D p , ϕ) p(ϕ|D 1:N ) dϕ ≈ p(θ|D p , ϕ) q(ϕ) dϕ ≈ p(θ|D p , ϕ * ), (10) where we disregard the impact of D p on the higher-level ϕ given the joint evidence, p(ϕ|D p , D 1:N ) ≈ p(ϕ|D 1:N ) due to the dominance of D 1:N compared to smaller D p . The last part of (10) makes approximation using the mode ϕ * of q(ϕ), which is reasonable for our two modeling choices for q(ϕ) to be discussed in Sec.  min v E v(θ) [-log p(D p |θ)] + KL(v(θ)||p(θ|ϕ * )). ( ) Once we have the optimised model v, our predictive distribution becomes: p(y p |x p , D p , D 1:N ) ≈ 1 S S s=1 p(y p |x p , θ (s) ), where θ (s) ∼ v(θ), which simply requires feed-forwarding test input x p through the sampled networks θ (s) and averaging. Thus far, we have discussed a general framework, deriving how the variational inference for our Bayesian model fits gracefully in the FL problem. In the next section, we define specific density families for the prior (p(ϕ), p(θ i |ϕ)) and posterior (q(ϕ), q i (θ i )) as our proposed concrete models.

3. BAYESIAN FL: TWO CONCRETE MODELS

We propose two different model choices that we find the most interesting: Normal-Inverse-Wishart (Sec. 3.1) and Mixture (Sec. 3.2). To avoid distraction, we make this section concise putting only the final results and discussions, and leaving all mathematical details in Appendix B and C.

3.1. NORMAL-INVERSE-WISHART (NIW) MODEL

We define the prior as a conjugate form of Gaussian and Normal-Inverse-Wishart. With ϕ = (µ, Σ), p(ϕ) = N IW(µ, Σ; Λ) = N (µ; µ 0 , λ -1 0 Σ) • IW(Σ; Σ 0 , ν 0 ), (13) p(θ i |ϕ) = N (θ i ; µ, Σ), i = 1, . . . , N, where Λ = {µ 0 , Σ 0 , λ 0 , ν 0 } is the parameters of the NIW. Although Λ can be learned via data marginal likelihood maximisation (e.g., empirical Bayes), but for simplicity we leave it fixed asfoot_1 : µ 0 = 0, Σ 0 = I, λ 0 = 1, and ν 0 = d + 2 where d is the number of parameters in θ i or µ. Next, our choice of the variational density family for q(ϕ) is the NIW, not just because it is the most popular parametric family for a pair of mean vector and covariance matrix ϕ = (µ, Σ), but it can also admit closed-form expressions in the ELBO function due to the conjugacy as we derive in Sec. B.1. q(ϕ) := N IW(ϕ; {m 0 , V 0 , l 0 , n 0 }) = N (µ; m 0 , l -1 0 Σ) • IW(Σ; V 0 , n 0 ). (15) Although the scalar parameters l 0 , n 0 can be optimised together with m 0 , V 0 , their impact is less influential and we find that they make the ELBO optimisation a little bit cumbersome. So we fix l 0 , n 0 with some near-optimal values by exploiting the conjugacy of the NIW under Gaussian likelihood (details in Appendix B), and regard m 0 , V 0 as variational parameters, L 0 = {m 0 , V 0 }. We restrict V 0 to be diagonal for computational tractability. The density family for q i (θ i )'s can be a Gaussian, but we find that it is computationally more attractive and numerically more stable to adopt the mixture of two spiky Gaussians that leads to the MC-Dropout (Gal & Ghahramani, 2016) . That is, q i (θ i ) = l p • N (θ i [l]; m i [l], ϵ 2 I) + (1 -p) • N (θ i [l]; 0, ϵ 2 I) , where (i) m i is the only variational parameters (L i = {m i }), (ii) •[l] indicates a column/layer in neural network parameters where l goes over layers and columns of weight matrices, (iii) p is the (user-specified) hyperparameter where 1 -p corresponds to the dropout probability, and (iv) ϵ is small constant (e.g., 10 -4 ) that makes two Gaussians spiky, close to the delta functions. Client update. We apply the general client update optimisation (5) to the NIW model. Following the approximation of (Gal & Ghahramani, 2016) for the KL divergence between a mixture of Gaussians ( 16) and a Gaussian ( 14), we have the client local optimisation (details in Appendix B): min mi L i (m i ) := -log p(D i | mi ) + p 2 (n 0 + d + 1)(m i -m 0 ) ⊤ V -1 0 (m i -m 0 ), ( ) where mi is the dropout version of m i , i.e., a reparametrised sample from ( 16). Note that m 0 and V 0 are fixed during the optimisation. Interestingly (17) generalises Fed-Avg (McMahan et al., 2017) and Fed-Prox (Li et al., 2018) : With p = 1 (i.e., no dropout) and setting V 0 = αI, (17) reduces to the client update formula for Fed-Prox where constant α controls the impact of the proximal term. Server update. The general server optimisation (6) admits the closed-form solution (Appendix B): m * 0 = p N + 1 N i=1 m i , V * 0 = n 0 N + d + 2 (1 + N ϵ 2 )I + m * 0 (m * 0 ) ⊤ + N i=1 ρ(m * 0 , m i , p) , ( ) where ρ(m 0 , m i , p) = pm i m ⊤ i -pm 0 m ⊤ i -pm i m ⊤ 0 + m 0 m ⊤ 0 . Note that m i 's are fixed from clients' latest variational parameters. It is interesting to see that m * 0 in (18) generalises the well-known aggregation step of averaging local models in Fed-Avg (McMahan et al., 2017) and related methods: when p = 1 (no dropout), it almost 3 equals client model averaging. Also, since ρ(m * 0 , m i , p = 1) = (m i -m * 0 )(m i -m * 0 ) ⊤ when p = 1, V * 0 essentially estimates the sample scatter matrix with (N + 1) samples, namely clients' m i 's and server's prior µ 0 = 0, measuring how much they deviate from the center m * 0 . The dropout is known to help regularise the model and lead to better generalisation (Gal & Ghahramani, 2016) , and with p < 1 our (18) forms a principled optimal solution. Global prediction. The inner integral of (8) becomes the multivariate Student-t distribution. Then the predictive distribution for a new test input x * can be estimated asfoot_3 : p(y * |x * , D 1:N ) ≈ 1 S S s=1 p(y * |x * , θ (s) ), where θ (s) ∼ t n0-d+1 θ; m 0 , (l 0 + 1)V 0 l 0 (n 0 -d + 1) , where t ν (a, B) is the multivariate Student-t with location a, scale matrix B, and d.o.f. ν. Personalisation. With the given personalisation training data D p , we follow the general framework in (11) to find v(θ) ≈ p(θ|D p , ϕ * ) in a variational way, where ϕ * obtained from (34). We adopt the same spiky mixture form (16) for v (θ) , which leads to the learning objective similar to (17).

3.2. MIXTURE MODEL

Our motivation for mixture is to make the prior p(θ, ϕ) more flexible by having multiple different prototypes, diverse enough to cover the heterogeneity in data distributions across clients. We consider: p(ϕ) = K j=1 N (µ j ; 0, I), p(θ i |ϕ) = K j=1 1 K N (θ i ; µ j ; σ 2 I), where ϕ = {µ 1 , . . . , µ K } contains K networks (prototypes) that can broadly cover the clients data distributions, and σ is the hyperparameter that captures perturbation scale, chosen by users or learned from data. Note that we put equal mixing proportions 1/K due to the symmetry, a priori. That is, each client can take any of µ j 's equally likely a priori. For the variational densities, we define: q i (θ i ) = N (θ i ; m i , ϵ 2 I), q(ϕ) = K j=1 N (µ j ; r j , ϵ 2 I), where {r j } K j=1 (L 0 ) and m i (L i ) are the variational parameters, and ϵ is small constant (e.g., 10 -4 ). Client update. For our model choice, the general client update (5) reduces to (details in Appendix C): min mi E qi(θi) [-log p(D i |θ i )] -log K j=1 exp - ||m i -r j || 2 2σ 2 . ( ) It is interesting to see that ( 22) can be seen as generalisation of Fed-Prox (Li et al., 2018) , where the proximal regularisation term in Fed-Prox is extended to multiple global models r j 's, penalizing the local model (m i ) straying away from these prototypes. And if we use a single prototype (K = 1), the optimisation (22) exactly reduces to the local update objective of Fed-Prox. Since log-sum-exp is approximately equal to max, the regularisation term in ( 22) effectively focuses on the closest global prototype r j from the current local model m i , which is intuitively well aligned with our motivation. Server update. The general form (6) can be approximately turned into (Appendix C for derivations): min {rj } K j=1 1 2 K j=1 ||r j || 2 - N i=1 log K j=1 exp - ||m i -r j || 2 2σ 2 . ( ) Interestingly, (23) generalises the well-known aggregation step of averaging local models in Fed-Avg and related methods: Especially when K = 1, (23) reduces to quadratic optimisation, admitting the optimal solution r * 1 = 1 N +σ 2 N i=1 m i . The extra term σ 2 can be explained by incorporating an extra zero local model originating from the prior (interpreted as a neutral model) with the discounted weight σ 2 rather than 1. Although (23) for K > 1 can be solved by standard gradient descent, we apply the Expectation-Maximisation (EM) algorithmfoot_4 (Dempster et al., 1977) instead: (E-step) c(j|i) = e -||mi-rj || 2 /(2σ 2 ) K j=1 e -||mi-rj || 2 /(2σ 2 ) , (M-step) r * j = 1 N N i=1 c(j|i) • m i σ 2 N + 1 N N i=1 c(j|i) . ( ) The M-step (server update) has intuitive meaning that the new prototype r j becomes the weighted average of the local models m i 's where the weights c(j|i) are determined by the proximity between m i and r j (i.e., those m i 's that are closer to r j have more contribution, and vice versa). This can be seen as an extension of the aggregation step in Fed-Avg to the multiple prototype case. Global prediction. We slightly modify our general approach to make individual client data dominantly explained by the most relevant model r j , by introducing a gating function from the mixture of experts (Jacobs et al., 1991; Jordan & Jacobs, 1994) . See Appendix C for details. Personalisation. With v(θ) of the same form as q i (θ i ), the VI learning becomes similar to (22).

4. THEORETICAL ANALYSIS

We provide two theoretical results for our Bayesian FL algorithm: (Convergence analysis) As a special block-coordinate optimisation algorithm, we show that it converges to an (local) optimum of the training objective (3); (Generalisation error bound) We theoretically show how well this optimal model trained on empirical data performs on unseen test data points. Due to space limit, full details and proofs are described in Appendix D,E, and we only state the theorems and remarks here. Theorem 4.1 (Convergence analysis). We denote the objective function in (3) by f (x) where x = [x 0 , x 1 , . . . , x N ] corresponding to the variational parameters x 0 := L 0 , x 1 := L 1 , . . . , x N := L N . Let η t = L + √ t for some constant L, and x T = 1 T T t=1 x t , where t is the batch iteration counter, x t is the iterate at t by following our FL algorithm, and N f (≤ N ) is the number of participating clients at each round. With Assumptions 1-3 in Appendix D, the following holds for any T : E[f (x T )] -f (x * ) ≤ N + N f N f • √ T +L 2 D 2 + R 2 f √ T T = O 1 √ T , where x * is the (local) optimum, D, and R f are some constants, and the expectation is taken over randomness in minibatches and selection of participating clients. Remark. It says that x t converges to the optimal point x * in expectation at the rate of O(1/ √ t). This rate asymptotically equals that of the conventional (non-block-coordinate, holistic) SGD algorithm. Theorem 4.2 (Generalisation error bound). Assume that the variational density family for q i (θ i ) is rich enough to subsume Gaussian. Let d 2 (P θi , P i ) be the expected squared Hellinger distance between the true class distribution P i (y|x) and model's P θi (y|x) for client i's data. The optimal solution ({q * i (θ i )} N i=1 , q * (ϕ)) of the optimisation problem (3) satisfies: 1 N N i=1 E q * i (θi) [d 2 (P θi , P i )] ≤ O 1 n + C • ϵ 2 n + C ′ r n + 1 N N i=1 λ * i , with high probability, where C, C ′ > 0 are constant, λ * i = min θ∈Θ ||f θ -f i || 2 ∞ is the best error within our backbone network family Θ, and r n , ϵ n → 0 as the training data size n → ∞. Remark. It implies that the optimal solution of (3) (attainable by our block-coordinate FL algorithm) is asymptotically optimal, since the RHS of (26) converges to 0 as the training data size n → ∞. Comparison to existing Bayesian FL approaches. Some recent studies tried to address the FL problem using Bayesian methods. As we mentioned earlier, the key difference is that these methods do not introduce Bayesian hierarchy, and ultimately treat network weights θ as a random variable shared across all clients, while our approach assigns individual θ i to each client i governed by a common prior p(θ i |ϕ). The non-hierarchical approaches must all resort to ad hoc heuristics or strong assumptions in parts of their algorithm. Due to the lack of space, we leave related references and discussions in Appendix I.

6. EVALUATION

In this section we evaluate the proposed hierarchical Bayesian models on two benchmark datasets: the popular CIFAR-100 and the challenging corrupted version (CIFAR-C-100) that renders the client data more heterogeneous both in input images and class distributions. FL settings for CIFAR-100. We follow the settings similar to those used in (Oh et al., 2022) ; in particular the client data distributions are heterogeneous non-iid, formed by the sharding-based class sampling (McMahan et al., 2017) . More specifically, we partition data instances in each class into non-overlapping equal-sized shards, and assign s randomly sampled shards (over all classes) to each of N clients. Thus the number of shards per user s can control the degree of data heterogeneity: small s leads to more heterogeneity, and vice versa. The number of clients N = 100 (each having 500 training, 100 test samples), and we denote by f the fraction of participating clients. So, N f = ⌊N • f ⌋ clients are randomly sampled at each round to participate in training. Smaller f makes the FL more challenging, and we test two settings: f = 1.0 and 0.1. Lastly, the number of epochs for client local update at each round is denoted by τ where we test τ = 1 and 10, and the number of total rounds is determined by τ as ⌊320/τ ⌋ for fairness. Note that smaller τ incurs more communication cost but often leads to higher accuracy. FL settings for CIFAR-100-Corrupted (CIFAR-C-100). The dataset (Hendrycks & Dietterich, 2019) setting is more challenging compared to CIFAR-100 since the data for personalisation are utterly unseen during the FL training stage. We test τ = 1 and 4 scenarios. Experimental settings. Our implementation is based on Oh et al. (2022) where we use the Mo-bileNet (Howard et al., 2017) as a backbone, and follow the body-update strategy: the classification head (the last layer) is randomly initialised and fixed during training, with only the network body updated (and both the body and head updated at the personalisation stage). We report experimental results all based on this body-update strategy since we observe that it considerably outperforms the full update for our models and all other competing methods. The hyperparameters in our models are: (NIW) ϵ = 10 -4 and p = 1 -0.001 (See below for ablation study of other values); (Mixture) σ 2 = 0.1, ϵ = 10 -4 , mixture order K = 2 (See Appendix F.1 for results with other values), and the gating network has the same architecture as the main backbone, but with output cardinality changed to K. The other hyperparameters including batch size (50), learning rate (0.1 initially and decayed by 0.1) and the number of epochs in personalisation ( 5), are the same as those in (Oh et al., 2022) . Main results. In Table 1 (CIFAR-100) and Table 2 (CIFAR-C-100), we compare our methods (NIW and Mixture with K = 2) against the popular FL methods, Fed-BABU (Oh et al., 2022 ), Fed-Avg (McMahan et al., 2017) , Fed-Prox (Li et al., 2018) , and the recent pFedBayes (Zhang et al., 2022) . The latter is especially interesting to contrast with as it is based on variational inference, most closely related to ours. We run the competing methods (implementation based on their public codes) with default hyperparameters (e.g., µ = 0.01 for FedProx) and report the results. First of all, our two models (NIW and Mix.) consistently perform the best (by large margins most of the time) in terms of both global prediction and personalisation for nearly all FL settings on the two datasets. This is attributed to the principled Bayesian modeling of the underlying FL data generative process in our approaches that can be seen as rigorous generalisation and extension of the existing intuitive algorithms such as Fed-Avg and Fed-Prox. In particular, the superiority of our methods to the other Bayesian approach pFedBayes verifies the effectiveness of modeling client-wise latent variables θ i against the commonly used shared θ modeling, especially for the scenarios of significant client data heterogeneity (e.g., CIFAR-C-100 personalisation on data with unseen corruption types). (Ablation) Hyperparameter sensitivity. We test sensitivity to some key hyperparameters in our models. For NIW, we have p = 1 -p drop , the MC-dropout probability, where we used p drop = 0.001 in the main experiments. In Fig. 2 (a) we report the performance of NIW for different values (p drop = 0, 10 -4 , 10 -2 ) on CIFAR-100 with (s = 100, f = 0.1, τ = 1) setting. We see that the performance is not very sensitive to p drop unless it is too large (e.g., 0.01). For the Mixture model, different mixture orders K = 2, 5, 10 are contrasted in Fig. 2(b ). As seen, having more mixture components does no harm (no overfitting), but we do not see further improvement over K = 2 in our experiments (See also results on CIFAR-C-100 in Table 4 in Appendix). Further analysis. In Appendix, we provide further empirical results: (i) comparison between our mixture model and simple ensemble baselines (Fig. 2 (b) and F.2) and (ii) actual running times (F.3).

7. CONCLUSION

We have proposed a novel hierarchical Bayesian approach to FL where the block-coordinate descent solution to the variational inference leads to a viable algorithm for FL. Our method not only justifies the previous FL algorithms that look intuitive but theoretically less underpinned, but also generalises them even further via principled Bayesian approaches. With strong theoretical support in convergence rate and generalisation error, our approach is also empirically shown to be superior to recent FL approaches by large margin on several benchmarks with various FL settings.

A ELBO DERIVATION FOR GENERAL FRAMEWORK

We derive the ELBO objective (3) for the general Bayesian FL framework. KL q(ϕ, θ 1:N ) p(ϕ, θ 1:N |D 1:N ) = E q log q(ϕ) • i q i (θ i ) • p(D 1:N ) p(ϕ) • i p(θ i |ϕ) • i p(D i |θ i ) (27) = KL(q(ϕ)||p(ϕ)) + N i=1 E qi(θi) [-log p(D i |θ i )] + E q(ϕ) KL(q i (θ i )||p(θ i |ϕ)) =:L(L) + log p(D 1:N ). (28) Since KL divergence is non-negative, -L(L) must be lower bound of the data log-likelihood log p(D 1:N ), rendering L(L) as our objective function (to be minimised).

B NORMAL-INVERSE-WISHART (NIW) MODEL (DETAILED VERSION)

We define the prior as a conjugate form of Gaussian and Normal-Inverse-Wishart. More specifically, each local client has Gaussian prior p(θ i |ϕ) = N (θ i ; µ, Σ) where ϕ = (µ, Σ), and the global latent variable ϕ is distributed as a conjugate prior which is Normal-Inverse-Wishart (NIW), p(ϕ) = N IW(µ, Σ; Λ) = N (µ; µ 0 , λ -1 0 Σ) • IW(Σ; Σ 0 , ν 0 ), (29) p(θ i |ϕ) = N (θ i ; µ, Σ), i = 1, . . . , N, where Λ = {µ 0 , Σ 0 , λ 0 , ν 0 } is the parameters of the NIW. Although Λ can be learned via data marginal likelihood maximisation (e.g., empirical Bayes), but for simplicity we leave it fixed as: µ 0 = 0, Σ 0 = I, λ 0 = 1, and ν 0 = d + 2 where d is the number of parameters in θ i or µ. This choice ensures that the mean of Σ equals I, and µ is distributed as zero-mean Gaussian with covariance Σ. Next, our choice of the variational density family for q(ϕ) is the NIW, not just because it is the most popular parametric family for a pair of mean vector and covariance matrix ϕ = (µ, Σ), but it can also admit closed-form expressions in the ELBO function due to the conjugacy as we derive in Sec. B.1. q(ϕ) := N IW(ϕ; {m 0 , V 0 , l 0 , n 0 }) = N (µ; m 0 , l -1 0 Σ) • IW(Σ; V 0 , n 0 ). (31) Although the scalar parameters l 0 ,n 0 can be optimised together with m 0 , V 0 , their impact is less influential and we find that they make the ELBO optimisation a little bit cumbersome. So we aim to estimate their optimal values in advance with reasonably good quality. To this end, we exploit the conjugacy of the NIW prior-posterior under the Gaussian likelihood. For each θ i , we pretend that we have instance-wise representative estimates θ i (x, y), one for each (x, y) ∈ D i . For instance, one can view θ i (x, y) as the network parameters optimised with the single training instance (x, y). Then this amounts to observing |D| (= N i=1 |D i |) Gaussian samples θ i (x, y) ∼ N (θ i ; µ, Σ) for (x, y) ∼ D i and i = 1, . . . , N . Then applying the NIW conjugacy, the posterior is the NIW with l 0 = λ 0 + |D| = |D| + 1 and n 0 = ν 0 + |D| = |D| + d + 2. This gives us good approximate estimates for the optimal l 0 , n 0 , and we fix them throughout the variational optimisation. Note that this is only heuristics for estimating the scalar parameters l 0 , n 0 quickly, and the parameters m 0 , V 0 are determined by the principled ELBO optimisation (Sec. B.1). That is, L 0 = {m 0 , V 0 }. Since the dimension d is large (the number of neural network parameters), we restrict V 0 to be diagonal for computational tractability. The density family for q i (θ i )'s can be a Gaussian, but we find that it is computationally more attractive and numerically more stable to adopt the mixture of two spiky Gaussians that leads to the MC-Dropout (Gal & Ghahramani, 2016) . That is, q i (θ i ) = l p • N (θ i [l]; m i [l], ϵ 2 I) + (1 -p) • N (θ i [l]; 0, ϵ 2 I) , where (i) m i is the only variational parameters (L i = {m i }), (ii) •[l] indicates the specific column/layer in neural network parameters where l goes over layers and columns of weight matrices, (iii) p is the (user-specified) hyperparameter where 1 -p corresponds to the dropout probability, and (iv) ϵ is a tiny constant (e.g., 10 -6 ) that makes two Gaussians spiky, close to the delta function. Now we provide more detailed derivations for the client optimisation and server optimisation.

B.1 DETAILED DERIVATIONS FOR NIW MODEL

Client update. We work on the objective function in the general client update optimisation (5). We note that q(ϕ) is spiky since our pre-estimated NIW parameters l 0 and n 0 are large (as the entire training data size |D| is added to the initial prior parameters). Due to the spiky q(ϕ), we can accurately approximate the second term in (5) as: E q(ϕ) KL(q i (θ i )||p(θ i |ϕ)) ≈ KL(q i (θ i )||p(θ i |ϕ * )), where ϕ * = (µ * , Σ * ) is the mode of q(ϕ), which has closed forms for the NIW distribution: µ * = m 0 , Σ * = V 0 n 0 + d + 1 . In ( 33) we have the KL divergence between a mixture of Gaussians (32) and a Gaussian (30). Similar to (Gal & Ghahramani, 2016) , we apply the approximation KL( i α i N i ||N ) ≈ i α i KL(N i ||N ) as well as the reparametrised sampling for (32), which allows us to rewrite (5) as: min mi L i (m i ) := -log p(D i | mi ) + p 2 (n 0 + d + 1)(m i -m 0 ) ⊤ V -1 0 (m i -m 0 ), ( ) where mi is the dropout version of m i , i.e., a reparametrised sample from (32). Also, we use a minibatch version of the first term for a tractable SGD update, which amounts to replacing the first term by the batch average E (x,y)∼Batch [-log p(y|x, mi )] while downweighing the second term by the factor of 1/|D i |. Note that m 0 and V 0 are fixed during the optimisation. Interestingly (35) generalises the famous Fed-Avg (McMahan et al., 2017) and Fed-Prox (Li et al., 2018) : With p = 1 (i.e., no dropout) and setting V 0 = αI for some constant α, we see that (35) reduces to the client update formula for Fed-Prox where α controls the impact of the proximal term. Server update. The server optimisation (6) involves two terms, both of which we will show admit closed-form expressions thanks to the conjugacy. Furthermore, we show that the optimal solution (m 0 , V 0 ) of ( 6) has an analytic form. First, the KL term in ( 6) is decomposed as: KL(IW(Σ; V 0 , n 0 )||IW(Σ; Σ 0 , ν 0 )) + E IW(Σ;V0,n0) [KL(N (µ; m 0 , l -1 0 Σ)||N (µ; µ 0 , λ -1 0 Σ))] ) By some algebra, (36) becomes identical to the following, up to constant, removing those terms that are not dependent on m 0 ,V 0 (See Appendix B.2 for derivations): 1 2 n 0 Tr(Σ 0 V -1 0 ) + ν 0 log |V 0 | + λ 0 n 0 (µ 0 -m 0 ) ⊤ V -1 0 (µ 0 -m 0 ) . Next, the second term of (6) also admits a closed form as follows (Appendix B.2 for details): -E q(ϕ)qi(θi) [log p(θ i |ϕ)] = n 0 2 pm ⊤ i V -1 0 m i -pm ⊤ 0 V -1 0 m i -pm ⊤ i V -1 0 m 0 + m ⊤ 0 V -1 0 m 0 + 1 n 0 log |V 0 | + ϵ 2 Tr(V -1 0 ) + const. ( ) That is, server's loss function L 0 is the sum of ( 37) and (38). We can take the gradients of the loss with respect to m 0 , V 0 as follows (also plugging µ 0 = 0, Σ 0 = I, λ 0 = 1, ν 0 = d + 2): ∂L 0 ∂m 0 = n 0 V -1 0 (N + 1)m 0 -p N i=1 m i , ∂L 0 ∂V -1 0 = 1 2 n 0 (1 + N ϵ 2 )I -(N + d + 2)V 0 + n 0 m 0 m ⊤ 0 + n 0 N i=1 ρ(m 0 , m i , p) , where ρ(m 0 , m i , p) = pm i m ⊤ i -pm 0 m ⊤ i -pm i m ⊤ 0 + m 0 m ⊤ 0 . We set the gradients to zero and solve for them, which yields the optimal solution: m * 0 = p N + 1 N i=1 m i , V * 0 = n 0 N + d + 2 (1 + N ϵ 2 )I + m * 0 (m * 0 ) ⊤ + N i=1 ρ(m * 0 , m i , p) . (41) Note that m i 's are fixed from clients' latest variational parameters. It is interesting to see that m * 0 in (41) generalises the well-known aggregation step of averaging local models in Fed-Avg (McMahan et al., 2017) and related methods: when p = 1 (i.e., no dropout), it almost 6 equals client model averaging. Also, since ρ(m * 0 , m i , p = 1) = (m i -m * 0 )(m i -m * 0 ) ⊤ when p = 1, we can see that V * 0 in (41) essentially estimates the sample scatter matrix with (N +1) samples, namely clients' m i 's and server's prior µ 0 = 0, measuring how much they deviate from the center m * 0 . It is known that the dropout can help regularise the model and lead to better generalisation (Gal & Ghahramani, 2016) , and with p < 1 our (41) forms a principled optimal solution. Global prediction. In the inner integral of (8) of the general predictive distribution, we plug p(θ|ϕ) = N (θ; µ, Σ) and NIW q(ϕ) of ( 31). This leads to the multivariate Student-t distribution: p(θ|ϕ) q(ϕ) dϕ = N (θ; µ, Σ) • N IW(ϕ) dϕ = t n0-d+1 θ; m 0 , (l 0 + 1)V 0 l 0 (n 0 -d + 1) , where t ν (a, B) is the multivariate Student-t with location a, scale matrix b, and d.o.f. ν. Then the predictive distribution for a new test input x * can be estimated asfoot_6 : p(y * |x * , D 1 , . . . , D N ) = p(y * |x * , θ) • t n0-d+1 θ; m 0 , (l 0 + 1)V 0 l 0 (n 0 -d + 1) dθ (43) ≈ 1 S S s=1 p(y * |x * , θ (s) ), where θ (s) ∼ t n0-d+1 θ; m 0 , (l 0 + 1)V 0 l 0 (n 0 -d + 1) . ( ) Personalisation. With the given personalisation training data D p , we follow the general framework in (11) to find v(θ) ≈ p(θ|D p , ϕ * ) in a variational way, where ϕ * obtained from (34). For the density family for v(θ) we adopt the same spiky mixture form as (32), v(θ) = l p • N (θ[l]; m[l], ϵ 2 I) + (1 -p) • N (θ[l]; 0, ϵ 2 I) , ( ) where m is the variational parameters. This leads to the MC-dropout-like learning objective, min m -log p(D p | m) + p 2 (n 0 + d + 1)(m -m 0 ) ⊤ V -1 0 (m -m 0 ), Once v is trained, our predictive distribution follows the MC sampling (12).

B.2 MATHEMATICAL DETAILS

The server optimisation (6) in our NIW model involves two terms, both of which we will show admit closed-form expressions thanks to the conjugacy. Furthermore, we show that the optimal solution (m 0 , V 0 ) of ( 6) has an analytic form. First, the KL term in ( 6) is decomposed as: KL(q(ϕ)||p(ϕ)) = KL(q(µ|Σ)q(Σ) || p(µ|Σ)p(Σ)) (47) = E q(Σ) log q(Σ) p(Σ) + E q(Σ) E q(µ|Σ) log q(µ|Σ) p(µ|Σ) = KL(IW(Σ; V 0 , n 0 )||IW(Σ; Σ 0 , ν 0 )) =:kla + E IW(Σ;V0,n0) [KL(N (µ; m 0 , l -1 0 Σ)||N (µ; µ 0 , λ -1 0 Σ))] =:kl b . (49) First we work on kl a = E IW(Σ;V0,n0) [log IW(Σ; V 0 , n 0 )] -E IW(Σ;V0,n0) [log IW(Σ; Σ 0 , ν 0 )]. From the definition of Inverse-Wishart (assuming Σ = (d × d)), log IW(Σ; Ψ, ν) = ν 2 log |Ψ| - ν + d + 1 2 log |Σ| - 1 2 Tr(ΨΣ -1 ) -log Γ d (ν/2) - νd 2 log 2, (50) where Γ d (•) is the multivariate Gamma function. We use the following facts from (Bishop, 2006; Braun & McAuliffe, 2008 ): E IW(Σ;Ψ,ν) log |Σ| = -d log 2 + log |Ψ| - d i=1 ψ((ν -i + 1)/2) (51) E IW(Σ;Ψ,ν) Σ -1 = νΨ -1 , where ψ(•) is the digamma function. Applying these to the terms in kl a yields: kl a = 1 2 n 0 Tr(Σ 0 V -1 0 ) + ν 0 log |V 0 | + const (w.r.t. m 0 , V 0 ). ( ) Next, using the closed-form expression for the KL between Gaussians, kl b becomes: kl b = 1 2 E IW(Σ;V0,n0) λ 0 (µ 0 -m 0 ) ⊤ Σ -1 (µ 0 -m 0 ) + const (w.r.t. m 0 , V 0 ) (54) = λ 0 n 0 2 (µ 0 -m 0 ) ⊤ V -1 0 (µ 0 -m 0 ) + const (w.r.t. m 0 , V 0 ), where in (55) we use the fact (52). Combining ( 56) and ( 55), we have: KL(q(ϕ)||p(ϕ)) = 1 2 n 0 Tr(Σ 0 V -1 0 ) + ν 0 log |V 0 | + λ 0 n 0 (µ 0 -m 0 ) ⊤ V -1 0 (µ 0 -m 0 ) + const. ( ) Next, we derive the second term of (6) (= utc stands for equality up to constant (w.r.t. m 0 , V 0 )). E q(ϕ)qi(θi) [log p(θ i |ϕ)] = - 1 2 E log |Σ| + (θ i -µ) ⊤ Σ -1 (θ i -µ) + const (w.r.t. m 0 , V 0 ) (57) = utc - 1 2 E IW(Σ;V0,n0) log |Σ| - 1 2 E (θ i -µ) ⊤ Σ -1 (θ i -µ) (58) = utc - 1 2 log |V 0 | - 1 2 Tr E IW(Σ;V0,n0) Σ -1 E N (µ;m0,l -1 0 Σ)qi(θi) (θ i -µ)(θ i -µ) ⊤ (59) = utc - 1 2 log |V 0 | - 1 2 Tr E IW(Σ;V0,n0) Σ -1 ρ(m 0 , m i , p) + ϵ 2 I + l -1 0 Σ (60) = utc - 1 2 log |V 0 | - 1 2 Tr ρ(m 0 , m i , p) + ϵ 2 I n 0 V -1 0 (61) = utc - n 0 2 pm ⊤ i V -1 0 m i -pm ⊤ 0 V -1 0 m i -pm ⊤ i V -1 0 m 0 + m ⊤ 0 V -1 0 m 0 + log |V 0 | n 0 + ϵ 2 TrV -1 0 , ( ) where ρ(m 0 , m i , p) = pm i m ⊤ i -pm 0 m ⊤ i -pm i m ⊤ 0 + m 0 m ⊤ 0 , and we use the definition of q i (θ i ) in ( 32) and the fact (52).

C MIXTURE MODEL (DETAILED VERSION)

Previously, the NIW model expresses our prior belief where each client i acquires its own network parameters θ i a priori as a Gaussian-perturbed version of the shared parameters µ, namely θ i |ϕ ∼ N (µ, Σ), as in ( 14). This is intuitively appealing, but may not be adequate for capturing more drastic diversity in local data across clients. In the situations where clients' local data distributions, as well as their domains and class label semantics, are highly heterogeneous (possibly even set up for adversarial purpose), it would be ideal to consider multiple different prototypes for the network parameters, diverse enough to cover the heterogeneity in data distributions across clients. Motivated from this idea, we introduce a mixture prior model as follows. First we consider that there are K network parameters (prototypes) that can broadly cover the clients data distributions. They are denoted as high-level latent variables, ϕ = {µ 1 , . . . , µ K }, and we let them distributed independently as standard normal a priori, p(ϕ) = K j=1 N (µ j ; 0, I). ( ) We here note some clear distinction from the NIW prior. Whereas the NIW prior (13) only controls the mean µ and covariance Σ in the Gaussian, from which local models θ i are sampled, the mixture prior ( 63) is far more flexible in covering highly heterogeneous distributions. Each local model is then assumed to be chosen from one of these K prototypes. Thus the prior distribution for θ i can be modeled as a mixture, p(θ i |ϕ) = K j=1 1 K N (θ i ; µ j ; σ 2 I), ( ) where σ is the hyperparameter that captures perturbation scale, and can be chosen by users or learned from data. Note that we put equal mixing proportions 1/K due to the symmetry, a priori. That is, each client can take any of µ j 's equally likely a priori. We then describe our choice of the variational density q(ϕ) i q i (θ i ) to approximate the posterior p(ϕ, θ 1:N |D 1:N ). First, q i (θ i ) is chosen as a Gaussian, q i (θ i ) = N (θ i ; m i , ϵ 2 I), (65) with small ϵ. For q(ϕ) we consider a Gaussian factorised over µ j 's, but with small variances, that is, q(ϕ) = K j=1 N (µ j ; r j , ϵ 2 I), where {r j } K j=1 are variational parameters (L 0 ) and ϵ is small (e.g., 10 -4 ). The main reason why we make q(ϕ) spiky is that the resulting near-deterministic q(ϕ) allows for computationally efficient and accurate MC sampling during ELBO optimisation as well as test time (global) prediction, avoiding difficult marginalisation (Sec. C.1 for details). Although Bayesian inference in general encourages to retain as many plausible latent states as possible under the given evidence (observed data), we aim to model this uncertainty by having many (possibly redundant) prototypes µ j 's rather than imposing larger variance for a single one (e.g., finite-sample approximation of a smooth distribution).

C.1 DETAILED DERIVATIONS FOR MIXTURE MODEL

With the full specification of the prior distribution and the variational density family, we are ready to dig into the client objective function ( 5) and the server (6). Client update. Since q(ϕ) is spiky, we can accurately approximate the second term of (5) as KL(q i (θ i )||p(θ i |ϕ * )) where ϕ * = {µ * j = r j } K j=1 is the mode of q(ϕ) since E q(ϕ)qi(θi) [log p(θ i |ϕ)] ≈ E qi(θi) [log p(θ i |ϕ * )]. (67) Since q i (θ i ) is also spiky, KL(q i (θ i )||p(θ i |ϕ * )), the KL divergence between a Gaussian and a Gaussian mixture, can be approximated accurately using the single mode sample m i ∼ q i (θ i ), that is, KL(q i (θ i )||p(θ i |ϕ * )) ≈ log q i (m i ) -log p(m i |ϕ * ) (68) = -log K j=1 N (m i ; r j , σ 2 I) + const. = -log K j=1 exp - ||m i -r j || 2 2σ 2 + const. ( ) Note here that we use the fact that m i disappears in log q i (m i ). Plugging it into (5) yields the following optimisation for client i: min mi E qi(θi) [-log p(D i |θ i )] -log K j=1 exp - ||m i -r j || 2 2σ 2 . ( ) It is interesting to see that (70) can be seen as generalisation of Fed-Prox (Li et al., 2018) , where the proximal regularisation term in Fed-Prox is extended to multiple global models r j 's, penalizing the local model (m i ) straying away from these prototypes. And if we use a single prototype (K = 1), the optimisation (70) exactly reduces to the local update objective of Fed-Prox. Since log-sum-exp is approximately equal to max, the regularisation term in (70) effectively focuses on the closest global prototype r j from the current local model m i , which is intuitively well aligned with our initial modeling motivation, namely each local data distribution is explained by one of the global prototypes. Lastly, we also note that in the SGD optimisation setting where we can only access a minibatch B ∼ D i during the optimisation of (70), we follow the conventional practice: replacing the first term of the negative log-likelihood by a stochastic estimate E qi(θi) E (x,y)∼B [-log p(y|x, θ i )] and multiplying the second term of regularisation by 1 |Di| . Server update. First, the KL term in (6) can be easily derived as: KL(q(ϕ)||p(ϕ)) = 1 2 K j=1 ||r j || 2 + const. ( ) and the second term of (6) approximated as follows: E q(ϕ)qi(θi) [log p(θ i |ϕ)] ≈ E q(ϕ) [log p(m i |ϕ)] ≈ log K j=1 1 K N (m i ; r j , σ 2 I) (72) = log K j=1 exp - ||m i -r j || 2 2σ 2 + const. ( ) where the approximations in (72) become accurate due to spiky q i (θ i ) and q(ϕ), respectively. Combining the two terms leads to the optimisation problem for the server: Although (74) for K > 1 can be solved by standard gradient descent, the objective function resembles the (regularised) Gaussian mixture log-likelihood, and we can apply the Expectation-Maximisation (EM) algorithm (Dempster et al., 1977) instead. Using Jensen's bound with convexity of the negative log function, we have the following alternating stepsfoot_7 : min {rj } K j=1 1 2 K j=1 ||r j || 2 - N i=1 log K j=1 exp - ||m i -r j || 2 2σ 2 . ( • E-step: With the current {r j } K j=1 fixed, compute the prototype assignment probabilities for each local model m i : c(j|i) = k ij K j=1 k ij , where k ij = exp - ||m i -r j || 2 2σ 2 . ( ) • M-step: With the current assignments c(j|i) fixed, we solve: min {rj } 1 2 j ||r j || 2 + 1 2σ 2 i,j c(j|i) • ||m i -r j || 2 , which admits the closed form solution: r * j = 1 N N i=1 c(j|i) • m i σ 2 N + 1 N N i=1 c(j|i) , j = 1, . . . , K. The server update equation ( 77) has intuitive meaning that the new prototype r j becomes the weighted average of the local models m i 's where the weights c(j|i) are determined by the proximity to r j (i.e., those m i 's that are closer to r j have more contribution, and vice versa). This can be seen as an extension of the aggregation step in Fed-Avg to the multiple prototype case. Global prediction. By plugging the mixture prior p(θ|ϕ) of ( 64) and the factorised spiky Gaussian q(ϕ) of (66) into the inner integral of (8), we have predictive distribution averaged equally over {r j } K j=1 approximately, that is, p(θ|ϕ) q(ϕ) dϕ ≈ 1 K K j=1 p(y * |x * , r j ). Unfortunately this is not ideal for our original intention where only one specific model r j out of K candidates is dominantly responsible for the local data. To meet this intention, we extend our model so that the input point x * can affect θ together with ϕ, and with this modification our predictive probability can be derived as: p(y * |x * , D 1:N ) = p(y * |x * , θ) p(θ|x * , ϕ) p(ϕ|D 1:N ) dθdϕ (78) ≈ p(y * |x * , θ) p(θ|x * , ϕ) q(ϕ) dθdϕ (79) ≈ p(y * |x * , θ) p(θ|x * , {r j } K j=1 ) dθ. ( ) To deal with the tricky part of inferring p(θ|x * , {r j } K j=1 ), we introduce a fairly practical strategy of fitting a gating function. The idea is to regard p(θ|x * , {r j } K j=1 ) as a mixture of experts (Jacobs et al., 1991; Jordan & Jacobs, 1994) where the prototypes r j 's serving as experts, p(θ|x * , {r j } K j=1 ) := K j=1 g j (x * ) • δ(θ -r j ), where δ(•) is the Dirac's delta function, and g(x) is a gating function that outputs a K-dimensional softmax vector. Intuitively, the gating function determines which of the K prototypes {r j } K j=1 the model θ for the test point x * belongs to. With (81), the predictive probability in ( 80) is written as: p(y * |x * , D 1:N ) ≈ K j=1 g j (x * ) • p(y * |x * , r j ). However, since we do not have this oracle g(x), we introduce and fit a neural network to the local training data during the training stage. Let g(x; β) be the gating network with the parameters β. To train it, we follow the Fed-Avgfoot_8 strategy. In the client update stage at each round, while we update the local model m i with a minibatch B ∼ D i , we also find the prototype closest to m i , namely j * := arg min j ||m i -r j ||. Then we form another minibatch of samples {(x, j * )} x∼B (input x and class label j * ), and update g(x; β) by SGD. The updated (local) β's from the clients are then aggregated (by simple averaging) by the server, and distributed back to the clients as an initial iterate for the next round. Personalisation. For p(θ|D p , ϕ * ) in the general framework (10), we define the variational distribution v(θ) ≈ p(θ|D p , ϕ * ) as: v(θ) = N (θ; m, ϵ 2 I), where ϵ is small positive constant, and m is the only parameters that we learn. Our personalisation training amounts to ELBO optimisation for v(θ) as in ( 11), which reduces to: min m E v(θ) [-log p(D p |θ)] -log K j=1 exp - ||m -r j || 2 2σ 2 . ( ) Once we have optimal m (i.e., v(θ)), our predictive model becomes: p(y p |x p , D p , D 1:N ) ≈ p(y p |x p , m), (85) which is done by feed-forwarding test input x p through the network deployed with the parameters m.

D CONVERGENCE ANALYSIS

Our (general) FL algorithm is a special block-coordinate SGD optimisation of the ELBO function ( 3) with respect to the (N + 1) parameter groups: L 0 (of q(ϕ; L 0 )), L 1 (of q 1 (θ 1 ; L 1 )), . . . , and L N (of q N (θ N ; L N )). In this section we will provide a theorem that guarantees convergence of the algorithm to a local minimum of the ELBO objective function under some mild assumptions. We will also analyse the convergence rate. Note that although our FL algorithm is a special case of the general block-coordinate SGD optimisation, we may not directly apply the existing convergence results for the regular block-coordinate SGD methods since they mostly rely on non-overlapping blocks with cyclic or uniform random block selection strategies (Beck & Tetruashvili, 2013; Wang & Banerjee, 2014) . As the block selection strategy in our FL algorithm is unique with overlapping blocks and non-uniform random block selection, we provide our own analysis here. Promisingly, we show that in accordance with general regular block-coordinate SGD (cyclic/uniform non-overlapping block selection), our FL algorithm has O(1/ √ t) convergence rate, which is also asymptotically the same as that of the conventional (holistic, non-block-coordinate) SGD optimisation. Note that this section is about the convergence of our algorithm to an (local) optimum of the training objective (ELBO). The question of how well this optimal model trained on empirical data performs on the unseen data points will be discussed in Sec. E. First we formally describe our FL algorithm as a block-coordinate SGD optimisation. For ease of exhibition, we will simplify the notation: The objective function in ( 3) is denoted as f (x) where x = [x 0 , x 1 , . . . , x N ] is the optimisation variables corresponding to x 0 := L 0 , x 1 := L 1 , . . . , x N := L N . That is, x 0 is server's parameters while x i (i = 1, . . . , N ) is worker i's parameters. We let x u be the partial vector of x selected by the index set u ⊆ {0, 1, . . . , N }, and x -u be the vector of the rest elements. Similarly ∇ u f (x) indicates the gradient vector with only elements at the indices in u. Let x t be the iterate at iteration t. Our FL algorithm is formally defined in Alg. 1. Now we state our convergence theorem. We first need the following mild assumptions: Assumption 1. Our objective function f (x) is locally convex, and f t (x) is also locally convex for all iterations t, where f t is the minibatch version of f defined on the minibatch data batch t (so that E batcht [f t (x)] = f (x)). Actually, the latter implies the former. Although the negative ELBO is in Algorithm 1 Bayesian FL Algorithm as Block-Coordinate Descent. We define the following hyperparameters: • N f (≤ N ) = the number of participating clients at each round. • M = the number of SGD iterations per round for updating the (participating) clients. • Let MS = the number of SGD iterations per round for updating the server. • Let ηt = the reciprocal of the SGD learning rate at iteration t. Initialise the global iteration counter t = 0. for each round do • Select N f clients uniformly at random from {1, . . . , N } without replacement. Let ut ⊆ {1, . . . , N } be the set of the participants (|ut| = N f ). • (Client update) For each of M iterations, 1. Perform an SGD update for the block ut. That is, x t+1 := [x t+1 u t ; x t -u t ], where x t+1 u t = x t u t - 1 ηt ∇u t ft(x t ), ( ) where ft is the minibatch version of f defined on the minibatch data batcht (so that E batch t [ft(x)] = f (x)). Note that this update is actually done independently over the participating clients i ∈ ut due to the separable objective, i.e., from (4) to (5). 2. t ← t + 1. • (Server update) For each of MS iterations, 1. Perform SGD update for the index (singleton block) 0. That is, x t+1 := [x t+1 0 ; x t -0 ], where x t+1 0 = x t 0 - 1 ηt ∇0ft(x t ). ( ) 2. t ← t + 1. end for general non-convex globally, we can regard it as a convex function within a local neighborhood, as is usually assumed in non-convex analysis (Bertsekas, 2016) and other FL analysis (Li et al., 2020) . Assumption 2. For all t, f t (x) has Lipschitz continuous gradient with constant L. x t by following our FL algorithm. With Assumptions 1-3, the following holds for any T : E[f (x T )] -f (x * ) ≤ N + N f N f • √ T +L 2 D 2 + R 2 f √ T T = O 1 √ T , ( ) where x * is the (local) optimum, and the expectation is taken over randomness in minibatches and block selections {u t } T -1 t=0 . Remark. Theorem D.1 states that x t converges to the optimal point x * in expectation at the rate of O(1/ √ t). This convergence rate asymptotically equals that of the conventional (holistic, non-blockcoordinate) SGD algorithm. To prove the theorem, we note that the algorithm Alg. 1 overall repeats the following three steps per round: i) sample the subset of clients u t from {1, . . . , N } with |u t | = N f , ii) update x ut for M iterations, and iii) update x 0 for M S iterations. Thus in the long-term view, we can see that the algorithm proceeds as follows: At each iteration t, we select u t as u t = {0} with prob. M S M +M S Size-N f subset uniformly from {1, . . . , N } with prob. M M +M S , and update the iterate as x t+1 := [x t+1 ut ; x t -ut ], where x t+1 ut = x t ut - 1 η t ∇ ut f t (x t ). ( ) We will use this long-term view strategy in our proof. Next we state the following lemma that is motivated from (Wang & Banerjee, 2014) , which is useful in our proof. Lemma D.2. Assume η t > L, and f t is (locally) convex with Lipschitz continuous gradient with constant L. For any subset u ⊆ {0, 1, . . . , N }, we define x t+1 as: x t+1 u = x t u - 1 η t ∇ u f t (x t ) and x t+1 -u = x t -u . ( ) establish the following identity: u∈2 N,N f ⟨∇ u f t (x t ), x t u -x u ⟩ = N -1 N f -1 N i=1 ⟨∇ i f t (x t ), x t i -x i ⟩. ( ) We plug ( 107) into ( 105) and apply Lemma D.2: ⟨∇f t (x t ), x t -x⟩ = ⟨∇ 0 f t (x t ), x t 0 -x 0 ⟩ + 1 N -1 N f -1 u∈2 N,N f ⟨∇ u f t (x t ), x t u -x u ⟩ (108) ≤ η t 2 ||x -x t || 2 -||x -x t+1 (0, batch t )|| 2 + R 2 f 2(η t -L) + 1 N -1 N f -1 u∈2 N,N f η t 2 ||x -x t || 2 -||x -x t+1 (u, batch t )|| 2 + R 2 f 2(η t -L) , where we define x t+1 (u, batch t ) for any subset u ⊆ {0, 1, . . . , N } as: x t+1 u := x t u -(1/η t )∇ u f t (x t ) and x t+1 -u := x t -u . Although we can simply use x t+1 , here we use this explicit notation to emphasise dependency of x t+1 on i and batch t . By letting g(x, x t , x t+1 ) : = ηt 2 ||x -x t || 2 -||x -x t+1 || 2 + R 2 f 2(ηt-L) , we can express (109) succinctly to yield: ⟨∇f t (x t ), x t -x⟩ ≤ g(x, x t , x t+1 (0, batch t )) + 1 N -1 N f -1 u∈2 N,N f g(x, x t , x t+1 (u, batch t )). (110) For the second term, we use the uniform expectation to replace the sum (i.e., u∈2 N,N f ψ(u) = N N f E u∼2 N,N f [ψ(u)] for any function ψ). Using N N f / N -1 N f -1 = N/N f , ⟨∇f t (x t ), x t -x⟩ ≤ g(x, x t , x t+1 (0, batch t )) + N N f E u∼2 N,N f [g(x, x t , x t+1 (u, batch t ))], and the right hand side of ( 111) can be written as: M + M S M S M S M + M S g(x, x t , x t+1 (0, batch t )) + M M + M S E u∼2 N,N f [g(x, x t , x t+1 (u, batch t ))] , (112) where we use our specification of M S = M •N f N . Note that the expression inside the parentheses is exactly the expectation of g(x, x t , x t+1 (u t , batch t )) over the random index set u t following our client-server selection strategy in the long term view, that is, (89). Then we have the following result: ⟨∇f t (x t ), x t -x⟩ ≤ M + M S M S E ut [g(x, x t , x t+1 (u t , batch t ))], where u t follows (89). As we have conditioned all quantities on batch t , we now take the expectation over batch t . ⟨∇f (x t ), x t -x⟩ = E batcht [⟨∇f t (x t ), x t -x⟩] (114) ≤ M + M S M S E batcht,ut η t 2 ||x -x t || 2 -||x -x t+1 || 2 + R 2 f 2 √ t , ( ) where we drop the dependency in x t+1 (u t , batch t ) in notation, and use η t = L + √ t. Since f is convex, f (x t ) -f (x) ≤ ⟨∇f (x t ), x t -x⟩, 116) and taking the expectation over x t leads to: E[f (x t )] -f (x) ≤ M + M S M S η t 2 E||x -x t || 2 -E||x -x t+1 || 2 + R 2 f 2 √ t . ( ) By telescoping ( 1T T t=1 ) and using M +M S M S = N +N f N f , we have: E 1 T T t=1 f (x t ) -f (x) ≤ N + N f N f 1 T 1 2 T t=1 η t E||x -x t || 2 -E||x -x t+1 || 2 + R 2 f T t=1 1 2 √ t . ( ) There are two sums in the right hand side of ( 118), and they can be bounded succinctly as follows. First, we use the simple calculus to bound the second sum: T t=1 1 2 √ t ≤ T 1 1 2 √ z dz + 1 2 ≤ √ T . Next, we let a t := E||x -x t || 2 , and the first sum is written as: η 1 (a 1 -a 2 ) + • • • + η T (a T -a T +1 ) = T t=1 (η t -η t-1 )a t -η T a T +1 by letting η 0 = 0. Using a t ≤ D 2 from Assumption 3, this sum is bounded above by D 2 T t=1 (η t -η t-1 ) = D 2 (η T -η 0 ) = ( √ T + L)D 2 . Plugging these bounds into (118) and applying Jensen's inequality to the left hand side (i.e., E[(1/T ) T t=1 f (x t )] ≥ E[f ((1/T ) T t=1 x t )] = E[f (x T )]) yields: E[f (x T )] -f (x) ≤ N + N f N f • √ T +L 2 D 2 + R 2 f √ T T , for any x, which completes the proof.

E GENERALISATION ERROR BOUND

In this section we will discuss generalisation performance of our proposed algorithm, answering the question of how well the Bayesian FL model trained on empirical data performs on the unseen data points. We aim to provide the upper bound of the generalisation error averaged over the posterior distribution of the model parameters (ϕ, {θ i } N i=1 ), by linking it to the expected empirical error with some additional complexity terms. To this end, we first consider the PAC-Bayes bounds (McAllester, 1999; Langford & Caruana, 2001; Seeger, 2002; Maurer, 2004) , naturally because they have similar forms relating the two error terms (generalisation and empirical) expected over the posterior distribution via the KL divergence term between the posterior and the prior distributions. However, the original PAC-Bayes bounds have the square root of the KL in the bound, which deviates from the ELBO objective function that has the sum of the expected data loss and the KL term as it is (instead of the square root). However, there are some recent variants of PAC-Bayes bounds, specifically the PAC-Bayes-λ bound, which removes the square root of the KL and suits better with the ELBO objective function (See (Thiemann et al., 2017) or Eq. ( 5) of (Rivasplata et al., 2019) ). To discuss it further, the objective function of our FL algorithm (3) can be viewed as a conventional variational inference ELBO objective with the prior p(β) and the posterior q(β), where β = {ϕ, θ 1 , . . . , θ N } indicates the set of all latent variables in our model. More specifically, the negative ELBO (function of the variational posterior distribution q) can be written as: -ELBO(q) = E q(β) [ ln (β)] + 1 n KL(q(β)||p(β)), where ln (β) is the empirical error/loss of the model β on the training data of size n. We then apply the PAC-Bayes-λ bound (Thiemann et al., 2017; Rivasplata et al., 2019) ; for any λ ∈ (0, 2), the following holds with probability at least 1 -δ: E q(β) [l(β)] ≤ 1 1 -λ/2 E q(β) [ ln (β)] + 1 λ(1 -λ/2) KL(q(β)||p(β)) + log(2 √ n/δ) n , where l(β) is the generalisation error/loss of the model β. Thus, when λ = 1, the right hand side of (121) reduces to -2 • ELBO(q) plus some complexity term, justifying why maximizing ELBO with respect to q can be helpful for reducing the generalisation error. Although this argument may look partially sufficient, but strictly saying, the extra factor 2 in the ELBO (for the choice λ = 1) may be problematic, potentially making the bound trivial and less useful. Other choice of λ fails to recover the original ELBO with slightly deviated coefficients for the expected loss and the KL. In what follows, we state our new generalisation error bound for our FL algorithm, which does not rely on the PAC-Bayes but the recent regression analysis technique for variational Bayes (Pati et al., 2018; Bai et al., 2020) . It was also adopted in the analysis of some personalised FL algorithm (Zhang et al., 2022) recently.

E.1 GENERALISATION ERROR BOUND VIA REGRESSION ANALYSIS TECHNIQUE

We begin with the regression-based data modeling perspective and related assumptions/notations. We denote by P i (x, y) the true data distribution for client i (i = 1, . . . , N ). We assume that the target y is real vector-valued (y ∈ R Sy ), and there exists a true regression function f i : R Sx → R Sy for each i. That is, P i (y|x) = N (y; f i (x), σ 2 ϵ I), where σ 2 ϵ is constant Gaussian output noise variance. Let D i = (X i , Y i ) ∼ P i (x, y) be the i.i.d. training data of size n for each client i. In our FL model, we assume that our backbone network is an MLP with L hidden layers of width M , and all activation functions σ(•) are Lipschitz continuous with constant 1. The parameters θ of the MLP are also assumed to be bounded, more formally, the parameter space Θ is defined as: (θ) and B is the maximal norm bound. The MLP defines a regression function f θ : R Sx → R Sy , and the θ-induced predictive distribution is denoted as: θ ∈ Θ = {θ ∈ R G : ||θ|| ∞ ≤ B, MLP with L layers of width M }. (123) Note that G = dim P θ (y|x) = N (y; f θ (x), σ 2 ϵ I), where we assume that the true noise variance is known. For notational convenience, we denote by f (X i ) the concatenated vector of f (x) for all x ∈ X i , i.e., f (X i ) = [f (x)] x∈X i , where f (•) is either the true f i (•) or the model f θ (•). Extending this notation, simply writing f i or f θ means infinite-dimensional (population) responses, that is, f i = [f i (x)] x∈R Sx and f θ similarly. For instance, ||f θ -f i || ∞ stands for the worst-case difference, namely max x∈R Sx ||f θ (x) -f i (x)||. As a generalisation error measure, we consider the expected squared Hellinger distance between the true P i and the model P θ . Formally, d 2 (P θ , P i ) = E x∼P i (x) H 2 (P θ (y|x), P i (y|x)) = E x∼P i (x) 1 -exp - ||f θ (x) -f i (x)|| 2 2 8σ 2 ϵ . (125) More specifically, we will bound the posterior-averaged distance 1 N N i=1 E q * i (θi) [d 2 (P θi , P i )], where {q * i (θ i )} N i=1 is an optimal solution of our FL-ELBO optimisation problemfoot_9 (We showed in Sec. D that our block-coordinate FL algorithm converges to this optimal solution in O(1/ √ t) rate). Theorem E.1 (Generalisation error analysis). Assume that the variational density family for q i (θ i ) is rich enough to subsume Gaussian. The optimal solution ({q * i (θ i )} N i=1 , q * (ϕ)) of our FL-ELBO optimisation problem (3) satisfies (with high probability): 1 N N i=1 E q * i (θi) [d 2 (P θi , P i )] ≤ O 1 n + C • ϵ 2 n + C ′ r n + 1 N N i=1 λ * i , where C, C ′ > 0 are constant, λ * i = min θ∈Θ ||f θ -f i || 2 ∞ is the best error within our backbone Θ, r n = G(L + 1) n log M + G n log S x n G , and ϵ n = √ r n log δ (n) for δ > 1 constant. Remark. Theorem E.1 implies that the optimal solution for our FL-ELBO optimisation problem (attainable by our block-coordinate FL algorithm) is asymptotically optimal, since the right hand side of (126) converges to 0 as the training data size n → ∞. This is easy to verify: as n → ∞, r n → 0 obviously, accordingly ϵ n → 0, and the last term 1 N i λ * i can be made arbitrarily close to 0 by increasing the backbone capacity (MLPs as universal function approximators). But practically for fixed n, as enlarging the backbone capacity (i.e., large G, L, and M ) also increases ϵ n and r n , it is important to choose the backbone network architecture properly. Note also that our assumption on the variational density family for q i (θ i ) is easily met; for instance, the families of the mixtures of Gaussians adopted in NIW (Sec. 3.1) and mixture models (Sec. 3.2) obviously subsume a single Gaussian family. Proof of Theorem E.1. We first aim to link the variational ELBO objective function to the Hellinger distance via Donsker-Varadhan's (DV) theorem (Boucheron et al., 2013) , motivated from (Bai et al., 2020; Zhang et al., 2022) . The DV theorem allows us to express the expectation of any exponential function variationally using the KL divergence. More specifically, the following holds for any distributions p, q and any (bounded) function h(z): log E p(z) [e h(z) ] = max q E q(z) [h(z)] -KL(q||p) . Here we set p(z) := p(θ i |ϕ), q(z) := q i (θ i ), h(z) := log η i (θ i ), where η i (θ i ) := exp l n (P θi (D i ), P i (D i )) + nd 2 (P θi , P i ) and ( 129) l n (P θi (D i ), P i (D i )) := log P θi (D i ) P i (D i ) , and we have the following inequality that holds for any ϕ: log E p(θi|ϕ) [η i (θ i )] ≥ E qi(θi) [log η i (θ i )] -KL(q i (θ i )||p(θ i |ϕ)). (131) By taking the expectation with respect to q(ϕ) and rearranging terms, we have n • E qi(θi) [d 2 (P θi , P i )] ≤ E qi(θi) [-l n (P θi (D i ), P i (D i ))] + E q(ϕ) [KL(q i (θ i )||p(θ i |ϕ))] + E q(ϕ) log E p(θi|ϕ) [η i (θ i )] . (132) For the last term of the right hand side, we use the bound E s (θ) [η (θ) ] ≤ e Cnϵ 2 n from the regression theorem (Pati et al., 2018) , which holds for any distribution s(θ) with high probability. The details and proof of this bound can be found in the proof of Theorem 3.1 in (Pati et al., 2018) . Applying this bound and telescoping over i = 1, . . . , N yields: n • N i=1 E qi(θi) [d 2 (P θi , P i )] ≤ N i=1 E qi(θi) [-l n (P θi (D i ), P i (D i ))] + N i=1 E q(ϕ) [KL(q i (θ i )||p(θ i |ϕ))] + N Cnϵ 2 n . We add KL(q(ϕ)||p(ϕ)) to the right hand side, which retains the inequality since KL divergence is nonnegative. Then we have the following result that holds for any q with high probability, n N • N i=1 E qi(θi) [d 2 (P θi , P i )] ≤ L n (q(ϕ), {q i (θ i )} N i=1 ) + Cnϵ 2 n , where L n (q(ϕ), {q i (θ i )} N i=1 ) equals: 1 N N i=1 E qi(θi) [-l n (P θi (D i ), P i (D i ))] + E q(ϕ) [KL(q i (θ i )||p(θ i |ϕ))] + KL(q(ϕ)||p(ϕ)) , (135) which exactly coincides with our FL-ELBO objective (3) up to constant, by the factor of 1/N . Next, we define qi (θ i ) and q(ϕ) as follows: qi (θ i ) = N (θ i ; θ * i , σ 2 n I) with θ * i = arg min θ∈Θ ||f θ -f i || 2 ∞ , σ 2 n = G 8n A, where A -1 = log(3S x M ) • (2BM ) 2(L+1) • S x + 1 + 1 BM -1 2 + 1 (2BM ) 2 -1 + 2 (2BM -1) 2 , q(ϕ) = arg min q(ϕ) N i=1 E q(ϕ) [KL(q i (θ i )||p(θ i |ϕ))] + KL(q(ϕ)||p(ϕ)). (137) Since ({q * i (θ i )} N i=1 , q * (ϕ)) is the optimal solution of the FL-ELBO optimisation problem, it is obvious that L n (q * (ϕ), {q * i (θ i )} N i=1 ) ≤ L n (q(ϕ), {q i (θ i )} N i=1 ) if the variational density family for q i (θ i ) is rich enough to subsume Gaussian. Now we look closely at L n (q(ϕ), {q i (θ i )} N i=1 ), and we note that the last two terms as per (135) are constant (i.e., not a function of data size n). That is, 1 N N i=1 E q(ϕ) [KL(q i (θ i )||p(θ i |ϕ))] + KL(q(ϕ)||p(ϕ)) = C, for some constant C. Then we can write L n (q(ϕ), {q i (θ i )} N i=1 ) = 1 N N i=1 E qi(θi) [-l n (P θi (D i ), P i (D i ))] + C. We bound the expected -l n term in (139) using use Lemma E.2 belowfoot_10 , which states that with high probability, E qi(θi) [-l n (P θi (D i ), P i (D i ))] ≤ C ′ n(r n + λ * i ), for some constant C ′ > 0. Plugging this bound into (139), we have the following derivation where we start from (134) with q(ϕ) = q * (ϕ) and q i (θ i ) = q * i (θ i ): n N • N i=1 E q * i (θi) [d 2 (P θi , P i )] ≤ L n (q * (ϕ), {q * i (θ i )} N i=1 ) + Cnϵ 2 n (141) ≤ L n (q(ϕ), {q i (θ i )} N i=1 ) + Cnϵ 2 n (142) ≤ C + C ′ n r n + 1 N N i=1 λ * i + Cnϵ 2 n . By dividing both sides by n, we complete the proof. Lemma E.2 (From the proof of Lemma 4.1 in (Bai et al., 2020) ). For qi (θ i ) definedfoot_11 as in ( 136) and r n , λ * i defined as in Theorem E.1, the inequality (140) holds with high probability. Proof of Lemma E.2. From our regression model assumption ( 122) and ( 124), E qi(θi) -l n (P θi (D i ), P i (D i )) = E qi(θi) log P i (D i ) -log P θi (D i ) (144) = 1 2σ 2 ϵ E qi(θi) ||Y i -f θi (X i )|| 2 2 -E qi(θi) ||Y i -f i (X i )|| 2 2 (145) = 1 2σ 2 ϵ E qi(θi) ||f θi (X i ) -f i (X i )|| 2 2 ≜R1 +2 • E qi(θi) Y i -f i (X i ), f i (X i ) -f θi (X i ) ≜R2 . We first work on R 1 . Since ||f θi (X i ) -f i (X i )|| 2 2 ≤ n • ||f θi -f i || 2 ∞ ≤ 2n ||f θi -f θ * i || 2 ∞ + ||f θ * i -f i || 2 ∞ , (147) we have R 1 = 2n • E qi(θi) ||f θi -f θ * i || 2 ∞ + 2n • ||f θ * i -f i || 2 ∞ (148) ≤ 2n(r n + λ * i ). (149) where in (149), we use the definition of λ * i and the fact E qi(θi) ||f θi -f θ * i || 2 ∞ ≤ r n from Appendix G in (Chérief-Abdellatif, 2020). Next, we bound R 2 . Since (Y i -f i (X i )) ∼ N (0, σ 2 ϵ I) and independent of θ i , we can let ϵ := Y i -f i (X i ) for ϵ ∼ N (0, σ 2 ϵ I). Then R 2 = ϵ ⊤ • E qi(θi) f i (X i ) -f θi (X i ) ∼ N (0, c f σ 2 ϵ ), where c f = E qi(θi) f i (X i ) -f θi (X i ) 2 2 . Applying Jensen's inequality on the convexity of || • || 2 2 , c f ≤ E qi(θi) ||f i (X i )-f θi (X i )|| 2 2 = R 1 . Due to the property of Gaussian, there exists some constant C ′ 0 such that R 2 ≤ C ′ 0 • c f ≤ C ′ 0 • R 1 with high probability. Plugging these bounds on R 1 and R 2 back to (146) leads to: E qi(θi) -l n (P θi (D i ), P i (D i )) = 1 2σ 2 ϵ (R 1 + 2R 2 ) ≤ 1 + 2C ′ 0 2σ 2 ϵ R 1 (151) ≤ 1 + 2C ′ 0 σ 2 ϵ n(r n + λ * i ). Letting C ′ := 1+2C ′ 0 σ 2 ϵ (constant) completes the proof. F ADDITIONAL EXPERIMENTAL RESULTS F.1 MORE RESULTS ON CIFAR-100 & CIFAR-C-100 We test our mixture model with different mixture orders K = 2, 5, 10 on CIFAR-100 (Table 3 ) and CIFAR-C-100 (Table 4 ). In the last columns of the tables, we also report the performance of the centralised (non-FL) training, in which the batch sampling follows the corresponding FL settings. That is, at each round, the minibatches for SGD (for conventional cross-entropy loss minimisation) are sampled from the data of the participating clients. The centralised training sometimes outperforms the best FL algorithms (our models), but can fail completely especially when data heterogeneity is high (small s) and τ is large. This may be due to overtraining on biased client data for relatively few rounds. Our FL models perform well consistently and stably being comparable to centralised training on its ideal settings (small τ and/or large s).

F.2 COMPARISON WITH SIMPLE ENSEMBLE BASELINES.

Our mixture model maintains K backbone networks (specifically {r j } K j=1 ) where the mixture order K is usually small but greater than 1 (e.g., K = 2). Thus it requires extra computational resources than other methods (including our NIW model) that only deal with a single backbone. As a baseline comparison, we aim to come up with some simple extension of Fed-Avg (McMahan et al., 2017) that incorporates multiple (the same K) networks. Here are the detailed descriptions of the baseline ensemble extensions: Note that K = 1 exactly reduces to Fed-Avg (or Fed-BABU). In Fig. 3 we visualise the performance of these ensemble baselines, compared with our mixture model for different K = 2, 5, 10 on CIFAR-100 with (f = 0.1, τ = 1) setting. It clearly shows that these simple ensemble strategies are prone to overfit. The result signifies the importance of the sophisticated negative log-sum-exp regularisation in the client/server updates as in ( 22) and ( 23) in our mixture model.  L i E q i (θ i ;L i ) [-log p(Di|θi)] + E q(ϕ;L 0 ) KL(qi(θi; Li)||p(θi|ϕ)) , Initial Li can be either copied from L0 or the last iterate if the client is able to save Li locally. 4. Each client i ∈ I sends the updated Li back to the server. 5. Upon receiving {Li}∈I, the server updates L0 by solving (with {Li}∈I fixed):  min L 0 KL(q(ϕ; L0)||p(ϕ)) - N N f i∈I E q(ϕ;L 0 )q i (θ i ;L i ) [log p(θi|ϕ)]. m i -log p(Di| mi) + p 2 (n0 + d + 1)(mi -m0) ⊤ V -1 0 (mi -m0), where mi is the dropout version (with probability 1 -p) of mi. Initial mi can be either copied from m0 or the last iterate if the client is able to save mi locally. 4. Each client i ∈ I sends the updated Li = mi back to the server. 5. Upon receiving {mi}∈I, the server updates L0 = (m0, V0) by: m * 0 = p N + 1 N N f i∈I mi, V * 0 = n0 N + d + 2 (1 + N ϵ 2 )I + m * 0 (m * 0 ) ⊤ + N N f i∈I ρ(m * 0 , mi, p) , where ρ(m0, mi, p) = pmim ⊤ i -pm0m ⊤ i -pmim ⊤ 0 + m0m ⊤ 0 . Algorithm 4 Training algorithm: Mixture case. Input: Initial L0 = {rj} K j=1 in q(ϕ; L0) = j N (µj; rj, ϵ 2 I) and β in the gating network g(x; β). Output: Trained parameters L0 = {rj} K j=1 and β. For each round r = 1, 2, . . . , R do: 1. Sample a subset I of participating clients (|I| = N f ≤ N ). 2. Server sends L0 = {rj} K j=1 and β to all clients i ∈ I. 3. For each client i ∈ I in parallel do: Solve (by SGD) with L0 = {rj} K j=1 fixed: min m i E q i (θ i ;m i ) [-log p(Di|θi)] -log K j=1 exp - ||mi -rj|| 2 2σ 2 , where qi(θi; mi) = N (θi; mi, ϵ 2 I). Initial mi can be either the center of {rj} K j=1 or the last iterate if the client is able to save mi locally. βi = SGD update of β in g(x; β) with data {(x, j * )}x∼D i where j * = arg min j ||mi -rj||. 4. Each client i ∈ I sends the updated Li = mi and βi back to the server. 5. Upon receiving {mi}∈I and {βi}∈I, the server updates L0 = {rj} K j=1 by the one-step EM: (E-step) c(j|i) = e -||m i -r j || 2 /(2σ 2 ) K j=1 e -||m i -r j || 2 /(2σ 2 ) , (M-step) r * j = 1 N f i∈I c(j|i) • mi σ 2 N + 1 N f i∈I c(j|i) , and updates β by aggregation: β * = 1 N f i∈I βi. 1. Estimate m in v(θ; m) = l p • N (θ[l]; m[l], ϵ 2 I) + (1 -p) • N (θ[l]; 0, ϵ 2 I) by solving (via SGD): min m -log p(D p | m) + p 2 (n0 + d + 1)(m -m0) ⊤ V -1 0 (m -m0), where m is the dropout version (with probability 1 -p) of m.  F + B + O(d) O(N f • d) (sent: m0, V0) (sent: mi) (O(d) from quadratic penalty) Mixture (K + 1)d 2d 2(F + B) + O(K • d) O(K • N f • d) (Order K) (sent: {rj} K j=1 , β) (sent: mi, βi) (O(K • d) from log-sum-exp) Fed-Avg d d F + B O(N f • d) (sent: θ) (sent: θi) (aggregation) F + B + O(d) F (O(d) from quadratic penalty) Mixture F + B + O(K • d) F (Order K) (O(K • d) from log-sum-exp) Fed-Avg F + B F 2021): the number of ensemble components 3; FedEM (Marfoq et al., 2021) : the number of base models 3.

H.2 MNIST AND FASHION-MNIST

The FL setting is as follows: the number of clients N = 100, the number of shards per client s = 5, the fraction of participating clients per round f = 0. (Li et al., 2018; Acar et al., 2021) , which is shown to help the global model converge more reliably. Recent approaches aimed to incorporate benefits from existing machine learning approaches including domain adaptation/generalisation, clustering, multi-task learning, transfer learning, and meta-learning. To deal with heterogeneous client data distributions, those works in (Peterson et al., 2019; Zhang et al., 2021; Sun et al., 2021) attempted to tackle the FL problem in the perspective of (multi-)Domain Adaptation/Generalisation. Another interesting line of works aims to cluster clients with similar data distributions together (Briggs et al., 2020; Mansour et al., 2020) . Along the line, the shared representations among the related or similar clients can be modeled motivated from general multi-task learning (Smith et al., 2017; Dinh et al., 2021) . Motivated from transfer learning, reasonable attempts are made to exploit the idea of learning/transferring knowledge from related clients (Chen et al., 2020; Yang et al., 2020; Dinh et al., 2020; Li et al., 2021) , Last but not least, there have been attempts to the personalised FL methods based on meta learning since the fientuning from the global trained model can be seen as adaptation to new data (Chen et al., 2018; Fallah et al., 2020) . Comparison to existing Bayesian FL approaches. Some recent studies tried to address the FL problem using Bayesian methods. As we mentioned earlier, the key difference is that these methods do not introduce Bayesian hierarchy, and ultimately treat network weights θ as a random variable shared across all clients, while our approach assigns individual θ i to each client i governed by a common prior p(θ i |ϕ). The non-hierarchical approaches must all resort to ad hoc heuristics or strong assumptions in parts of their algorithm. More specifically, FedPA (Posterior Averaging) (Al-Shedivat et al., 2021) aims to establish the decomposition, p(θ|D 1:N ) ∝ N i=1 p(θ|D i ) also known as product of experts, to allow client-wise inference/optimisation of p(θ|D i ). Unfortunately this decomposition does not hold in general unless we make a strong assumption of uninformative prior p(θ) ∝ 1 as they did. FedBE (Bayesian Ensemble) (Chen & Chao, 2021) aims to build the global posterior distribution p(θ|D 1:N ) from the individual posteriors p(θ|D i ) in either of two ad-hoc ways: SWAG (Maddox et al., 2019) -like model averaging over clients, or a convex combination of the modes of the local posteriors. pFedBayes (Zhang et al., 2022) can be seen as an implicit regularisation-based method to approximate p(θ|D 1:N ) from individual posteriors p(θ|D i ). To combine the individual posteriors, they introduce the so-called global distribution w(θ), which essentially serves as a regulariser that aims to enforce local posteriors p(θ|D i ) not to deviate from it, i.e., p(θ|D i ) ≈ w(θ) for all i. The introduction of w(θ) and its update strategy appears to be a hybrid treatment rather than solely Bayesian perspective. Finally, FedEM (Marfoq et al., 2021) forms a seemingly reasonable hypothesis that local client data distributions can be identified as mixtures of a fixed number of base distributions (with different mixing proportions). However, although they have probabilistic modeling, mixture estimation, and base distribution learning under this hypothesis, this method is not a Bayesian approach. FedGP (Achituve et al., 2021b) aims to extend the GP-Tree algorithm (Achituve et al., 2021a) to the FL setting via the shared deep kernel learning. To this end, the clients perform GP-Tree kernel learning locally on its own data while the server aggregation simply follows the FedAvg algorithm to learn a global kernel. In this sense, the overall approach is quite different from our hierarchical Bayesian treatment. FedPop (Kotelevskii et al., 2022) : It has a similar hierarchical Bayesian model structure as ours. But they split the backbone network parameters into those of the feature extractor (denoted by ϕ in the paper) and the linear classification head (β). In their model, the feature extractor weights ϕ are shared across the clients (called fixed effects), and the client-wise classification head parameters z i are sampled from β, i.e., z i ∼ p(z|β). Thus the client data D i is generated by ϕ and z i . The main differences from our approach are in four folds: 1) The higher-level variables β and local variables z i sampled from β are both restricted to the linear classification head part of the network, which makes imposing uncertainty in model parameters quite limited; 2) Moreover, they do not actually treat β (and ϕ of feature extractor) as random variables, but deterministic variables which are optimized in empirical Bayes learning. This hinders the model from benefiting from hierarchical Bayesian modeling (e.g., they do not have prior distribution p(β) at all); 3) Their optimization is alternating between the feature extractor ϕ and the head prior parameters β, utterly different from our block coordinate optimization alternating between higher level random variables and individual local variables; 4) They did not use variational inference for inference p(z i |D i , ϕ, β), but MCMC sampling (Lagevin dynamics), which is the very reason why they had to reduce the size of the latents z i only limited to classification heads, instead of full network parameters as we did.



Note that we do not deal with generative modeling of input images x. Inputs x are always given, and only conditionals p(y|x) are modeled. See Fig.1(b) for the in-depth graphical model diagram. This choice ensures that the mean of Σ equals I, and µ is distributed as 0-mean Gaussian with covariance Σ. Only the constant 1 added to the denominator, which comes from the prior and has the regularising effect. In practice we use a single sample (S = 1) for computational efficiency. Instead of performing several EM steps until convergence, in practice we find only one EM step is sufficient. Only the constant 1 added to the denominator, which comes from the prior and has the regularising effect. In practice we use a single sample (S = 1) for computational efficiency. Although one can perform several EM steps until convergence, in practice, we find that only one EM step per round is sufficient. We also follow the Fed-BABU(Oh et al., 2022) strategy by updating only the body of β and fixing/sharing the random classification head across the server and clients. Note that the optimal posterior on ϕ (i.e., q * (ϕ)) does not appear here since it only affects d 2 (P θ , P i ) implicitly. See our detailed analysis/proof provided below. Although Lemma E.2 can be found in the proof of Lemma 4.1 of(Bai et al., 2020), we state this lemma more clearly with separate proof for self-containment. In(Bai et al., 2020), they defined qi(θi) as a spike-and-slab model to deal with sparsity. Essentially it is a mixture of two components, selecting N (θi; θ * i , σ 2 n I) when θ * i entries are non-zero and selecting the delta function at 0 when θ * i entries equal zero. Without loss of generality (or under mild numerical approximation), we can assume all entries of θ * i are non-zero, which makes qi(θi) Gaussian equal to N (θi; θ * i , σ 2 n I) as in (136).



(a) Overall model (b) Individual client (c) Global prediction (d) Personalisation Figure 1: Graphical models. (a) Plate view of iid clients. (b) Individual client data with input images x given and only p(y|x) modeled. (c) & (d): Global prediction and personalisation as probabilistic inference problems (shaded nodes = evidences, red colored nodes = targets to infer, x * = test input in global prediction, D p = training data for personalisation and x p = test input).

min L0 L 0 (L 0 ) := KL(q(ϕ; L 0 )||p(ϕ)) -N i=1 E q(ϕ;L0)qi(θi;Li) [log p(θ i |ϕ)].

Figure 2: Hyperparameter sensitivity analysis and comparison with simple ensemble baselines.

)Interestingly, (74) generalises the well-known aggregation step of averaging local models in Fed-Avg(McMahan et al., 2017) and related methods: Especially when K = 1, (74) reduces to quadratic optimisation, admitting the optimal solution r * 1 =i=1 m i . The extra term σ 2 in the denominator can be explained by incorporating an extra zero local model originating from the prior (interpreted as a neutral model) with the discounted weight σ 2 rather than 1.

For all t, ||∇f t (x)|| ≤ R f and ||x -x ′ || ≤ D for any x, x ′ , where R f and D are some constants. Theorem D.1 (Convergence analysis). Let η t = L + √ t

Figure 3: Comparison between our mixture model and ensemble baselines (K varied) on CIFAR-100.

Training algorithm: Normal-Inverse-Wishart case. Input: Initial L0 = (m0, V0) in q(ϕ; L0) = N IW(ϕ; {m0, V0, l0 = |D| + 1, n0 = |D| + d + 2}) where |D| = N i=1 |Di| and d = the number of parameters in the backbone network p(y|x, θ). Output: Trained parameters L0 = (m0, V0). For each round r = 1, 2, . . . , R do: 1. Sample a subset I of participating clients (|I| = N f ≤ N ). 2. Server sends L0 = (m0, V0) to all clients i ∈ I. For each client i ∈ I in parallel do: Solve (by SGD) with L0 = (m0, V0) fixed: min

Return p(y p |x p , D p , D1:N ) ≈ p(y p |x p , m). Algorithm 10 Personalisation: Mixture case. Input: Personal training data D p . Test input x p . Learned model L0 = {rj} K j=1 in q(ϕ; L0) = j N (µj; rj, ϵ 2 I). Output: Predictive distribution p(y p |x p , D p , D1:N ). 1. Estimate m in v(θ; m) = N (θ; m, ϵ 2 I) by solving (via SGD): min m E v(θ;m) [-log p(D p |θ)] -log K j=1

3.1 and Sec. 3.2.  Since dealing with p(θ|D p , ϕ

Global prediction and personalisation accuracy. (a) Global prediction performance (initial accuracy) ±0.42 40.87 ±0.62 41.49 ±0.75 37.23 ±0.88 10 29.02±0.33  29.02 ±0.33 29.02 ±0.33 29.02 ±0.29 29.02 ±0.29 29.02 ±0.29 27.93 ±0.28 28.26 ±0.19 27.11 ±0.11 28.21 ±1.42 ±0.93 46.43 ±0.82 49.91 ±0.78 45.83 ±1.12 10 36.68 ±0.37 36.68 ±0.37 36.68 ±0.37 36.32 ±0.27 35.45 ±0.34 33.57 ±0.06 33.92 ±0.22 35.74 ±1.36 ±0.91 53.15 ±0.25 55.50 ±0.90 53.00 ±0.48 10 35.92 ±0.17 36.22 ±0.17 36.22 ±0.17 36.22 ±0.17 35.58 ±0.24 33.82 ±1.04 33.70 ±0.42 35.57 ±1.02 ±0.36 70.36 ±1.02 75.06 ±0.67 73.93 ±0.14 10 67.35 ±1.02 67.57 ±0.62 67.57 ±0.62 67.57 ±0.62 66.24 ±0.53 61.39 ±0.27 64.86 ±0.73 65.82 ±0.33 ±0.23 76.98 ±0.66 78.56 ±0.55 78.08 ±0.28 10 67.78 ±1.02 67.78 ±1.02 67.78 ±1.02 66.74 ±0.27 66.25 ±0.46 63.81 ±0.40 63.81 ±0.51 66.15 ±1.29

Global prediction and personalisation accuracy. (a) Global prediction (initial accuracy) on test splits for the 10 training corruption types ±0.71 70.01 ±0.77 78.94 ±0.82 71.14±0.33  4 67.69 ±0.74 67.69 ±0.74 67.69 ±0.74 65.81 ±0.84 63.58 ±1.28 48.47 ±1.26 60.96 ±1.11 57.88 ±1.51 ±0.50 34.40 ±0.31 35.44 ±0.66 35.78 ±0.73 4 30.84 ±0.07 30.84 ±0.07 30.84 ±0.07 30.60 ±0.27 28.31 ±0.28 29.24 ±0.79 28.09 ±0.71 29.12 ±0.30

Global prediction and personalisation accuracy. Mixture order K varied. ±0.16 54.93 ±0.25 55.83 ±0.47 50.43 ±0.93 53.18 ±0.10 10 36.68 ±0.37 36.32 ±0.27 37.30 ±0.64 37.34 ±0.38 35.45 ±0.34 38.20 ±1.58 ±0.19 79.29 ±0.19 77.44 ±0.54 75.44 ±0.36 42.04 ±1.38 10 67.35 ±1.02 67.57 ±0.62 67.84 ±0.40 67.33 ±0.26 66.24 ±0.53 17.40 ±1.05 ±0.35 79.91 ±0.25 80.92 ±0.19 78.92 ±0.23 78.43 ±0.90 10 67.78 ±1.02 66.74 ±0.27 66.50 ±0.24 67.30 ±0.29 66.25 ±0.46 5.13 ±0.19

Running times (in seconds) on CIFAR-100 with (s = 100, f = 1.0, τ = 1) setting.

F.3 RUNNING TIMESAlthough our models achieve significant improvement in prediction accuracy, we have extra computational overhead compared to simpler FL methods like Fed-BABU. To see if this extra cost is allowable, we measure/compare wall clock times in Table5, where all methods are tested on the same machine, Xeon 2.20GHz CPU with a single RTX 2080 Ti GPU. For NIW, the extra cost in the local client update and personalisation (training) originates from the penalty term in (17), while model weight squaring to compute V 0 in (18) incurs additional cost in server update. For Mixture, the increased time in training is mainly due to the overhead of computing distances from the K server models in (22) and (23). However, overall the extra costs are not prohibitively large, rendering our methods sufficiently practical. Initial parameters L0 in the variational posterior q(ϕ; L0).Output: Trained parameters L0.For each round r = 1, 2, . . . , R do:1. Sample a subset I of participating clients (|I| = N f ≤ N ).2. Server sends L0 to all clients i ∈ I.3.For each client i ∈ I in parallel do: Solve (by SGD) with L0 fixed: min

Return p(y p |x p , D p , D1:N ) ≈ p(y p |x p , m). Training complexity of the proposed algorithms (NIW and Mixture) and Fed-Avg. All quantities are per-round, per-batch, and per-client costs. In the entries, d = the number of parameters in the backbone network, F = time for feed-forward pass, B = time for backprop, and N f = the number of participating clients per round.

Global prediction complexity of the proposed algorithms (NIW and Mixture) and Fed-Avg. All quantities are per-test-batch costs. In the entries, d = the number of parameters in the backbone network, F = time for feed-forward pass, and S = the number of samples θ (s) from the Student-t distribution in the NIW case (we use S = 1).

Personalisation complexity of the proposed algorithms (NIW and Mixture) and Fed-Avg. All quantities are per-train/test-batch costs. In the entries, d = the number of parameters in the backbone network, F = time for feed-forward pass, and B = time for backprop.

1, and the number of local training epochs per round τ = 1 (total number of rounds 100) or 5 (total number of rounds 20). FedBE and FedEM use three component models. The backbone is an MLP with a single hidden layer with 256 units. The results are summarized in Table 10 (MNIST) and Table 11 (Fashion-MNIST). = 200, the fraction of participating clients per round f = 0.2, and the number of local training epochs per round τ = 1 (total number of rounds 300). We follow the standard Dirichlet-based client data splitting. FedBE and FedEM use three component models. The backbone is a standard ConvNet with two hidden layers. The results are summarized in Table 12. I (REVISION): EXTENDED RELATED WORK General FL approaches. Perhaps the seminal pioneering work on FL is attributed to FedAvg (McMahan et al., 2017), which proposed fairly intuitive local training and global aggregation strategies with minimal training and communication complexity. A potential issue of divergence of global and local models due to the separated steps of local training and aggregation was addressed by model regularisation in the follow-up works

Global prediction and personalisation accuracy.

(MNIST) Global prediction and personalisation accuracy. (a) Global prediction performance (initial accuracy)

(Fashion-MNIST) Global prediction and personalisation accuracy. (a) Global prediction performance (initial accuracy)

(EMNIST) Global prediction and personalisation accuracy. (a) Global prediction performance (initial accuracy)

annex

Then the following holds for any x:Proof of Lemma D.2. By definition, ∇ u f t (x t ) + η t (x t+1 u -x t u ) = 0. Then for any x, we haveNote that (95) follows from (94) since x t and x t+1 only differ at indices u. Since f t has Lipschitz continuous gradient,where we plugged (95) into (97). We rearrange the terms in (98) as follows:(99) Due to the convexity of f t , we can bound the last term in (99) asPlugging ( 102) into (99) and choosing α = η t -L(> 0) yields:Applying Assumption 3 of the bounded gradient norm completes the proof. Now we are ready to prove our convergence theorem (Theorem D.1).Proof of Theorem D.1. We first aim to bound ⟨∇f (x t ), x t -x⟩, as it upper-bounds of f (x t ) -f (x) for convex f . Note that by conditioning on x t , we can only deal with randomness in minibatch at t(104) Further conditioning on batch t leads to:Here we aim to rewrite the summation term in (105) in terms of size N f blocks. To this end, let us consider:where 2 N,N f is defined as the set of all subsets of {1, . . . , N } with size N f and no repeating elements. For instance, 2 5,3 contains {1, 3, 4} and {2, 4, 5}, among others. Obviously |2 N,N f | = N N f , and each particular index i ∈ {1, . . . , N } appears exactly N -1 N f -1 times in the sum (106). Thus we can 1. The server maintains K networks (denoted by θ 1 , . . . , θ K ).2. We partition the clients into K groups with equal proportions. We will assign each θ j to each group j (j = 1, . . . , K).3. At each round, each participating client i receives the current model θ j(i) from the server, where j(i) means the group index to which client i belongs.4. The clients perform local updates as usual by warm-start with the received models, and send the updated models back to the server.5. The server collects updated local models from the clients, and takes the average within each group j to update θ j .6. After training, we have trained K networks. At test time, we can use these K networks in two different ways/options: (Preset option) Each client i uses the network assigned to its group, i.e., θ j(i) , for both prediction and finetuning/personalisation; (Ensemble option) We use all K networks (as an ensemble) for prediction and finetuning.

