SPARSE RANDOM NETWORKS FOR COMMUNICATION-EFFICIENT FEDERATED LEARNING

Abstract

One main challenge in federated learning is the large communication cost of exchanging weight updates from clients to the server at each round. While prior work has made great progress in compressing the weight updates through gradient compression methods, we propose a radically different approach that does not update the weights at all. Instead, our method freezes the weights at their initial random values and learns how to sparsify the random network for the best performance. To this end, the clients collaborate in training a stochastic binary mask to find the optimal sparse random network within the original one. At the end of the training, the final model is a sparse network with random weights -or a subnetwork inside the dense random network. We show improvements in accuracy, communication (less than 1 bit per parameter (bpp)), convergence speed, and final model size (less than 1 bpp) over relevant baselines on MNIST, EMNIST, CIFAR-10, and CIFAR-100 datasets, in the low bitrate regime.

1. INTRODUCTION

Federated learning (FL) is a distributed learning framework where clients collaboratively train a model by performing local training on their data and by sharing their local updates with a server every few iterations, which in turn aggregates the local updates to create a global model, that is then transmitted to the clients for the next round of training. While being an appealing approach for enabling model training without the need to collect client data at the server, uplink communication of local updates is a significant bottleneck in FL (Kairouz et al., 2021) . This has motivated research in communication-efficient FL strategies (McMahan et al., 2017a) and various gradient compression schemes via sparsification (Lin et al., 2018; Wang et al., 2018; Barnes et al., 2020; Ozfatura et al., 2021; Isik et al., 2022 ), quantization (Alistarh et al., 2017; Wen et al., 2017; Bernstein et al., 2018; Mitchell et al., 2022) , and low-rank approximation (Konečnỳ et al., 2016; Vargaftik et al., 2021; 2022; Basat et al., 2022) . In this work, while aiming for communication efficiency in FL, we take a radically different approach from prior work, and propose a strategy that does not require communication of weight updates. To be more precise, instead of training the weights, (1) the server initializes a dense random network with d weights, denoted by the weight vector w init = (w init 1 , w init 2 , . . . , w init d ), using a random seed SEED, and broadcasts SEED to the clients enabling them to reproduce the same w init locally, (2) both the server and the clients keep the weights frozen at their initial values w init at all times, (3) clients collaboratively train a probability mask of d parameters θ = (θ 1 , θ 2 , . . . , θ d ) ∈ [0, 1] d , (4) the server samples a binary mask from the trained probability mask and generates a sparse network with random weights -or a subnetwork inside the initial dense random network as follows w final = Bern(θ) ⊙ w init , where Bern(•) is the Bernoulli sampling operation and ⊙ the element-wise multiplication. We Figure 1 : Extracting a randomly weighted sparse network using the trainable probability mask θ t in the forward-pass of round t (for clients and the server). In practice, clients collaboratively train continuous scores s ∈ R d , and then at inference time, the clients (or the server) find θ t = Sigmoid(s t ) ∈ [0, 1] d . We skip this step in the figure for the sake of simplicity. In addition to the accuracy and communication gains, our framework also provides an efficient representation of the final model post-training by requiring less than 1 bpp to represent (i) the random seed that generates the initial weights w init , and (ii) a sampled binary vector Bern(θ) (computed with the trained θ). Therefore, the final model enjoys a memory-efficient deployment -a crucial feature for machine learning at power-constrained edge devices. Another advantage our framework brings is the privacy amplification under some settings, thanks to the stochastic nature of our training strategy. Our contributions can be summarized as follows: (1) We propose a FL framework, in which the clients do not train the model weights, but instead a stochastic binary mask to be used in sparsifying the dense network with random weights. This differs from the standard training approaches in the literature. (2) Our framework provides efficient communication from clients to the server by requiring (less than) 1 bpp per client while yielding faster convergence and higher accuracy than the baselines. (3) We propose a Bayesian aggregation strategy at the server side to better deal with partial client participation and non-IID data splits. (4) The final model (a sparse network with random weights) can be efficiently represented with a random seed and a binary mask which requires (less than) 1 bpp -at least 32× more efficient storage and communication of the final model with respect to standard FL strategies.



call the proposed framework Federated Probabilistic Mask Training (FedPM) and summarize it in Figure1. At first glance, it may seem surprising that there exist subnetworks inside randomly initialized networks that could perform well without ever modifying the weight values. This phenomenon has been explored to some extent in prior work(Zhou et al., 2019; Ramanujan et al., 2020;  Pensia et al., 2020; Diffenderfer & Kailkhura, 2020; Aladago & Torresani, 2021)  with different strategies for finding the subnetworks. However, how to find these subnetworks in a FL setting has not attracted much attention so far. Some exceptions to this are works by Li et al. (2021); Vallapuram et al. (2022); Mozaffari et al. (2021), which provide improvements in other FL challenges, such as personalization and poisoning attacks, while not being competitive with existing (dense) compression methods such as QSGD(Alistarh et al., 2017),DRIVE (Vargaftik et al., 2021), and  SignSGD (Bernstein et al., 2018)  in terms of accuracy under the same communication budget. In this work, we propose a stochastic way of finding such subnetworks while reaching higher accuracy at a reduced communication cost -less than 1 bit per parameter (bpp).

acknowledgement

(5) We demonstrate the efficacy of our strategy on MNIST, EMNSIT, CIFAR-10, and CIFAR-100 datasets under both IID and non-IID data splits; and show improvements in accuracy, bitrate, convergence speed, and final model size over relevant baselines, under various system configurations.

