WHERE TO BEGIN? ON THE IMPACT OF PRE-TRAINING AND INITIALIZATION IN FEDERATED LEARNING

Abstract

An oft-cited challenge of federated learning is the presence of heterogeneity. Data heterogeneity refers to the fact that data from different clients may follow very different distributions. System heterogeneity refers to client devices having different system capabilities. A considerable number of federated optimization methods address this challenge. In the literature, empirical evaluations usually start federated training from random initialization. However, in many practical applications of federated learning, the server has access to proxy data for the training task that can be used to pre-train a model before starting federated training. Using four standard federated learning benchmark datasets, we empirically study the impact of starting from a pre-trained model in federated learning. Unsurprisingly, starting from a pre-trained model reduces the training time required to reach a target error rate and enables the training of more accurate models (up to 40%) than is possible when starting from random initialization. Surprisingly, we also find that starting federated learning from a pre-trained initialization reduces the effect of both data and system heterogeneity. We recommend future work proposing and evaluating federated optimization methods to evaluate the performance when starting from random and pre-trained initializations. This study raises several questions for further work on understanding the role of heterogeneity in federated optimization.

1. INTRODUCTION

Federated learning (FL) has emerged as a popular distributed machine learning paradigm for privately training a shared model across many participants. At the same time, the training data never leaves the participant's devices. This paper empirically investigates the impact of model initialization on federated optimization methods. Previous empirical evaluations of FL methods start federated training from a randomly initialized model. Transfer learning from pre-trained models has become common practice in natural language processing Radford et al. (2019) ; Devlin et al. (2018) and computer vision He et al. (2019) ; Dosovitskiy et al. (2020) , yielding state-of-the-art results on many tasks and enabling faster model convergence in the centralized setting. Although public proxy data is available at the server in many applications, few prior works studied the impact of starting federated training from a pre-trained model. In cross-device FL (Kairouz et al., 2019) , the primary setting considered in this paper, a central server coordinates many client devices (possibly in hundreds of millions). Each device possesses a local dataset, and the data at different devices follow different distributions, leading to the data heterogeneity challenge (Kairouz et al., 2019) . Moreover, client devices have different system capabilities, leading to system heterogeneity. Finally, devices communicate with the server over low-bandwidth communication links making the performance bottleneck. The predominant approach to federated training builds on local update methods such as FE-DAVG (McMahan et al., 2016) , where a device performs several local updates (e.g., one epoch of SGD on their local training set) before transmitting an update to the server. Although this reduces communication overhead, it can also exacerbate data heterogeneity. Several approaches have been proposed to address this challenge (Li et al., 2018; Hsu et al., 2019; Reddi et al., 2020; Wang et al 

Reddit

Figure 1 : We tested the accuracy of four datasets using random and pre-trained weights. We used solid lines for SGD, dashed lines for PROXIMAL, and dotted lines for MIMELITE. The graph shows how the algorithm rankings changed between the random and pre-trained models. Although no single method is the best for all tasks, FEDADAM with SGD for CLIENTOPT performed consistently well when starting from a pre-trained model, especially for the larger language modeling tasks of Stack Overflow and Reddit. 2020; Karimireddy et al., 2020; 2021; Zhang et al., 2021; Acar et al., 2021) . However, few prior works examine the impact of initialization on federated training.

Contributions.

In this work, we consider the question: How does model initialization (random or pre-trained) impact the behavior of federated optimization methods? We perform an extensive empirical study, comparing 15 variations of federated optimization methods on four commonly-used FL benchmark datasets. Our study reveals three key findings: 1. Although optimizers designed to address heterogeneity typically lead to better performance when starting from a random initialization, when starting from a pre-trained model, we observe that (cf. Fig. 1 ): (i) there is not as big a difference between optimizers in terms of accuracy after a fixed number of rounds, and (ii) using an adaptive optimizer at the server, such as FEDADAM, is more important than using any method for addressing heterogeneity. 2. Starting from a pre-trained model significantly reduces the difference between having non-IID vs. IID data for clients. Furthermore, when starting from a pre-trained model, the number of local epochs per round can be significantly increased without degrading the final accuracy. 3. The initial loss is sometimes lower when starting from a random model. However, the largest Hessian eigenvalue (i.e., local Lipshitz constant or smoothness) is consistently smaller at initialization when starting from a pre-trained model compared to when starting from random initialization. Some of our empirical observations are consistent with existing FL theoretical convergence guarantees. Our findings also highlight that some aspects of FL are not captured with the existing theory, suggesting directions for future work. Initializing 

