DOES FEDERATED LEARNING REALLY NEED BACKPROPAGATION?

Abstract

Federated learning (FL) provides general principles for decentralized clients to train a server model collectively without sharing local data. FL is a promising framework with practical applications, but its standard training paradigm requires the clients to backpropagate through the model to compute gradients. Since these clients are typically edge devices and not fully trusted, executing backpropagation on them incurs computational and storage overhead as well as white-box vulnerability. In light of this, we develop backpropagation-free federated learning, dubbed BAFFLE, in which backpropagation is replaced by multiple forward processes to estimate gradients. BAFFLE is 1) memory-efficient and easily fits uploading bandwidth; 2) compatible with inference-only hardware optimization and model quantization or pruning; and 3) well-suited to trusted execution environments, because the clients in BAFFLE only execute forward propagation and return a set of scalars to the server. In experiments, we use BAFFLE to train models from scratch or to finetune pretrained models, achieving empirically acceptable results.

1. INTRODUCTION

Federated learning (FL) allows decentralized clients to collaboratively train a server model (Konečnỳ et al., 2016; McMahan et al., 2017) . In each training round, the selected clients compute model gradients or updates on their local private datasets, without explicitly exchanging sample points to the server. While FL describes a promising blueprint and has several applications (Yang et al., 2018; Hard et al., 2018; Li et al., 2020b) , the mainstream training paradigm of FL is still gradient-based that requires the clients to locally execute backpropagation, which leads to two practical limitations: (i) Overhead for edge devices. The clients in FL are usually edge devices, such as mobile phones and IoT sensors, whose hardware is primarily optimized for inference-only purposes (Sharma et al., 2018; Umuroglu et al., 2018) , rather than for backpropagation. Due to the limited resources, computationally affordable models running on edge devices are typically quantized and pruned (Wang et al., 2019a) , making exact backpropagation difficult. In addition, standard implementations of backpropagation rely on either forward-mode or reverse-mode auto-differentiation in contemporary machine learning packages (Bradbury et al., 2018; Paszke et al., 2019b) , which increases storage requirements. (ii) White-box vulnerability. To facilitate gradient computing, the server regularly distributes its model status to the clients, but this white-box exposure of the model renders the server vulnerable to, e.g., poisoning or inversion attacks from malicious clients (Shokri et al., 2017; Xie et al., 2020; Zhang et al., 2020; Geiping et al., 2020) . With that, recent attempts are made to exploit trusted execution environments (TEEs) in FL, which can isolate the model status within a black-box secure area and significantly reduce the success rate of malicious evasion (Chen et al., 2020; Mo et al., 2021; Zhang et al., 2021; Mondal et al., 2021) . However, TEEs are highly memory-constrained (Truong et al., 2021) , while backpropagation is memory-consuming to restore intermediate states. While numerous solutions have been proposed to alleviate these limitations (discussed in Appendix B), in this paper, we raise an essential question: does FL really need backpropagation? Inspired by the literature on zero-order optimization (Stein, 1981) , we intend to substitute backpropagation with multiple forward or inference processes to estimate the gradients. Technically speaking, we propose the framework of BAckpropagation-Free Federated LEarning (BAFFLE). As illustrated in Figure 1 executes forward processes on the perturbed models using its private dataset D c and obtains K loss differences {∆L(W, δ k ; D c )} K k=1 ; (3) the server aggregates loss differences to estimate gradients. BAFFLE's defining characteristic is that it only utilizes forward propagation, which is memoryefficient and does not require auto-differentiation. It is well-adapted to model quantization and pruning as well as inference-only hardware optimization on edge devices. Compared to backpropagation, the computation graph of forward propagation in BAFFLE may be more easily optimized, such as by slicing it into per-layer calculation (Kim et al., 2020) . Since each loss difference ∆L(W, δ k ; D c ) is a scalar, BAFFLE can easily accommodate the uploading bandwidth of clients by adjusting the value of K as opposed to using, e.g., gradient compression (Suresh et al., 2017) . BAFFLE is also compatible with recent advances in inference approaches for TEE (Tramer & Boneh, 2019; Truong et al., 2021) , providing an efficient solution for combining TEE into FL and preventing white-box evasion. Base on our convergence analyses, we adapt secure aggregation (Bonawitz et al., 2017a) to zeroorder optimization and investigate ways to improve gradient estimation in BAFFLE. In our experiments, BAFFLE is used to train models from scratch on MNIST (LeCun et al., 1998) and CIFAR-10/100 (Krizhevsky & Hinton, 2009) , and finetune ImageNet-pretrained models to transfer to OfficeHome (Venkateswara et al., 2017) . Compared to conventional FL, BAFFLE achieves suboptimal but acceptable performance. These results shed light on the potential of BAFFLE and the effectiveness of backpropagation-free methods in FL.

2. PRELIMINARIES

In this section, we introduce the basic concepts of federated learning (FL) (Kairouz et al., 2021) and the finite difference formulas that will serve as the foundation for our methods. 



Figure1: A sketch map of BAFFLE. In addition to the global parameters update ∆W, each client downloads random seeds to locally generate perturbations ±δ 1:K and perform 2K times of forward propagation (i.e., inference) to compute loss differences. The server can recover these perturbations using the same random seeds and obtain ∆L(W, δ k ) by secure aggregation. Each loss difference ∆L(W, δ k ; D c ) is a floating-point number, so K can be easily adjusted to fit the uploading bandwidth.

FEDERATED LEARNING Suppose we have C clients, and the c-th client's private dataset is defined as D c := {(X c i , y c i )} Nc i=1 with N c input-label pairs. Let L(W; D c ) represent the loss function calculated on the dataset D c , where W ∈ R n denotes the server model's global parameters. The training objective of FL is to find W that minimize the total loss function as L(W) := C c=1 N c N L(W; D c ), where N = C c=1 N c . (1) In the conventional FL framework, clients locally compute gradients {∇ W L(W; D c )} C c=1 or model updates through backpropagation and then upload them to the server. Federated average (McMahan et al., 2017) performs global aggregation using ∆W := C i=1 Nc N ∆W c , where ∆W c is the local update obtained via executing W c ← W c -η∇ Wc L(W c ; D c ) multiple times and η is learning rate.

