FEDERATED LEARNING WITH DECOUPLED PROBABILISTIC-WEIGHTED GRADIENT AGGREGATION

Abstract

In the federated learning paradigm, multiple mobile clients train local models independently based on datasets generated by edge devices, and the server aggregates parameters/gradients from local models to form a global model. However, existing model aggregation approaches suffer from high bias on both data distribution and parameter distribution for non-IID datasets, which result in severe accuracy drop for increasing number of heterogeneous clients. In this paper, we proposed a novel decoupled probabilistic-weighted gradient aggregation approach called FeDEC for federated learning. The key idea is to optimize gradient parameters and statistical parameters in a decoupled way, and aggregate the parameters from local models with probabilistic weights to deal with the heterogeneity of clients. Since the overall dataset is unaccessible by the central server, we introduce a variational inference method to derive the optimal probabilistic weights to minimize statistical bias. We further prove the convergence bound of the proposed approach. Extensive experiments using mainstream convolutional neural network models based on three federated datasets show that FeDEC significantly outperforms the state-of-the-arts in terms of model accuracy and training efficiency.

1. INTRODUCTION

Federated learning (FL) has emerged as a novel distributed machine learning paradigm that allows a global machine learning model to be trained by multiple mobile clients collaboratively. In such paradigm, mobile clients train local models based on datasets generated by edge devices such as sensors and smartphones, and the server is responsible to aggregate parameters/gradients from local models to form a global model without transferring data to a central server. Federated learning has been drawn much attention in mobile-edge computing (Konecný et al. (2016) ; Sun et al. (2017) ) with its advantages in preserving data privacy (Zhu & Jin (2020) ; Jiang et al. (2019) ; Keller et al. (2018) ) and enhancing communication efficiency (Shamir et al. (2014) ; Smith et al. (2018) ; Zhang et al. (2013) ; McMahan et al. (2017) ; Wang et al. (2020) ). Gradient aggregation is the key technology of federated learning, which typically involves the following three steps repeated periodically during training process: (1) the involved clients train the same type of models with their local data independently; (2) when the server sends aggregation signal to the clients, the clients transmit their parameters or gradients to the server; (3) when server receives all parameters or gradients, it applies an aggregation methods to the received parameters or gradients to form the global model. The standard aggregation method FedAvg (McMahan et al. (2017) ) and its variants such as FedProx (Li et al. (2020a) ), Zeno (Xie et al. (2019) ) and q-FedSGD (Li et al. (2020b) ) applied the synchronous parameter averaging method to the entire model indiscriminately. Agnostic federated learning (AFL) (Mohri et al. (2019) ) defined an agnostic and risk-averse objective to optimize a mixture of the client distributions. FedMA (Wang et al. (2020) ) constructed the shared global model in a layer-wise manner by matching and averaging hidden elements with similar feature extraction signatures. The recurrent neural network (RNN) based aggregator (Ji et al. (2019) ) learned an aggregation method to make it resilient to Byzantine attack. Despite the efforts that have been made, applying the existing parameter aggregation methods for large number of heterogeneous clients in federated learning still suffers from performance issues. It was reported in (Zhao et al. (2018) ) that the accuracy of a convolutional neural network (CNN) model trained by FedAvg reduces by up to 55% for highly skewed non-IID dataset. The work of (Wang et al. (2020) ) showed that the accuracy of FadAvg (McMahan et al. (2017) ) and FedProx (Li et al. (2020a) ) dropped from 61% to under 50% when the client number increases from 5 to 20 under heterogeneous data partition. A possible reason to explain the performance drops in federated learning could be the different levels of bias caused by inappropriate gradient aggregation, on which we make the following observations. Data Bias: In the federated learning setting, local datasets are only accessible by the owner and they are typically non-IID. Conventional approaches aggregate gradients uniformly from the clients, which could cause great bias to the real data distribution. Fig. 1 shows the distribution of the real dataset and the distributions of uniformly taking samples from different number of clients in the CIFAR-10 dataset (Krizhevsky (2009) ). It is observed that there are great differences between the real data and the sampled distributions. The more clients involved, the more difference occurs. Parameter Bias: A CNN model typically contains two different types of parameters: the gradient parameters from the convolutional (Conv) layers and full connected (FC) layers; and the statistical parameters such as mean and variance from the batch normalization (BN) layers. Existing approaches such as FedAvg average the entire model parameters indiscriminately using distributed stochastic gradient descent (SGD), which will lead to bias on the means and variances in BN layer. Fig. 2 shows the means and variances in BN layer distribution of a centrally-trained CNN model and that of FedAvg-trained models with different number of clients on non-IID local datasets. It is observed that the more clients involved, the larger deviation between the central model and the federated learning models. Our contributions: In the context of federated learning, the problems of data bias and parameter bias have not been carefully addressed in the literature. In this paper, we propose a novel gradient aggregation approach called FeDEC. The main contribution of our work are summarized as follows. (1) We propose the key idea of optimizing gradient aggregation with a decoupled probabilisticweighted method. To the best of our knowledge, we make the first attempt to aggregate gradient parameters and statistical parameters separatively, and adopt a probabilistic mixture model to resolve the problem of aggregation bias for federated learning with heterogeneous clients. (2) We propose a variational inference method to derive the optimal probabilistic weights for gradient aggregation, and prove the convergence bound of the proposed approach. (3) We conduct extensive experiments using five mainstream CNN models based on three federated datasets under non-IID conditions. It is shown that FeDEC significantly outperforms the state-of-the-arts in terms of model accuracy and training efficiency.

