HARNESSING CLIENT DRIFT WITH DECOUPLED GRA-DIENT DISSIMILARITY

Abstract

The performance of Federated Learning (FL) typically suffers from client drift caused by heterogeneous data, where data distributions vary with clients. Recent studies show that the gradient dissimilarity between clients induced by the data distribution discrepancy causes the client drift. Thus, existing methods mainly focus on correcting the gradients. However, it is challenging to identify which client should (or not) be corrected. This challenge raises a series of questions: will the local training, without gradient correction, contribute to the server model's generalization on other clients' distributions? when does the generalization contribution hold? how to address the challenge when it fails? To answer these questions, we analyze the generalization contribution of local training and conclude that the generalization contribution of local training is bounded by the conditional Wasserstein distance between clients' distributions. Thus, the key to promote generalization contribution is to leverage similar conditional distributions for local training. As collecting data distribution can cause privacy leakage, we propose decoupling the deep models, i.e., splitting the model into a high-level model and a low-level one, for harnessing client drift. High-level models are trained on shared feature distributions, causing promoted generalization contribution and alleviated gradient dissimilarity. Experimental results demonstrate that FL with decoupled gradient dissimilarity is robust to data heterogeneity.

1. INTRODUCTION

To protect data privacy while cooperatively training machine learning models between personal users and organizations, Federated Learning (FL) (Brendan McMahan et al., 2016) is widely exploited as a powerful framework in recent years. In the FL framework, many clients train models without communicating private data. Federated Average (FedAvg) is proposed to make FL practical in lowbandwidth and low-computing resources environments. However, when data distributions between clients are severely heterogeneous (Non-Independent and Identically Distributed, Non-IID), the convergence rate and the generalization performance of FL are much worse than centralized training which collects all the data (Li et al., 2020a; Karimireddy et al., 2020; Kairouz et al., 2019) . The FL community theoretically and empirically found that the "client drift" caused by the heterogeneous data is the main bottleneck of FedAvg (Li et al., 2020a; Karimireddy et al., 2020; Kairouz et al., 2019; Wang et al., 2020a) . It means that, after several or more training epochs on private datasets, local models on clients become extremely far away from each other. Recent convergence analysis (Li et al., 2020a; Reddi et al., 2021; Woodworth et al., 2020) of FedAvg shows that the degree of client drift is linearly upper bounded by gradient dissimilarity. Therefore, most existing works (Karimireddy et al., 2020; Wang et al., 2020a) focus on gradient correction techniques to accelerate the convergence rate of local training. However, how to correct the gradients during the local training is still an open problem (Kairouz et al., 2019; Woodworth et al., 2020; Karimireddy et al., 2020) , especially for achieving better generalization ability. The challenge lies in the lack of criterion for identifying which client should (or not) be corrected. This challenge raises a fundamental question in FL systems: Can the local training on a specific client m contribute to the generalization performance of the server model when evaluted on other clients' distributions? Moreover, it is also unclear under which conditions the local training can lead to generalization contribution. The in-depth question is how to deal with the conditions where local training cannot contribute to the server models' generalizability to other clients. To answer these questions, we formulate the objective of local training in FL systems as a generalization contribution problem. The generalization contribution means how much local training on one client can improve the generalization performance on other clients' distributions for server models. Specifically, we evaluate the generalization performance of a server model locally trained on one client using other clients' data distributions. Our theoretical analysis shows that the generalization contribution of local training is bounded by the conditional Wasserstein distance between clients' distributions. This implies that even if the marginal distributions on different clients are the same, it is insufficient to achieve a guaranteed generalization performance of local training. Therefore, the key to promoting generalization contribution is to leverage the same or similar conditional distributions for local training. However, collecting data to construct identical distributions shared across clients is forbidden due to privacy concerns. To avoid privacy leakage, we propose decoupling a deep neural network into a low-level model and a high-level one, i.e., a feature extractor network and a classifier network. Consequently, we can construct a shared identical distribution in the feature space. Namely, on each client, we estimate the feature distribution obtained by the low-level network and send the estimated distribution to the server model. After aggregating the received distributions, the server sends the aggregated distribution and the server model to clients simultaneously. Theoretically, we show that introducing such a simple decoupling strategy promotes the generalization contribution and alleviates gradient dissimilarity. Our extensive experimental results demonstrate the effectiveness of our method, where we consider the global test accuracy of four datasets under various FL settings following previous works (He et al., 2020b; Li et al., 2020a; Wang et al., 2020a) . Our main contributions include: (1) We theoretically show that the generalization contribution from clients during training is bounded by the conditional Wasserstein distance between clients' distributions, answering the question that when the local training on one client can contribute to the generalization performance of server models on other clients' distributions. (2) We are the first to theoretically propose that sharing similar features between clients can improve the generalization contribution from local training, and significantly reduce the gradient dissimilarity. (3) We experimentally validate the gradient dissimilarity reduction and benefits of our method on generalization performance.

2. RELATED WORKS

We review FL algorithms aiming to address the Non-IID problem and introduce other works related to measuring client contribution and decoupled training. Due to limited space, we leave a more detailed discussion of the literature review in Appendix C.

2.1. ADDRESSING NON-IID PROBLEM IN FL

Model Regularization focuses on calibrating the local models to restrict them not to be excessively far away from the server model. A number of works like FedProx (Li et al., 2020a ), FedDyn (Acar et al., 2021 ), SCAFFOLD (Karimireddy et al., 2020 ) and FedIR (Hsu et al., 2020) add a regularizer of local-global model difference. MOON (Li et al., 2021b) adds the local-global contrastive loss to learn a similar representation between clients. Reducing Gradient Variance tries to correct the directions of local updates at clients via other gradient information. This kind of method aims to accelerate and stabilize the convergence, like Fed-Nova (Wang et al., 2020a ), FedAvgM (Hsu et al., 2019 ), FedAdaGrad, FedYogi, and FedAdam (Reddi et al., 2021) . Our theorem 4.2 provides a new angle to reduce gradient variance.

