WHEN TO TRUST AGGREGATED GRADIENTS: AD-DRESSING NEGATIVE CLIENT SAMPLING IN FEDER-ATED LEARNING

Abstract

Federated Learning has become a widely-used framework which allows learning a global model on decentralized local datasets under the condition of protecting local data privacy. However, federated learning faces severe optimization difficulty when training samples are not independently and identically distributed (non-i.i.d.). In this paper, we point out that the client sampling practice plays a decisive role in the aforementioned optimization difficulty. We find that the negative client sampling will cause the merged data distribution of currently sampled clients heavily inconsistent with that of all available clients, and further make the aggregated gradient unreliable. To address this issue, we propose a novel learning rate adaptation mechanism to adaptively adjust the server learning rate for the aggregated gradient in each round, according to the consistency between the merged data distribution of currently sampled clients and that of all available clients. Specifically, we make theoretical deductions to find a meaningful and robust indicator that is positively related to the optimal server learning rate and can effectively reflect the merged data distribution of sampled clients, and we utilize it for the server learning rate adaptation. Extensive experiments on multiple image and text classification tasks validate the great effectiveness of our method.

1. INTRODUCTION

As tremendous data is produced in various edge devices (e.g., mobile phones) every day, it becomes important to study how to effectively utilize the data without revealing personal information and privacy. Federated Learning (Konečnỳ et al., 2016; McMahan et al., 2017) is then proposed to allow many clients to jointly train a well-behaved global model without exposing their private data. In each communication round, clients get the global model from a server and train the model locally on their own data for multiple steps. Then they upload the accumulated gradients only to the server, which is responsible to aggregate (average) the collected gradients and update the global model. By doing so, the training data never leaves the local devices. It has been shown that the federated learning algorithms perform poorly when training samples are not independently and identically distributed (non-i.i.d.) across clients (McMahan et al., 2017; Li et al., 2021) , which is the common case in reality. Previous studies (Zhao et al., 2018; Karimireddy et al., 2020) mainly attribute this problem to the fact that the non-i.i.d. data distribution leads to the divergence of the directions of the local gradients. Thus, they aim to solve this issue by making the local gradients have more consistent directions (Li et al., 2018; Sattler et al., 2021; Acar et al., 2020) . However, we point out that the above studies overlook the negative impact brought by the client sampling procedure (McMahan et al., 2017; Fraboni et al., 2021b) , whose existence we think should be the main cause of the optimization difficulty of the federated learning on non-i.i.d. data. Client sampling is widely applied in the server to solve the communication difficulty between the great number of total clients and the server with limited communication capability, by only sampling a small part of clients to participate in each round. We find that the client sampling induces the negative effect of the non-i.i.d. data distribution on federated learning. For example, assume each client performs one full-batch gradient descent and uploads the local full-batch gradient immediately (i.e., FedSGD in McMahan et al. (2017) ), ( 1) if all clients participate in the current round, it is equivalent to performing a global full-batch gradient descent on all training samples, and the aggregated server gradient is accurate regardless of whether the data is i.i.d. or not; (2) if only a part of clients is sampled for participation, the aggregated server gradient will deviate from the above global full-batch gradient and its direction depends on the merged data distribution of the currently sampled clients (i.e., the label distribution of the dataset constructed by data points from all sampled clients). The analysis is similar in the scenario where clients perform multiple local updates before uploading the gradients, and we have a detailed discussion in Appendix B. The above analysis indicates that the reliability of the averaged gradients depends on the consistency between the merged data distribution of currently sampled clients and that of all available clients. Specifically, take the image classification task as an examplefoot_0 (refer to Figure 1): (1) if local samples from all selected clients are almost all cat images, the averaged gradient's direction deviates far away from the ideal server gradient averaged by all clients' local gradients. In this case, we may not trust the aggregated gradient, and decrease the server learning rate. (2) However, if the merged data distribution of selected clients matches well with the merged data distribution of all clients, the averaged gradient's direction is more reliable. Thus, we should relatively enlarge the server learning rate. This example motivates us to set dynamic server learning rates across rounds instead of a fixed one as previous methods do. In this paper, we first analyze the impact of client sampling and are motivated to mitigate its negative impact by dynamically adjusting the server learning rates based on the reliability of the aggregated server gradients. We theoretically show that the optimal server learning rate in each round is positively related to an indicator called the Gradient Similarity-aware Indicator (GSI), which can reflect the merged data distribution of sampled clients by measuring the dissimilarity between uploaded gradients. Based on this indicator, we propose a gradient similarity-aware learning rate adaptation mechanism to adaptively adjust the server learning rates. Furthermore, our method will adjust the learning rate for each parameter group (i.e., weight matrix) individually based on its own GSI, in order to solve the issue of the inconsistent fluctuation patterns of GSI across parameter groups and achieve more precise adjustments. Extensive experiments on multiple benchmarks show that our method consistently brings improvement to various state-of-the-art federated optimization methods in various settings.

2. RELATED WORK

Federated Averaging (FedAvg) is first proposed by McMahan et al. (2017) for the federated learning, and its convergence property is then widely studied (Li et al., 2019; Karimireddy et al., 2020) . While FedAvg behaves well on the i.i.d. data with client sampling, further studies (Karimireddy et al., 2020) reveal that its performance on the non-i.i.d. data degrades greatly. Therefore, existing studies focus on improving the model's performance when the training samples are non-i.i.d. under the partial client participation setting. They can be divided into the following categories: Advancing the optimizer used in the server's updating: Compared with FedAvg that is equivalent to applying SGD during the server's updating, other studies take steps further to use advanced server optimizers, such as using SGDM (Hsu et al., 2019; Wang et al., 2019b) and adaptive opti-



Though in our example in Figure1, for simplicity, we assume the merged data distribution of all available clients is balanced, we do not have this assumption in our later analysis.



Figure 1: An illustration of the impact of client sampling through a binary (cats v.s. dogs) image classification task. (Right Top): If Client 1 and Client 2 are selected, their merged local samples are almost all cat images. Then the averaged gradient deviates far away from the ideal gradient that would be averaged by all clients' local gradients. (Right Bottom): If Client 1 and Client N are selected, their merged data distribution matches well with the global data distribution merged by all clients' data, and the averaged gradient has a more reliable direction.