2. RELATED WORK

We summarize the related work as two categories: parameter/gradient aggregation for distributed learning and federated learning. Distributed Learning: In distributed learning, the most famous parameter aggregation paradigm is the Parameter Server Framework (Li et al. (2014) ). In this framework, multiple servers maintain a partition of the globally shared parameters and communicate with each other to replicate and migrate parameters, while the clients compute gradients locally with a portion of the training data, and communicate with the server for model update. Parameter server paradigm had motivated the development of numerous distributed optimization methods (Boyd et al. (2011) ; Dean et al. (2012) ; Dekel et al. (2012) ; Richtárik & Takác (2016) ; Zhang et al. (2015) ). Several works focused on improving the communication-efficiency for distributed learning (Shamir et al. (2014) ; Smith et al. (2018) ; Zhang et al. (2013) ). To address the issue of model robustness, Zeno (Xie et al. (2019) ) was proposed to make distributed machine learning tolerant to an arbitrary number of faulty workers. The RNN based aggregator (Ji et al. (2019) ) adopted a meta-learning approach that utilizes a recurrent neural network (RNN) in the parameter server to learn to aggregate the gradients from the workers, and designed a coordinatewise preprocessing and postprocessing method to improve its robustness. 2019)), demonstrated that permutations of layers can affect the gradient aggregation results, and proposed a layer-wise gradient aggregation method to solve the problem. For fair resources allocation, the q-FedSGD (Li et al. (2020b) ) method encouraged a more uniform accuracy distribution across devices in federated networks. However, all the methods did not differentiate gradient parameters and statistical parameters and aggregated the entire model in a coupled manner. In this paper, we make the first attempt to decouple the aggregation of gradient parameters and statistical parameters with probabilistic weights to optimize the global model to achieve fast convergence and high accuracy in non-IID conditions. 3 FEDEC: A DECOUPLED GRADIENT AGGREGATION METHOD  min W L(W, x) := K ∑ k=1 |x k | |x| L k (W k , x k ), where W is the parameters of the global model, As discussed in section 1, conventional federated learning has two drawbacks. Firstly, local datasets are collected by mobile devices used by particular users, which are typically non-IID. Training samples on each client may be drawn from a different distribution, therefore the data points available locally could be bias from the overall distribution. Secondly, since a neural network model is typically consists of convolutional (Conv) layers and full-connected (FC) layers that are formed by gradient parameters, and batch normalization (BN) layers that are formed by statistical parameters such as mean and variance, aggregating them without distinction will cause severe deviation of the global model parameters. W k (k = 1, 2, • • • , K) is To address the above issues, we propose a decoupled probabilistic-weighted approach for federated learning that focuses on optimizing the following loss function: min W * L({W t N N , W t mean , W t var }, x) := K ∑ k=1 π k L k ({W t-1,k N N , W t-1,k mean , W t-1,k var }, x k ), where * indicates N N, mean and var; W t N N , Wt mean and W t var are the parameters of Conv and FC layers of the global model after t-th aggregation epoch; W t-1,k N N , W t-1,k mean and W t-1,k var are the k-th local model been trained several local epoch based t -1-th global model; π k (k = 1, . . . , K) is the probability that a sample is drawn from the distribution of the k-th client, i.e., π k ∈ [0, 1] (k = 1, . . . , K) and ∑ K k=1 π k = 1. The above formulation objects to minimize the expected loss over K clients with non-IID datasets. Next, we will introduce a decoupled method called FeDEC to optimize the parameters of different types of layers separatively, and derived the probability weights π k for parameter aggregation.

