ADAPTIVE PERSONALIZED FEDERATED LEARNING

Abstract

Investigation of the degree of personalization in federated learning algorithms has shown that only maximizing the performance of the global model will confine the capacity of the local models to personalize. In this paper, we advocate an adaptive personalized federated learning (APFL) algorithm, where each client will train their local models while contributing to the global model. We derive the generalization bound of mixture of local and global models, and find the optimal mixing parameter. We also propose a communication-efficient optimization method to collaboratively learn the personalized models and analyze its convergence in both smooth strongly convex and nonconvex settings. The extensive experiments demonstrate the effectiveness of our personalization schema, as well as the correctness of established generalization theories.

1. INTRODUCTION

With the massive amount of data generated by the proliferation of mobile devices and the internet of things (IoT), coupled with concerns over sharing private information, collaborative machine learning and the use of federated optimization (FO) is often crucial for the deployment of large-scale machine learning (McMahan et al., 2017; Kairouz et al., 2019; Li et al., 2020b) . In FO, the ultimate goal is to learn a global model that achieves uniformly good performance over almost all participating clients without sharing raw data. To achieve this goal, most of the existing methods pursue the following procedure to learn a global model: (i) a subset of clients participating in the training is chosen at each round and receive the current copy of the global model; (ii) each chosen client updates the local version of the global model using its own local data, (iii) the server aggregates over the obtained local models to update the global model, and this process continues until convergence (McMahan et al., 2017; Mohri et al., 2019; Karimireddy et al., 2019; Pillutla et al., 2019) . Most notably, FedAvg by McMahan et al. (2017) uses averaging as its aggregation method over local models. Due to inherent diversity among local data shards and highly non-IID distribution of the data across clients, Fe-dAvg is hugely sensitive to its hyperparameters, and as a result, does not benefit from a favorable convergence guarantee (Li et al., 2020c) . In Karimireddy et al. (2019) , authors argue that if these hyperparameters are not carefully tuned, it will result in the divergence of FedAvg, as local models may drift significantly from each other. Therefore, in the presence of statistical data heterogeneity, the global model might not generalize well on the local data of each client individually (Jiang et al., 2019) . This is even more crucial in fairness-critical systems such as medical diagnosis (Li & Wang, 2019) , where poor performance on local clients could result in damaging consequences. This problem is exacerbated even further as the diversity among local data of different clients is growing. To better illustrate this fact, we ran a simple experiment on MNIST dataset where each client's local training data is sampled from a subset of classes to simulate heterogeneity. Obviously, when each client has samples from less number of classes of training data, the heterogeneity among them will be high and if each of them has samples from all classes, the distribution of their local training data becomes almost identical, and thus heterogeneity will be low. The results of this experiment are depicted in Figure 1 , where the generalization and training losses of the global models of the FedAvg (McMahan et al., 2017) and SCAFFOLD (Karimireddy et al., 2019) on local data diverge when the diversity among different clients' data increases. This observation illustrates that solely optimizing for the global model's accuracy leads to a poor generalization of local clients. To embrace statistical heterogeneity and mitigate the effect of negative transfer, it is necessary to integrate the personalization into learning instead of finding a single consensus predictor. This pluralistic solution for FO has recently resulted in significant research in personalized learning schemes (Eichner et al., 2019; Smith et al., 2017; Dinh et al., 2020; Mansour et al., 2020; Fallah et al., 2020; Li et al., 2020a) . To balance the trade-off between the benefit from collaboration with other users and the disadvantage from the statistical heterogeneity among different users' domains, in this paper, we propose an adaptive personalized federated learning (APFL) algorithm which aims to learn a personalized model for each device that is a mixture of optimal local and global models. We theoretically analyze the generalization ability of the personalized model on local distributions, with dependency on mixing parameter, the divergence between local and global distributions, as well as the number of local and global training data. To learn the personalized model, we propose a communication efficient optimization algorithm that adaptively learns the model by leveraging the relatedness between local and global models as learning proceeds. As it is shown in Figure 1 , by progressively increasing the diversity, the personalized model found by the proposed algorithm demonstrates a better generalization compared to the global models learned by FedAvg and SCAFFOLD. We supplement our theoretical findings with extensive corroborating experimental results that demonstrate the superiority of the proposed personalization schema over the global and localized models of commonly used federated learning algorithms.

2. PERSONALIZED FEDERATED LEARNING

In this section, we propose a personalization approach for federated learning and analyze its statistical properties. Following the statistical learning theory, in a federated learning setting each client has access to its own data distribution D i on domain Ξ := X ×Y, where X ∈ R d is the input domain and Y is the label domain. For any hypothesis h ∈ H the loss function is defined as : H×Ξ → R + . The true risk at local distribution is denoted by L Di (h) = E (x,y)∼Di [ (h(x) , y)]. We use LDi (h) to denote the empirical risk of h on distribution D i . We use D = (1/n) n i=1 D i to denote the average distribution over all clients.

2.1. PERSONALIZED MODEL

In a standard federated learning scenario, where the goal is to learn a global model for all devices cooperatively, the learned global model obtained by minimizing the joint empirical distribution D, i.e., min h∈H L D (h) by proper weighting. However, as alluded to before, a single consensus predictor may not perfectly generalize on local distributions when the heterogeneity among local data shards is high (i.e., the global and local optimal models drift significantly). Meanwhile, from the local user perspective, the key incentive to participate in "federated" learning is the desire to seek a reduction in the local generalization error with the help of other users' data. In this case, the ideal situation would be that the user can utilize the information from the global model to compensate for the small number of local training data while minimizing the negative transfer induced by heterogeneity among distributions. This motivates us to mix the global model and local model with a controllable weight as a joint prediction model, namely, the personalized model. Here we formally introduce our proposed adaptive personalized learning schema, where the goal is to find the optimal combination of the global and the local models, in order to achieve a better client-specific model. In this setting, global server still tries to train the global model by minimizing the empirical risk on the aggregated domain D, i.e., h * = arg min h∈H L D (h), while each user trains a local model while partially incorporating the global model, with some mixing weight α i , i.e., ĥ * loc,i = arg min h∈H LDi (α i h + (1 -α i ) h * ). Finally, the personalized model for ith client is



Figure 1: Comparing generalization and training losses of our proposed personalized model with the global models of FedAvg and SCAFFOLD by increasing the diversity among the data of clients on MNIST dataset with a logistic regression model.

