ADAPTIVE CLIENT SAMPLING IN FEDERATED LEARN-ING VIA ONLINE LEARNING WITH BANDIT FEEDBACK

Abstract

Due to the high cost of communication, federated learning (FL) systems need to sample a subset of clients that are involved in each round of training. As a result, client sampling plays an important role in FL systems as it affects the convergence rate of optimization algorithms used to train machine learning models. Despite its importance, there is limited work on how to sample clients effectively. In this paper, we cast client sampling as an online learning task with bandit feedback, which we solve with an online stochastic mirror descent (OSMD) algorithm designed to minimize the sampling variance. We then theoretically show how our sampling method can improve the convergence speed of optimization algorithms. To handle the tuning parameters in OSMD that depend on the unknown problem parameters, we use the online ensemble method and doubling trick. We prove a dynamic regret bound relative to any sampling sequence. The regret bound depends on the total variation of the comparator sequence, which naturally captures the intrinsic difficulty of the problem. To the best of our knowledge, these theoretical contributions are new and the proof technique is of independent interest. Through both synthetic and real data experiments, we illustrate advantages of the proposed client sampling algorithm over the widely used uniform sampling and existing online learning based sampling strategies. The proposed adaptive sampling procedure is applicable beyond the FL problem studied here and can be used to improve the performance of stochastic optimization procedures such as stochastic gradient descent and stochastic coordinate descent.

1. INTRODUCTION

Modern edge devices, such as personal mobile phones, wearable devices, and sensor systems in vehicles, collect large amounts of data that are valuable for training of machine learning models. If each device only uses its local data to train a model, the resulting generalization performance will be limited due to the number of available samples on each device. Traditional approaches where data are transferred to a central server, which trains a model based on all available data, have fallen out of fashion due to privacy concerns and high communication costs. Federated Learning (FL) has emerged as a paradigm that allows for collaboration between different devices (clients) to train a global model while keeping data locally and only exchanging model updates (McMahan et al., 2017) . In a typical FL process, we have clients that contain data and a central server that orchestrates the training process (Kairouz et al., 2021) . The following process is repeated until the model is trained: (i) the server selects a subset of available clients; (ii) the server broadcasts the current model parameters and sometimes also a training program (e.g., a Tensorflow graph (Abadi et al., 2016) ); (iii) the selected clients make updates to the model parameters based on their local data; (iv) the local model updates are uploaded to the server; (v) the server aggregates the local updates and makes a global update of the shared model. In this paper, we focus on the first step and develop a practical strategy for selecting clients with provable guarantees. To train a machine learning model in a FL setting with M clients, we would like to minimize the following objective: m and sends it back to the server. 3 After receiving local updates from clients in S t , the server constructs a stochastic estimate of the global gradient as 1 min w F (w) := m∈[M ] λ m ϕ (w; D m ) , g t = 1 K m∈S t λ m p t m g t m , and makes the global update of the parameter w t using g t . For example, w t+1 = w t -µ t g t , if the server is using stochastic gradient descent (SGD) with the stepsize sequence {µ t } t≥1 (Bottou et al., 2018) . However, the global update can be obtained using other procedures as well. The sampling distribution in FL is typically uniform over clients, p t = p unif = (1/M, . . . , 1/M ) ⊤ . However, nonuniform sampling (also called importance sampling) can lead to faster convergence, both in theory and practice, as has been illustrated in stochastic optimization (Zhao & Zhang, 2015; Needell et al., 2016) . While the sampling distribution can be designed based on prior knowledge (Zhao & Zhang, 2015; Johnson & Guestrin, 2018; Needell et al., 2016; Stich et al., 2017) , we cast the problem of choosing the sampling distribution as an online learning task and need no prior knowledge about equation 1. Existing approaches to designing a sampling distribution using online learning focus on estimation of the best sampling distribution under the assumption that it does not change during the training process. However, the best sampling distribution changes with iterations during the training process, and the target stationary distribution does not capture the best sampling distribution in each round. In the existing literature, the best fixed distribution in hindsight is used as the comparator to measure the performance of the algorithm used to design the sampling distribution. Here, we focus on measuring the performance of the proposed algorithm against the best dynamic sampling distribution. We use an online stochastic mirror descent (OSMD) algorithm to generate a sequence of sampling distributions and prove a regret bound relative to any dynamic comparators that involve a total variation term that characterizes the intrinsic difficulty of the problem. To the best of our knowledge, this is the first bound on the dynamic regret with intrinsic difficulty characterization in importance sampling. Moreover, we theoretically show how our sampling method improves the convergence guarantee of optimization method by reducing the dependency on the heterogeneity of the problem.

1.1. CONTRIBUTIONS

We develop an algorithm based on OSMD that generates a sequence of sampling distributions {p t } t≥1 based on the partial feedback available to the server from the sampled clients. We prove a bound on regret relative to the any dynamic comparators, which allows us to consider the best sequence of sampling distributions as they change over iterations. The bound includes a total variation 1 We use [M ] to denote the set {1, . . . , M }. 2 In this paper, we assume that all clients are available in each round and the purpose of client sampling is to reduce the communication cost, which is also the case considered by some previous research (Chen et al., 2020) . However, in practice, it is possible that only a subset of clients are available at the beginning of each round due to physical constraint. In Appendix H.2, we discuss how to extend our proposed methods to deal with such situations. Analyzing such an extension is highly non-trivial and we leave it for further study. See detailed discussion in Appendix H.2. 3 Throughout the paper we do not discuss how g t m is obtained. One possibility that the reader could keep in mind for concreteness is the LocalUpdate algorithm Charles & Konečný (2020), which covers well-known algorithms such as mini-batch SGD and FedAvg (McMahan et al., 2017) .



where ϕ(w; D m ) is the loss function used to assess the quality of a machine learning model parameterized by the vector w based on the local data D m on the client m ∈ [M ]. The parameter λ m denotes the weight for client m. Typically, we have λ m = n m /n, where n m = |D m | is the number of samples on the client m, and the total number of samples is n = M m=1 n m . At the beginning of the t-th communication round, the server uses the sampling distribution p t = (p t 1 , . . . , p t M ) ⊤ to choose K clients by sampling with replacement from [M ] 2 . Let S t ⊆ [M ] denote the set of chosen clients with |S t | = K. The server transmits the current model parameter vector w t to each client m ∈ S t . The client m computes the local update g t