3.2. DECOUPLED PROBABILISTIC-WEIGHTED GRADIENT AGGREGATION METHOD

In this section, we proposed a decoupled method to derive the global model with respect to W t N N (parameters of Conv and FC layers) and W t mean , W t var (statistical parameters of BN layers).

3.2.1. GRADIENT AGGREGATION FOR CONV AND FC LAYERS

Since the parameters of Conv and FC layers are neural network weights which are updated by distributed gradient descent method (Nesterov (1983) ), they are appropriate to be aggregated with a similar approach that adapts conventional federated average for non-IID datasets. Let  g t k = W t-1,k N N * -W t-1,k N N (k = 1, . . . , K), W t N N = W t-1 N N -β K ∑ k=1 π t k g t k , ( ) where  β W t mean = K ∑ k=1 π t k W t,k mean , ( ) W t var = 1 |x| -K K ∑ k=1 (|x k | -1)π t k W t,k var , ( ) where W t,k mean and W t,k var indicate the means and variances in BN layers of the k-th client in epoch t; π t k (k = 1, . . . , K) are probabilistic weights with ∑ K k=1 π t k = 1 that are derived in section 3.2.3. In the above equations, we update the mean with the weighted average of local models, and update the variance with the weighted pooled variance (Killeen ( 2005)), which can give an unbias estimation of parameters of the whole dataset under non-IID conditions (see Appendix A.2).

3.2.3. DERIVATION OF PROBABILISTIC WEIGHTS