2. PROBLEM FORMULATION AND THE FEDOPT FRAMEWORK

We consider the following standard optimization formulation of federated training. We seek to find model parameters w that solve the problem, min w∈R d f (w) := m i=1 p i F i (w) (1) where m is the total number of clients, the function F i measures the average loss of a model with parameters w on the ith client's training data, and p i > 0 is the weight given to client i. Usually p i is taken to be proportional to the number of samples at client i so that the optimization problem gives equal weight to all training samples. The goal is to find a model that fits all clients' data well on (weighted) average. In FL, only client i can evaluate F i and its gradient. All of the methods we consider in this study can be expressed in the general FEDOPT framework introduced in Reddi et al. (2020) ; see Algorithm 1. At round t, the server sends its last model x t-1 to a cohort of clients. Each client in the cohort performs E epochs of local training starting from x t-1 using CLIENTOPT with client learning rate η c , producing a local model y E i . Then each client communicates the difference ∆ t i between their local model and the server model, where ∆ t i := x t-1 -y E i . The server computes a weighted average ∆ t of the client updates (line 9 in Alg. 1) and updates its own model via x t+1 = SERVEROPT(x t , ∆ t , η s , t), where SERVEROPT(x t , ∆ t , η s , t) is a first-order optimizer, η s is the server learning rate, and t is the round number.

3.1. DATASETS, MODELS, AND TASKS

We experiment on four FL benchmark datasets: CIFAR-10, FEMNIST, Stack Overflow, and Reddit. All datasets have a natural non-IID partitioning of the data except for CIFAR-10, which we partition using a Dirichlet distribution with α ∈ [0.1, 1.0, 10]. We train Squeezenet and ResNet18 models for CIFAR-10 and FEMNIST, and DistilGPT2 and CharLM models for Stackoverflow and Reddit pushift.io. See Appendix A.1 for additional information. We run each experiment with three different seeds and report the average. We tune client and server learning rates η ℓ and η g , and the proximal penalty parameter µ for FEDPROX with a hyperparameter sweep. Each client update runs one local epoch with fixed batch size per task, and we perform 1050 rounds of training for Stackoverflow, 1000 rounds of training for CIFAR-10, and 1082 training rounds for FEMNIST. For additional implementation details, see Appendix A.2. We use the open-source federated learning simulation framework FLSim FLSim Authors (2022).

3.2. INITIALIZATION STRATEGIES

We consider two initialization strategies: random initialization and supervised pre-training. Random initialization. Most prior federated optimization works use random weights to initialize the model. We can use the same random initialization strategies used in the standard (centralized) training of deep networks for each model (Iandola et al., 2016; HuggingFace, 2019; He et al., 2016; Kim et al., 2016) . Supervised pre-training. In many FL applications, pre-training can be done on a large non-private proxy dataset available at the server. To facilitate easily reproducing our results, we use publicly available pre-trained models or pre-train on public data. For tasks using Squeezenet and ResNet18, we use the version of the model pre-trained on ImageNet, available in the PyTorch Torchvision library.foot_0 For tasks using DistilGPT2, we use the model weights provided in the HuggingFace library that has been distilled from a pre-trained GPT2,foot_1 and for tasks using CharLM, we pre-train the model on WikiText-103 (Merity et al., 2016) (see Appendix B.1 for details).

3.3. ALGORITHMS

We compare federated training with five different CLIENTOPT strategies: SGD clients perform standard stochastic gradient descent updates; Proximal (Li et al., 2018) GD clients perform full-batch gradient updates; in this case, the update ∆ t i returned to the server is a full-batch gradient on client i's local training set evaluated at model parameters x t-1 . At the server, we consider three strategies for SERVEROPT. In all strategies, the server treats the averaged update ∆ t as a gradient. SGD the server updates the global model using stochastic gradient descent; when CLIENTOPT is also SGD, this is equivalent to FEDAVG (McMahan et al., 2016) . SGD with momentum the server updates the global model using SGD with momentum; when CLIENTOPT is SGD, this is equivalent to FEDAVGM (Hsu et al., 2019) . Adam the server updates the global model using the Adam optimizer; when CLIENTOPT is SGD, this is equivalent to FEDADAM (Reddi et al., 2020) . The method commonly referred to as FEDSGD (McMahan et al., 2017) is obtained when CLIENTOPT is full-batch gradient descent (GD) and SERVEROPT is SGD, with η c = 1 and E = 1. We focus on the above choices for CLIENTOPT and SERVEROPT because they are reflective of the most widely-cited federated optimization methods, and they also represent a diverse set of possible choices available to the practitioner seeking to deploy cross-device federated training at scale.

