ADAPTIVE FEDERATED OPTIMIZATION

Abstract

Federated learning is a distributed machine learning paradigm in which a large number of clients coordinate with a central server to learn a model without sharing their own training data. Standard federated optimization methods such as Federated Averaging (FEDAVG) are often difficult to tune and exhibit unfavorable convergence behavior. In non-federated settings, adaptive optimization methods have had notable success in combating such issues. In this work, we propose federated versions of adaptive optimizers, including ADAGRAD, ADAM, and YOGI, and analyze their convergence in the presence of heterogeneous data for general nonconvex settings. Our results highlight the interplay between client heterogeneity and communication efficiency. We also perform extensive experiments on these methods and show that the use of adaptive optimizers can significantly improve the performance of federated learning.

1. INTRODUCTION

Federated learning (FL) is a machine learning paradigm in which multiple clients cooperate to learn a model under the orchestration of a central server (McMahan et al., 2017) . In FL, raw client data is never shared with the server or other clients. This distinguishes FL from traditional distributed optimization, and requires contending with heterogeneous data. FL has two primary settings, crosssilo (eg. FL between large institutions) and cross-device (eg. FL across edge devices) (Kairouz et al., 2019 , Table 1 ). In cross-silo FL, most clients participate in every round and can maintain state between rounds. In the more challenging cross-device FL, our primary focus, only a small fraction of clients participate in each round, and clients cannot maintain state across rounds. For a more in-depth discussion of FL and the challenges involved, we defer to Kairouz et al. (2019) and Li et al. (2019a) . Standard optimization methods, such as distributed SGD, are often unsuitable in FL and can incur high communication costs. To remedy this, many federated optimization methods use local client updates, in which clients update their models multiple times before communicating with the server. This can greatly reduce the amount of communication required to train a model. One such method is FEDAVG (McMahan et al., 2017) , in which clients perform multiple epochs of SGD on their local datasets. The clients communicate their models to the server, which averages them to form a new global model. While FEDAVG has seen great success, recent works have highlighted its convergence issues in some settings (Karimireddy et al., 2019; Hsu et al., 2019) . This is due to a variety of factors including (1) client drift (Karimireddy et al., 2019) , where local client models move away from globally optimal models, and (2) a lack of adaptivity. FEDAVG is similar in spirit to SGD, and may be unsuitable for settings with heavy-tail stochastic gradient noise distributions, which often arise when training language models (Zhang et al., 2019a) . Such settings benefit from adaptive learning rates, which incorporate knowledge of past iterations to perform more informed optimization. In this paper, we focus on the second issue and present a simple framework for incorporating adaptivity in FL. In particular, we propose a general optimization framework in which (1) clients perform multiple epochs of training using a client optimizer to minimize loss on their local data and (2) server updates its global model by applying a gradient-based server optimizer to the average of the clients' model updates. We show that FEDAVG is the special case where SGD is used as both client and server optimizer and server learning rate is 1. This framework can also seamlessly incorporate adaptivity by using adaptive optimizers as client or server optimizers. Building upon this, we develop novel adaptive optimization techniques for FL by using per-coordinate methods as server optimizers. By focusing on adaptive server optimization, we enable use of adaptive learning rates without increase in client storage or communication costs, and ensure compatibility with cross-device FL. Main contributions In light of the above, we highlight the main contributions of the paper. • We study a general framework for federated optimization using server and client optimizers. This framework generalizes many existing federated optimization methods, including FEDAVG. • We use this framework to design novel, cross-device compatible, adaptive federated optimization methods, and provide convergence analysis in general nonconvex settings. To the best of our knowledge, these are the first methods for FL using adaptive server optimization. We show an important interplay between the number of local steps and the heterogeneity among clients. • We introduce comprehensive and reproducible empirical benchmarks for comparing federated optimization methods. These benchmarks consist of seven diverse and representative FL tasks involving both image and text data, with varying amounts of heterogeneity and numbers of clients. • We demonstrate strong empirical performance of our adaptive optimizers throughout, improving upon commonly used baselines. Our results show that our methods can be easier to tune, and highlight their utility in cross-device settings. (Stich, 2019; Yu et al., 2019; Wang & Joshi, 2018; Stich & Karimireddy, 2019; Basu et al., 2019) . In order to analyze FEDAVG in heterogeneous settings, many works derive convergence rates depending on the amount of heterogeneity (Li et al., 2018; Wang et al., 2019; Khaled et al., 2019; Li et al., 2019b) . Typically, the convergence rate of FEDAVG gets worse with client heterogeneity. By using control variates to reduce client drift, the SCAFFOLD method (Karimireddy et al., 2019) achieves convergence rates that are independent of the amount of heterogeneity. While effective in cross-silo FL, the method is incompatible with cross-device FL as it requires clients to maintain state across rounds. For more detailed comparisons, we defer to Kairouz et al. (2019) . Adaptive methods have been the subject of significant theoretical and empirical study, in both convex (McMahan & Streeter, 2010b; Duchi et al., 2011; Kingma & Ba, 2015) and non-convex settings (Li & Orabona, 2018; Ward et al., 2018; Wu et al., 2019) Notation For a, b ∈ R d , we let √ a, a 2 and a/b denote the element-wise square root, square, and division of the vectors. For θ i ∈ R d , we use both θ i,j and [θ i ] j to denote its j th coordinate.

2. FEDERATED LEARNING AND FEDAVG

In federated learning, we solve an optimization problem of the form: min x∈R d f (x) = 1 m m i=1 F i (x), where F i (x) = E z∼Di [f i (x, z)], is the loss function of the i th client, z ∈ Z, and D i is the data distribution for the i th client. For i = j, D i and D j may be very different. The functions F i (and



Related work FEDAVG was first introduced by McMahan et al. (2017), who showed it can dramatically reduce communication costs. Many variants have since been proposed to tackle issues such as convergence and client drift. Examples include adding a regularization term in the client objectives towards the broadcast model (Li et al., 2018), and server momentum (Hsu et al., 2019). When clients are homogeneous, FEDAVG reduces to local SGD (Zinkevich et al., 2010), which has been analyzed by many works

.Reddi et al. (2019); Zaheer et al.  (2018)  study convergence failures of ADAM in certain non-convex settings, and develop an adaptive optimizer, YOGI, designed to improve convergence. While most work on adaptive methods focuses on non-FL settings,Xie et al. (2019)  propose ADAALTER, a method for FL using adaptive client optimization. Conceptually, our approach is also related to the LOOKAHEAD optimizer(Zhang et al.,  2019b), which was designed for non-FL settings. Similar to ADAALTER, an adaptive FL variant of LOOKAHEAD entails adaptive client optimization (see Appendix B.3 for more details). We note that both ADAALTER and LOOKAHEAD are, in fact, special cases of our framework (see Algorithm 1) and the primary novelty of our work comes in focusing on adaptive server optimization. This allows us to avoid aggregating optimizer states across clients, making our methods require at most half as much communication and client memory usage per round (see Appendix B.3 for details).