We adopt a mixture probabilistic model to describe non-IID datasets in federated learning. Without loss of generality, in the t-th training epoch, we assume the mini-batch samples of each client follows a Gaussian distribution N k (µ k , σ k ) (k = 1, . . . , K), where µ k , σ k are the mean and standard deviation of the distribution that vary among clients. We omit the upper script t for simplicity thereafter. The whole samples can be described as a Gaussian Mixture Model (GMM) with the following probability function: p(x|λ) = K ∑ k=1 π k p(x k |µ k , σ k ), where λ = {π k , µ k , σ k | k = 1, 2, • • • , K} are the parameters of the GMM model 1 . In federated learning, the local data samples are accessed by particular client and the central server can only observe the statistics of local dataset such as mean and standard variance. Without knowing the overall samples, conventional expectation-maximization (EM) algorithm (Dempster et al. (1977) ) cannot be applied to derive λ. Alternatively, we introduce a variational inference method to estimate the parameters of λ. s t k z t k ϕ k θ k π k , µ k , σ k K 1 Figure 3: The variational Bayesian generative model using plate notations. Specifically, we construct a variational Bayesian generative model to generate data that are close to the reported statistics of local models as possible, and use the generated data to estimate the GMM model parameters. The plate notions of the generative model are shown in Fig. 3 . The notations are explained as follows. •s t k = {W t,k mean , W t,k var } is the observed statistics from the feature maps of k-th client. • z t k = {z t k,i |(i = 1, 2, • • • , C)} is a vector of latent variables with length C, where z t k,i ∈ [0, 1], ∑ C i=1 z t k,i = 1 , and C is the number of classes for a classification task. z t k can be viewed as a data distribution that represents the probability of a sample in client k belonging to the classes. • θ = {θ k } are generative model parameters, and ϕ = {ϕ k } are variational parameters. The solid lines in Fig. 3  ϕ k (z t k |s t k ) to the intractable posterior p θ k (z t k |s t k ) We approximate p θ k (z t k |s t k ) with q ϕ k (z t k |s t k ) by minimizing their divergence: ϕ * k , θ * k = arg min θ k ,ϕ k divergence(q ϕ k (z t k |s t k ) || p θ k (z t k |s t k )), s.t. ∑ C i=1 z t k,i = 1. To derive the optimal value of the parameters ϕ k and θ k , we compute the marginal likelihood of s t k : log p(s t k ) = D KL (q ϕ k (z t k |s t k ) || p θ k (z t k |s t k )) + E q ϕ k (z t k |s t k ) [ log p θ k (z t k , s t k ) q ϕ k (z t k |s t k ) ] . ( ) In Eq. 8, the first term is the KL-divergence (Joyce (2011)) of the approximate distribution and the posterior distribution; the second term is called the ELBO (Evidence Lower BOund) on the marginal likelihood of dataset in the k-th client. Since log p(s t k )is non-negative, the minimization problem of Eq. 7 can be converted to maximize the ELBO. To solve the problem, we change the form of ELBO as: E q ϕ k (z t k |s t k ) [ log p θ k (z t k , s t k ) q ϕ k (z t k |s t k ) ] = E q ϕ k (z t k |s t k ) [ log p(z t k ) q ϕ k (z t k |s t k ) ] Encoder + E q ϕ k (z t k |s t k ) [log p θ k (s t k |z t k )] Decoder . (9) The above form is a variational encoder-decoder structure: the model q ϕ k (z t k |s t k ) can be viewed as a probabilistic encoder that given an observed statistics s t k it produces a distribution over the possible values of the latent variables z t k ; The model p θ k (s t k |z t k ) can be refered to as a probabilistic decoder that reconstructs the value of s t k based on the code z t k . According to the theory of variational inference (Kingma & Welling (2014) ), the problem in Eq. 9 can be solved with stochastic gradient descent (SGD) method using a fully-connected neural network to optimize the mean squared error loss function. With the derived optimal parameters ϕ * k , θ * k , we can extract the latent variables z t k that is interpreted as the sample distribution of client-k. Therefore z t k can be used to infer the parameters (π k , µ k , σ k ) of k-th component of the GMM model. Specifically, the probabilistic weights π k can be represented by π t k = { C ∑ i=1 z t k,i ∑ K j=1 z t j,i } / { K ∑ k=1 C ∑ i=1 z t k,i ∑ K j=1 z t j,i } . ( )

4. CONVERGENCE ANALYSIS