3.4. IMPLEMENTATION AND TUNING

We use three different seeds and report the average of each experiment. Hyperparameter tuning is done for each algorithm, model, and dataset, with parameters including client and server learning 

4. THE IMPACT OF PRE-TRAINING IN FL

In this section, we illustrate the benefits of pre-training in the federated setting and how pre-training can impact federated optimization algorithms behavior. Pre-training changes the ranking of federated optimization algorithms. If one sorts federated optimization methods based on their performance when starting from a random initialization, the order is substantially different from when using a pre-trained initialization. We focus on nine combinations of SERVEROPT and CLIENTOPT, only using local update methods for CLIENTOPT and excluding full-batch gradient descent. We show the change in performance in Figure 1 . First, observe that the span of final accuracies is much smaller when starting from a pre-trained model. Second, all methods starting with pre-trained models achieve a better accuracy after the same number of steps compared to random models. Lastly, observe that the order of methods changes depending on the initialization. Although no particular method dominates across all workloads in Figure 1 , FEDADAM with SGD for CLIENTOPT performs consistently well when starting from a pre-trained model, especially on the two larger language modeling workloads, Stack Overflow and Reddit, and so we focus on studying FEDADAM-SGD below. Faster convergence to better accuracy when starting from a pre-trained model. Figure 2 shows that, as one would hope, when starting from a pre-trained model, it is possible to achieve much better accuracy after a fixed number of steps than when starting federated training from a random initialization. Note that the initial accuracy is not always substantially higher than a random initialization (See Table 2 ). Pre-training closes the accuracy gap between non-IID and IID. We study how pre-training and data heterogeneity affect convergence without system heterogeneity by fixing the number of local epochs to E i = 1. We compare FEDADAM-SGD under IID and Non-IID data splits. In Figure 4 , we report the average accuracy for FedAdam (Reddi et al., 2020) on the four datasets. As expected, randomly initialized models perform much worse than their pre-trained counterparts, and IID partitions yield better quality than non-IID. Surprisingly, the gap between models trained on IID data and models trained on non-IID data is significantly smaller when starting with pre-trained weights. Moreover, shows that taking local SGD steps before server averaging reduces communication by 10-100× compared to taking a full batch gradient step. To understand how pre-training impacts this comparison, we compare FEDADAM with SGD and FEDADAM with GD. While local SGD can reduce communication, the saving is much less when the models are initialized with pre-trained weights compared to random weights. Figure 9 in the Appendix shows that with pre-trained initialization, using GD at the client can yield almost the same result as taking local SGD steps. Pre-training reduces the impact of system heterogeneity. To study the impact of pre-training on system heterogeneity, we follow the setup described in (Li et al., 2018; Wang et al., 2020) . We sample 30% of clients uniformly at random, and client i performs E i local epochs where E i ∼ U (1, 5) while the remaining 70% of the clients perform E i = 1 epochs; this models the setting where clients have different processing capabilities and they perform as much work as they can within a given time limit. Figure 5 shows that FEDADAM-SGD consistently outperforms other methods specifically designed for system heterogeneity (NOVA, PROXIMAL, MIMELITE) when starting from a pre-trained model. Apparently using an adaptive optimizer at the server is sufficient to correct for the negative effects of systems heterogeneity when starting from a pre-trained model. On the other hand, when starting from a random initialization, optimizers specifically designed for system heterogeneity (i.e., FEDNOVA) outperform SGD (Figure 5 right). Moreover, the accuracy gap between algorithms is more pronounced in the random initialization setting, whereas in the pre-trained setting, all algorithms converge to more similar accuracies. Our results suggest that pre-training may reduce the need for algorithms that try to correct system heterogeneity. 

5. UNDERSTANDING WHY PRE-TRAINING HELPS FEDERATED OPTIMIZATION

