UNDERSTANDING THE TRAINING DYNAMICS IN FEDERATED DEEP LEARNING VIA AGGREGATION WEIGHT OPTIMIZATION

Abstract

From the server's perspective, federated learning (FL) learns a global model by iteratively sampling a cohort of clients and updating the global model with the sum local gradient of the cohort. We find this process is analogous to mini-batch SGD of centralized training. In mini-batch SGD, a model is learned by iteratively sampling a batch of data and updating the model with the sum gradient of the batch. In this paper, we delve into the training dynamics in FL by learning from the experience of optimization and generalization in mini-batch SGD. Specifically, we focus on two aspects: client coherence (refers to sample coherence in mini-batch SGD) and global weight shrinking regularization (refers to weight decay in mini-batch SGD). We find the roles of the two aspects are both determined by the aggregation weights assigned to each client during global model updating. Thus, we use aggregation weight optimization on the server as a tool to study how client heterogeneity and the number of local epochs affect the global training dynamics in FL. Besides, we propose an effective method for Federated Aggregation Weight Optimization, named as FEDAWO. Extensive experiments verify that our method can improve the generalization of the global model by a large margin on different datasets and models.

1. INTRODUCTION

Federated Learning (FL) (McMahan et al., 2017; Li et al., 2020a; Wang et al., 2021; Lin et al., 2020; Li et al., 2022b ) is a promising distributed optimization paradigm where clients' data are kept local, and a central server aggregates clients' local gradients for collaborative training. Although a lot of FL algorithms with deep neural networks (DNNs) are emerging in recent years (Lin et al., 2020; Chen & Chao, 2021a; Li et al., 2020b; Acar et al., 2020; Chen & Chao, 2021b) , there are few works about the underlying training dynamics in FL with DNNs (Yan et al., 2021; Yuan et al., 2021) , which hinders us to go further into the link between generalization and optimization in FL. In the meanwhile, an interesting analogy exists between centralized mini-batch SGD and FL. The server-client training framework of FL (from the server perspective) learns a global model by iteratively sampling a cohort of clients and updating the global model with the sum local gradient of the cohort. While in centralized mini-batch SGD, a model is learned by iteratively sampling a mini-batch of data and updated by summing the corresponding gradients. In the analogy, the clients in FL refer to the data samples in mini-batch SGD, the cohort of clients refers to the mini-batch of data samples, and the communication round refers to the iteration step. The interesting analogy makes us wonder: Can we leverage the insights of mini-batch SGD to better understand the training dynamics in FL? Following this question and considering the key techniques in mini-batch SGD (as well as its generalization), in this paper, we focus on two aspects of training dynamics in FL: client coherence (refers to sample coherence in mini-batch SGD) and global weight shrinking (GWS) regularization (refers to weight decay in mini-batch SGD). Firstly, sample coherence explains how the relations between data samples affect the generalization of DNN models (Chatterjee, 2019; Chatterjee & Zielinski, 2020; Fort et al., 2019) . As an analogy, here we extend the concept of sample coherence to the client case in FL with partial participation for studying the effect and training dynamics jointly caused by heterogeneous client data and local updates. Secondly, in a different line of works, weight decay methods (Lewkowycz & Gur-Ari, 2020; Zhang et al., 2018; Loshchilov & Hutter, 2018; Xie et al., 2020) -by decaying the weights of the model parameter in each iteration step-are the key techniques in the mini-batch SGD based optimization to guard the generalization performance of deep learning tasks. We similarly examine the effects of weight decay in FL, in which we shrink the aggregated global model on the server in each communication round (i.e. global weight shrinking). Note that we take the server-side aggregation weight optimization as a tool framework to derive the insights of the training dynamics in FL. Though the idea of aggregation weight optimization was appeared in previous FL works to match similar peers in decentralized FL (Li et al., 2022a) or improve performances in FL with medical tasks (Xia et al., 2021) , all prior works assume normalized aggregation weights of clients' models (i.e. γ = 1 in Equation 1), failing to dive into understand the FL's dynamics from the learned weights for further insights, e.g., identifying the significance of adaptive global weight shrinking. Specifically, our contributions are three-folded. • We first make an analogy between centralized mini-batch SGD and FL, in which it enables us to derive a principled tool framework to understand the training dynamics in FL, by leveraging the learnt aggregation weights a global-objective-consistent proxy dataset. • As our main contribution, we identify some interesting findings (see below take-away) to unveil the training dynamics of FL, from the aspects of client coherence (cf. section 3) and global weight shrinking (cf. section 4)foot_0 . These insights are crucial to the FL community and can inspire better practical algorithm design in the future. • We showcase the effectiveness of these insights, and devise a simple yet effective method FEDAWO, for server-side aggregation weight optimization (cf. section 5). It can perform adaptive global weight shrinking and optimize attentive aggregation weights simultaneously to improve the performance of the global model. We summarize our key take-away messages of the understandings as follows. • Our novel concept of client coherence undermines the training dynamics of FL, from the aspects of local gradient coherence and heterogeneity coherence. The update step of FLfoot_2 can be viewed as a manipulation of the received local models 3 : w t+1 g = γ • (w t g -ηg m i=1 λig t i ), s.t. γ > 0, λi ≥ 0, ∥λ∥1 = 1, where w t+1 g denotes the global model of round t + 1, η g is the global learning rate, m is the cohort size (i.e., the number of sampled clients), and g t i denotes the local accumulated model updates of



For concision, in section 4 and section 3, if not mentioned otherwise, we all use CIFAR-10 as dataset and SimpleCNN as model. Experiments on more datasets and models are shown in section 5 and Appendix. Different from previous observations (w/o affecting the training dynamics), applying global weight shrinking results in a positive local gradient coherence after the critical point and the learning can benefit from it. We recommend the readers to check Appendix A for the preliminary of federated learning.



Local gradient coherence refers to the averaged cosine similarities of clients' local gradients. A critical point (from positive to negative) exists in the curves of local gradient coherence during the training. The optimization quality of the initial phase (before encountering the point) matters: Assigning larger weights to more coherent clients in this period boosts the final performance. -Heterogeneity coherence refers to the distribution consistency between the global data and the sampled one (i.e. data distribution of a cohort of sampled clients) in each round. The value of heterogeneity coherence is proportional to the IID-ness of clients as well as the client participation ratio; the higher, the better. Increasing the heterogeneity coherence by reweighting the sampled clients could also improve the training performance. • Global weight shrinking regularization effectively improves the generalization performance of the global model. -When the number of local epochs is larger, or the clients' data are more IID, a stronger global weight shrinking is necessary. -The magnitude of the global gradient (i.e. uniform average of local updates) determines the optimal weight shrinking factor. A larger norm of the global gradient requires stronger regularization. -In the late training of FL, where the global model is near convergence, the effect of global weight shrinking gradually saturates. -The effectiveness of global weight shrinking is stemmed from flatter loss landscapes of the global model as well as the improved local gradient coherence after the critical point. 2 2 UNDERSTANDING FL VIA AN ANALOGY