In this section, we will show that the convergence of the proposed FeDEC algorithm is theoretically guaranteed. We use the following assumptions and lemmas, and the convergence guarantee is provided in Theorem 1. Assumption 1 (Unbiased Gradient): We assume that the stochastic gradients g t i is an unbiased estimator of the true gradient ∇f (w t i ), i.e., E[g t i ] = ∇f (w t i ), where f (•) is any convex objective function and w t i is its variables. Assumption 2 (Gradient Convex Set): We assume that gradient set G is a convex set, where all gradients g 1 , g 2 , . . . , g K are in G, and any g = ∑ K i=1 λ i g i (∀λ i > 0 and ∑ K i=1 λ i = 1 ) is in G. Lemma 1 (L-Lipschitz Continuity): For a function f (•) is Lipschitz continuous if there exists a positive real constant L such that, for all real x 1 and x 2 : |f (x 1 ) -f (x 2 )| ≤ L|x 1 -x 2 |. Lemma 2 (Jensen's Inequality): If f (w) is a convex function on W, and E[f (w)] and f (E[w]) are finite, then: E[f (w)] ≥ f (E[w])). Definition ). Proof skeleton: We provide a simple description of the proof skeleton of Theorem 1 with the following steps. (1) Since f (•) is a convex function, we have f (w t )-f (w) ≤ ∇f (w t )(w t -w). (2) With assumption 1 and 2, we have f (w t ) -f (w) ≤ 1 2β (||w t -w|| 2 -||w t+1 * -w|| 2 ) + β 2 ||∇f (w t )|| 2 , where w t+1 * is the intermediate result of f (w) in update time t + 1. (3) With lemma 1 and definition 1, by projecting w t+1 * to w t+1 , we have f (w t ) -f (w) ≤ 1 2β (||w t -w|| 2 -||w t+1 -w|| 2 ) + β 2 L 2 . ( ) Summing from t = 1 to T and with definition 1 and 2, we have ∑ T t=1 f (w t ) -T f (w) ≤ 1 2β Γ 2 + β 2 L 2 T . (5) According to lemma 2, we have f ( wT ) -f (w) ≤ Γ 2 2βT + β 2 L 2 . (6) Taking β = Γ/(L √ T ), we can obtain the convergence rate in the theorem. The detailed proof of Theorem 1 and explanations are provided in Appendix A.1. According to Theorem 1, the FeDEC parameter aggregation algorithm is guaranteed to converge, and the convergence rate can be as fast as general stochastic gradient decent which only related to the training epoch T with an associated constants. The constant is related to the optimization problem parameters such as lipschitz constant L, and diameter of domain Γ.

5. PERFORMANCE EVALUATION

In this section, we evaluate the performance of the proposed FeDEC method for federated learning.

5.1. EXPERIMENTAL SETUP

Implementation. We implement the proposed FeDEC parameter aggregation approach and the considered baselines in PyTorch (Paszke et al. (2019) ). We train the models in a simulated federated learning environment consisting of one server and a set of mobile clients with wireless network connections. Unless explicitly specified, the default number of clients is 20, and the learning rate β = 0.01. We conduct experiments on a GPU-equipped personal computer (CPU: Inter Core i7-8700 3.2GHz, GPU: Nvidia GeForce RTX 2070, Memory: 32GB DDR4 2666MHz, and OS: 64-bit Ubuntu 16.04).

Models and datasets.

We conduct experiments based on 5 mainstream neural network models: ResNet18 (He et al. (2016) ), LeNet (Lecun et al. (1998) ), DenseNet121 (Huang et al. (2017) ), Mo-bileNetV2 (Sandler et al. (2018) ), and a 4-layer CNN (every CNN layer is followed by a BN layer). The detailed structure of the CNN models are provided in Appendix A.3. We use three real world datasets: MNIST (LeCun et al. ( 2010)), Fashion-MNIST (Xiao et al. (2017) ), and CIFAR-10 (Krizhevsky ( 2009)). MNIST is a dataset for hand written digits classification with 60000 samples and each example is a 28 × 28 greyscale image. Fashion-MNIST is a dataset intended to replace the original MNIST for benchmarking machine learning algorithms. CIFAR-10 is a larger dataset with 10 categories. Each category has 5000 training images and 1000 validation images of size 32 × 32. For each dataset, we use 80% of the data for training and amalgamate the remaining data into a global test set. We form non-IID local datasets as follows. Assume there are C classes of samples in a dataset.

Each client draw samples form the dataset with probability pr

(x) = { η ∈ [0, 1], if x ∈ class j , N (0.5, 1 ), otherwise. It means that the client draw samples from a particular class j with a fixed probability η, and from other classes based on standard Gaussian distribution. The larger η is, the more likely the client's samples concentrate on a particular class, and the more heterogeneous the local datasets are. Convergence: In this experiment we study the convergence of all baselines and our algorithm by showing the total communication epochs versus train loss. Fig. 4 shows the result of ResNet18 on CIFAR-10. It is shown that the loss of all algorithms tend to be stable after a number of epoches. Clearly FeDEC has the lowest loss among all algorithms, which means that FeDEC converges faster that of baselines. The results of more CNN models on different datasets are shown in Appendix A.4. Training Efficiency: In this experiment we study the test accuracy versus time during training of a CNN model with federated learning. Fig. 5 shown the results of training ResNet18 on CIFAR-10. It is shown that FeDEC reaches 0.8 accuracy after 18 minutes, while FedMA, FedProx, and FedAvg take 36 to 63 minutes to reach the same accuracy. FeDEC approaches 0.9 accuracy after 54 minutes, while the accuracy of other algorithms are below 0.85. The results of more CNN models on different datasets are shown in Appendix A.5. It suggests that FeDEC trains much faster than the baseline algorithms and it can reach high accuracy in a short time period. Parameter Bias: In this experiment we study the parameter bias of federated learning algorithms. Fig. 6 compares the KL-divergence between the means and variances in BN of global models aggregated by different algorithms and the central model. It is shown that FedAvg, FedProx, and q-FedSGD have exceptional high parameter bias, while FeDEC has significantly lower KL-divergence compared to the baselines for different CNN models on different datasets. Global Model Accuracy: In this experiment, we compare the global model accuracy of different federated parameter aggregation algorithms after training to converge. We repeat the experiment for 20 rounds and show the average results in Hyperparameter Analysis: We further analyze the influence of two hyperparameters in federated learning: the number of clients involved and the heterogeneity of local datasets. Fig. 7 compares the test accuracy of the global model for different number of involved clients. According to the figure, the performance of FeDEC is stable. When the number of mobile clients increases from 5 to 20, the test accuracy slightly decreases from 0.909 to 0.893. Other baseline algorithms yield significant performance drop. FeDEC achieves the highest test accuracy among all federated learning algorithms in all cases, and it performs very close to the central model. In the experiment, the heterogeneity of local datasets is represented by η, the probability that a client tends to sample from a particular class. The more η approaches to 1, the more heterogeneous the local datasets are. Fig. 8 shows the test accuracy under different level of heterogeneity. As η increases, the test accuracy of all models decreases. FeDEC yields the highest test accuracy among all algorithms, and its performance drops much slower than the baselines. It verifies the effectiveness of the proposed probabilistic-weighted gradient aggregation approach under non-IID conditions.

6. CONCLUSION

Gradient aggregation played an important role in federated learning to form a global model. To address the problem of data and parameter bias in federated learning for non-IID dataset, we proposed a novel probabilistic parameter aggregation method called FeDEC that decoupled gradient parameters and statistical parameters to aggregate them separatively. The probabilistic weights were optimized with variational inference, and the proposed method was proved to be convergence guaranteed. Extensive experiments showed that FeDEC significantly outperforms the state-of-the-arts on a variety of performance metrics.

A APPENDIX A.1 PROOF OF CONVERGENCE GUARANTEE (THEOREM 1 IN SECTION 4)

We provide the detailed proof of Theorem 1 in Section 4. We first restate the necessary equations and the theorem. In section 3, we propose the usage of the following equations for gradient decent update algorithm: W t N N = W t-1 N N -β K ∑ k=1 π t k g t k , ( ) where β is the learning rate for parameter update. And for batch normalization update: W t mean = K ∑ k=1 π t k W t,k mean (4) W t var = 1 ||X|| -n K ∑ k=1 (||x k || -1)π t k W t,k var (5) We restate the theorem in section 4 in the following: Assumption 1 (Unbiased Gradient): We assume that the stochastic gradients g t i is an unbiased estimator of the true gradient ∇f (w t i ), i.e., E[g t i ] = ∇f (w t i ), where f (•) is any convex objective function and w t i is its variables. Assumption 2 (Gradient Convex Set): We assume that gradient set G is a convex set, where all gradients g 1 , g 2 , . . . , g K are in G, and any g = ∑ K i=1 λ i g i (∀λ i > 0 and ∑ K i=1 λ i = 1 ) is in G. Lemma 1 (L-Lipschitz Continuity): For a function f (•) is Lipschitz continuous if there exists a positive real constant L such that, for all real x 1 and x 2 : |f (x 1 ) -f (x 2 )| ≤ L|x 1 -x 2 |. Lemma 2 (Jensen's Inequality): If f (w) is a convex function on W, and E[f (w)] and f (E[w]) are finite, then: (1) According to the definition of convex function, f (w t ) -f (w) ≤ ∇f (w t )(w tw). E[f (w)] ≥ f (E[w])). f ( wT ) -min w∈W f (w) ≤ O( Γ 2 2βT + β 2 L 2 ), (2) We define G(w) = ∇f T (w t )(w tw), and g t = ∑ K k=1 π k g t k . The intermediate result of f (w) in update time t + 1 is denoted by w t+1 * . With assumption 1 and assumption 2, we have: (4) According to definition 2, summing up all w from t = 1 to T , we have: G(w) = 1 β (w t - T ∑ t=1 f (w t ) -T f (w) ≤ 1 2β (||w 1 -w|| 2 -||w t+1 -w|| 2 ) + β 2 L 2 T ≤ 1 2β |w 1 -w|| 2 + β 2 L 2 T ≤ 1 2β Γ 2 + β 2 L 2 T. (5) According to Jensen's Inequality (Lemma 2), we have: f ( wT ) -f (w) = f ( 1 T T ∑ t=1 w t ) -f (w) ≤ 1 T T ∑ t=1 f (w t ) -f (w) ≤ Γ 2 2βT + β 2 L 2 . We can get the result: f ( wT ) -min w∈W f (w) ≤ O( Γ 2 2βT + β 2 L 2 ), Taking β = Γ/(L √ T ), the right part of the above equation becomes Γ 2 2βT + β 2 L 2 = Γ 2 L √ T 2ΓT + Γ 2L √ T L 2 = ΓL √ T . Therefore we can obtain the simplified expression of the convergence bound O( 1 √ T ). A.2 EXPLANATION OF UNBIAS PARAMETER AGGREGATION IN SECTION 3.2.2 We compute the expectation of the aggregated parameters W t mean and W t var in Section 3.2.2 as follows. E[W t mean ] = E [ K ∑ k=1 π k W t,k mean ] = K ∑ k=1 π k E [ W t,k mean ] E[W t var ] = E [ 1 ||X|| -K K ∑ k=1 (||x k || -1)π k W t,k var ] = 1 ||X|| -K E [ K ∑ k=1 (||x k || -1)π k W t,k var ] = 1 ||X|| -K K ∑ k=1 (||x k || -1)π k E [ W t,k var ] According to the above equations, if the parameters of the local models W t,k mean and W t,k var are unbias, then the aggregated model parameters are unbias as well.



Noted that the proposed variational inference method can be applied to other non-Gaussian distributions with slight modification. https://github.com/kuangliu/pytorch-cifar/blob/master/models/densenet.py https://github.com/kuangliu/pytorch-cifar/blob/master/models/mobilenetv2.py



Figure 1: The differences between real data and sampled datasets (CIFAR-10).

OBJECTIVE OF FEDERATED LEARNING WITH NON-IID DATA Consider a federated learning scenario with K clients that train their local CNN models independently based on local datasets x 1 , x 2 , . . . , x K and report their gradients and model parameters to a central server. The objective of the server is to form an aggregate global CNN model to minimize the loss function over the total datasets x = {x 1 , x 2 , . . . , x K }. Conventional federated learning tends to optimize the following loss function:

the parameters of the k-th local model; L(•) and L k (•) indicate the loss functions for global model and local models accordingly. The above objective assumes training samples uniformly distributed among the clients, so that the aggregated loss can be represented by the sum of percentage-weighted of the local losses.

where N N * indicates N N parameters after full local training. be the gradient of the k-th client in the t-th training epoch. After receiving the gradients from K clients, the central server update the parameters of global model as follows.

PERFORMANCE COMPARISON We compare the performance of FeDEC with 5 state-of-the-art methods: FedAvg (McMahan et al. (2017)), RNN based aggregator (Ji et al. (2019)), FedProx (Li et al. (2020a)), q-FedSGD (Li et al. (2020b)), and FedMA (Wang et al. (2020)). The results are analyzed as follows.

Figure 4: Convergence of different algorithms (ResNet18 on CIFAR-10).

Figure 6: KL-divergence of different algorithms.

Figure 8:Test accuracy with different level of heterogeneity (ResNet18 on CIFAR-10).

11) where wT is the average result of w for total training epoch T , β is the learning rate in equation-(3), and T is the total training epoch. If we let β = Γ L √ T , the convergence rate is O( 1 √ T ).Proof: To simplify the analysis, we consider fixed learning rate β. The proof includes the following steps:

denote the generative model p θ k (z t k )p θ k (s t k |z t k ), and the dashed lines denote the variational approximation q

As shown in the table, the central method yields the highest accuracy. In comparison of different federated learning methods, FeDEC significantly outperforms the other algorithms in global model accuracy. It performs better than the state-of-the-art method FedMA with 2.87%, 3.17%, 2.58%, and 3.09% accuracy improvement in ResNet18, DenseNet121, MobileNetV2, and 4-L CNN respectively for CIFAR-10, 1.09% improvement in LeNet for F-MNIST, and 0.33% improvement in LeNet for MNIST accordingly. FeDEC achieves the highest accuracy among all baselines, and it performs very close to the centralized method, whose accuracy drop is less than 3% in all cases.

Average test accuracy on non-IID datasets. The "Central" method trains the CNN model in the central server with global dataset. The "FeDEC(w/o)" method means using the proposed probabilistic-weighted aggregation methodwithout distinguishing N N and mean, var. The "FeDEC" method represents the proposed decoupled probabilistic-weighted aggregation approach.

w t+1 * )(w tw) ) 2 -w t w + w 2 -(w t+1 * ) 2 + 2ww t+1 * -w 2 + (w t ) 2 -2w t w t+1 * We project w t+1 *to w t+1 . With definition 1 and using non-expandable property of projection operation of convex set, we have:

A.3 STRUCTURE OF THE NEURAL NETWORK MODELS IN SECTION 5

Here we report the detailed model structure used in the experiments. We use LeNet shown in Table 2 and the 4-layer CNN model shown in Table 3 . We adopt a slim ResNet18 as shown in Table 4 , where "Conv2d" is convolution layer, "BatchNorm2d" is batch normalization layer, and "Linear" is fully-connected layer. We can observe that every convolution layer is followed by a batch normalization (BN) layer. For all models, we use ReLU layer after every Conv2d layer. The structure of DenseNet121 2 and MobileNetV2 3 can be found in GitHub.For language model, we consider the sentiment analysis task on tweets from Sentiment140 with 2-layer BiLSTM. The BiLSTM binary classifier containing 256 hidden units with pretrained 100dimentional GloVe embedding. Each twitter account corresponds to a device. Transmit g t and μ, σ to server. 