While pre-training unsurprisingly speeds up convergence, the reason for the speedup is less apparent. In this section, we examine why pre-training is beneficial to federated learning. Pre-training helps align client updates. To better understand why pre-training alleviates the heterogeneity challenge, we first investigate the gradient diversity of the updates received from different clients. We adopt the notion introduced in Yin et al. (2017) , adapted here to apply to client updates ∆ i (whereas Yin et al. (2017) focus specifically on gradients g i ): GradientDiversity({∆ i : i ∈ S t }) = i∈S t ||∆ i || 2 || i∈S t ∆ 2 i || . In Figure 6 , we plot the gradient diversity of client updates ∆ i at each round for FEDADAM. In the pre-trained setting, client updates have significantly lower gradient diversity (see the bottom left plot in Figure 6 ). This suggests that when starting from a pre-trained model, the client local model changes are more similar to each other. On the other hand, clients local model changes from randomly initialized weights are almost orthogonal, suffering more from the client drift problem. In addition, when looking at the cosine similarity of consecutive aggregated update vectors in time (bottom middle), we see that consecutive updates point more consistently in a similar direction at the beginning of training when starting from a pre-trained model. From the top middle and bottom right plots in Figure 6 , we see that the pre-trained model starts closer to the final result. We also examine the largest eigenvalue of the Hessian matrix (i.e., local Lipshitz constant or smoothness) at the beginning of training, a larger value of which suggests a harder-to-optimizer loss surface. To compute the Hessian matrix, We examine the largest eigenvalue of the Hessian matrix at initialization, round 0, using Power Iteration from PyHessian by Yao et al. (2020) . In Table 1 , one can observe that pre-trained models always lead to smaller eigenvalues on different datasets. Initial loss for pre-trained versus random models. Connection to theory. Here, we present the existing optimization theory for FEDAVG and discuss how pre-training helps to improve the model convergence. Following the formulation in Section 3, suppose that there are total m clients, jointly optimizing a global objective function f (w) = m i=1 p i F i (w), and that each client's local loss function F i (w) is L-Lipschitz smooth. For ease of presentation, we assume that all clients participate in training and perform K local SGD updates at each round. Then, under standard assumptions, one can show that after R communication rounds, the expected gradient norm satisfies (see Theorem V in Karimireddy et al. (2020) ): E ∇f (x R ) 2 ≤ O √ F √ RKm + F 2/3 ζ 2/3 R 2/3 , where wR represents a weighted sampled model from all previous rounds, ζ is a measure of data heterogeneity, and F = f (x 0 ) -f * denotes the gap between the initial loss value and the optimal loss value. In addition, in order for FEDAVG to achieve the O(1/ √ RKm) asymptotic convergence rate, previous works (Wang et al., 2021; Woodworth et al., 2020; Karimireddy et al., 2020; Wang and Joshi, 2021) showed that the number of local updates K should be upper bounded as follows: K ≤ O R 1/3 F 1/3 ζ 4/3 m . Clearly, if F becomes smaller starting from a pre-trained model, one can use a larger number of local updates. This corroborates our empirical observations in Figure 3 . When starting from a pre-trained model, the initial gap F is sometimes reduced, as observed in Table 2 . As a result, the optimization error upper bound (2) will be smaller, i.e., we get better worst-case performance. However a lower initial loss is not always observed in our experiments, so this does not fully explain our observations, suggesting that we may need to re-think the convergence theory of local update methods.

6. RECOMMENDATIONS

In this work, we study the effects of pre-training on federated optimization methods. Our results inform the following recommendations: 1. When evaluating FL algorithms, researchers should experiment with both pre-trained (if available) and random weights, as the initialization can clearly impact the relative performance of different methods, and both initialization schemes may be relevant to practice. 2. When deploying FL to a production environment, using adaptive server optimizers such as FEDADAM together with SGD at the client is a simple and competitive approach when it is possible to start from a pre-trained model. 3. When there is public data to pre-train a model, the impact of heterogeneity can be reduced. Thus, when focusing on heterogeneity, it may be worth considering whether or not proxy data is available for pre-training to motivate the application considered.

7. RELATED WORK

