UNDERSTANDING THE TRAINING DYNAMICS IN FEDERATED DEEP LEARNING VIA AGGREGATION WEIGHT OPTIMIZATION

Abstract

From the server's perspective, federated learning (FL) learns a global model by iteratively sampling a cohort of clients and updating the global model with the sum local gradient of the cohort. We find this process is analogous to mini-batch SGD of centralized training. In mini-batch SGD, a model is learned by iteratively sampling a batch of data and updating the model with the sum gradient of the batch. In this paper, we delve into the training dynamics in FL by learning from the experience of optimization and generalization in mini-batch SGD. Specifically, we focus on two aspects: client coherence (refers to sample coherence in mini-batch SGD) and global weight shrinking regularization (refers to weight decay in mini-batch SGD). We find the roles of the two aspects are both determined by the aggregation weights assigned to each client during global model updating. Thus, we use aggregation weight optimization on the server as a tool to study how client heterogeneity and the number of local epochs affect the global training dynamics in FL. Besides, we propose an effective method for Federated Aggregation Weight Optimization, named as FEDAWO. Extensive experiments verify that our method can improve the generalization of the global model by a large margin on different datasets and models.

1. INTRODUCTION

Federated Learning (FL) (McMahan et al., 2017; Li et al., 2020a; Wang et al., 2021; Lin et al., 2020; Li et al., 2022b ) is a promising distributed optimization paradigm where clients' data are kept local, and a central server aggregates clients' local gradients for collaborative training. Although a lot of FL algorithms with deep neural networks (DNNs) are emerging in recent years (Lin et al., 2020; Chen & Chao, 2021a; Li et al., 2020b; Acar et al., 2020; Chen & Chao, 2021b) , there are few works about the underlying training dynamics in FL with DNNs (Yan et al., 2021; Yuan et al., 2021) , which hinders us to go further into the link between generalization and optimization in FL. In the meanwhile, an interesting analogy exists between centralized mini-batch SGD and FL. The server-client training framework of FL (from the server perspective) learns a global model by iteratively sampling a cohort of clients and updating the global model with the sum local gradient of the cohort. While in centralized mini-batch SGD, a model is learned by iteratively sampling a mini-batch of data and updated by summing the corresponding gradients. In the analogy, the clients in FL refer to the data samples in mini-batch SGD, the cohort of clients refers to the mini-batch of data samples, and the communication round refers to the iteration step. The interesting analogy makes us wonder: Can we leverage the insights of mini-batch SGD to better understand the training dynamics in FL? Following this question and considering the key techniques in mini-batch SGD (as well as its generalization), in this paper, we focus on two aspects of training dynamics in FL: client coherence (refers to sample coherence in mini-batch SGD) and global weight shrinking (GWS) regularization (refers to weight decay in mini-batch SGD). Firstly, sample coherence explains how the relations between data samples affect the generalization of DNN models (Chatterjee, 2019; Chatterjee & Zielinski, 2020; Fort et al., 2019) . As an analogy, here we extend the concept of sample coherence to the client case in FL with partial participation for studying the effect and training dynamics jointly caused by heterogeneous client data and local updates. Secondly, in a different line of works, weight decay methods (Lewkowycz & Gur-Ari, 2020; Zhang et al., 2018; Loshchilov & Hutter, 2018; Xie et al., 2020) -by decaying the weights of the model parameter in each iteration step-are the key