Transfer learning. Model initialization can significantly impact training and final performance. Previous work studying the loss landscape of deep networks observed significant differences between the landscape around a random initialization and the landscape later in training. In particular, later in training, the loss can be much more "well-behaved" (Li et al., 2017; Frankle et al., 2020; 2019) . Fine-tuning from pre-trained models is common practice in natural language processing and computer vision, yielding strong performance on many tasks (Radford et al., 2019; Dosovitskiy et al., 2020; Devlin et al., 2018; He et al., 2019) . Limitations. Depending on the application, it may not be possible to get public data, in which case random initialization may be the only option. Nevertheless, we believe there is sufficient prevalence and importance of applications where public data is available for this study to be of broad interest. When public data is available, it may not necessarily reflect the distribution of all users in the population. Consequently, pre-training using public data may introduce bias, which warrants further study, including methods to detect and mitigate such bias. Moreover, we only consider one warm-start initialization strategy: supervised pre-training. Several other possibilities are worth investigating, including meta-learning the warm-start initialization and self-supervised pre-training.

Conclusion.

In this paper, we present a thorough empirical analysis of initialization on federated learning by evaluating it on twelve federated learning algorithms across four vision and text tasks. We find that pre-training on public data can recover most of the accuracy drop from heterogeneity. We show that client updates starting from pre-trained weights have higher cosine similarity, which explains why initializing with pre-trained weights can speed up convergence and achieve high accuracy even in heterogeneous settings. We further show that using simple SGD locally can be as good as other local optimizers.



https://github.com/pytorch/vision https://huggingface.co/distilgpt2



Figure2: While prior works ignore the importance of initialization, using pre-trained models should be the first step for any practical deployment to save on communication bandwidth and achieve high model accuracy. This figure shows the advantage of using a pre-trained model for four tasks. For Stack Overflow and Reddit, we use DistilGPT2. For CIFAR-10 and FEMNIST, we use SqueezeNet.

Figure 4: The average accuracy on 3 different seeds for FEDADAM trained on IID and non-IID data.For CIFAR-10 Non-IID, we generate 100 non-IID clients using a Dirichlet(0.1). For other three datasets, we use the natural non-IID client partitions.

Figure 6: Training and gradient statistics of a Resnet18 on CIFAR-10 with Dirichlet distribution with parameter 0.1. Top row: Train loss of global model; train accuracy of global model; evaluation accuracy of global model; evaluation loss of global model. Bottom row: Gradient diversity of client updates; cosine similarity between client updates; L2 distance of server weights from their final values at the end of training.

Optimization. While a significant amount of research focused on various aspects of FL, including communication-efficiencyMcMahan et al. (2016), data and systems heterogeneityLi et al. (2018);Wang et al. (2020), and faster convergence rate . Nearly all previous work in this field neglect the importance of initialization. In our work, we study the impact of initialization on federated optimization in the cross-device setting. We defer the interested reader to surveys ofKairouz et al. (2019) andWang et al. (2021)  for additional background. Pre-training in Federated Learning. Few studies have investigated pre-trained models in federated learning, including Pillutla et al. (2022); Hsu et al. (2020); Zhao et al. (2018); Lin et al. (2021); Stremmel and Singh (2021); Tan et al. (2022). While Zhao et al. (2018) found pre-training does not alleviate the effect of heterogeneity, Tan et al. (2022) focuses on effectively learning from pre-trained models, and Chen et al. (2022); Weller et al. (2022); Chen et al. (2022) focused on synthetically partitioned data and proposed pre-training methods with synthetic data. Our work fills a gap by comparing random initialization and pre-training, which other studies Pillutla et al. (2022); Hsu et al. (2020); Lin et al. (2021);Stremmel and Singh (2021) did not address. We systematically investigate both forms of heterogeneity and evaluate 15 SOTA federated optimization algorithms across visual and language tasks. Our study offers theoretical and empirical evidence supporting the benefits of pre-training in FL 8 CONCLUSION AND LIMITATIONS

.,

Input: initial global model x 0 , server and client step sizes η s , η c , local epochs E, rounds T 2: for each round t = 1, . . . , T do

Table2shows that pre-training does not always lead to lower initial loss. For Squeezenet 1.0 on CIFAR-10 and ResNet-18 on FEMNIST, the initial loss of the randomly initialized models are lower pre-trained models. However, pre-trained model still converges faster as illustrated in Figure2. The top eigenvalue of the Hessian matrix for each dataset between the pre-trained and random initialized models.

Loss at beginning of training for various model architectures and datasets. The initial loss of the pretrained model is not always lower than that of a random initialization.

