FEDERATED LEARNING BASED ON DYNAMIC REGULARIZATION

Abstract

We propose a novel federated learning method for distributively training neural network models, where the server orchestrates cooperation between a subset of randomly chosen devices in each round. We view Federated Learning problem primarily from a communication perspective and allow more device level computations to save transmission costs. We point out a fundamental dilemma, in that the minima of the local-device level empirical loss are inconsistent with those of the global empirical loss. Different from recent prior works, that either attempt inexact minimization or utilize devices for parallelizing gradient computation, we propose a dynamic regularizer for each device at each round, so that in the limit the global and device solutions are aligned. We demonstrate both through empirical results on real and synthetic data as well as analytical results that our scheme leads to efficient training, in both convex and non-convex settings, while being fully agnostic to device heterogeneity and robust to large number of devices, partial participation and unbalanced data.

1. INTRODUCTION

In (McMahan et al., 2017) , the authors proposed federated learning (FL), a concept that leverages data spread across many devices, to learn classification tasks distributively without recourse to data sharing. The authors identified four principle characteristics of FL based on several use cases. First, the communication links between the server and devices are unreliable, and at any time, there may only be a small subset of devices that are active. Second, data is massively distributed, namely the number of devices are large, while amount of data per device is small. Third, device data is heterogeneous, in that data in different devices are sampled from different parts of the sample space. Finally, data is unbalanced, in that the amount of data per device is highly variable. The basic FL problem can be cast as one of empirical minimization of a global loss objective, which is decomposable as a sum of device-level empirical loss objectives. The number of communication rounds, along with the amount of bits communicated per round, has emerged as a fundamental gold standard for FL problems. Many mobile and IoT devices are bandwidth constrained, and wireless transmission and reception is significantly more power hungry than computation (Halgamuge et al., 2009) . As such schemes that reduce communication are warranted. While distributed SGD is a viable method in this context, it is nevertheless communication inefficient. A Fundamental Dilemma. Motivated by these ideas, recent work has proposed to push optimization burden onto the devices, in order to minimize amount of communications. Much of the work in this context, propose to optimize the local risk objective based on running SGD over mini-batched device data, analogous to what one would do in a centralized scenario. On the one hand, training models on local data that minimize local empirical loss appears to be meaningful, but yet, doing so, is fundamentally inconsistent with minimizing the global empirical lossfoot_0 (Malinovsky et al., 2020; Khaled et al., 2020a) . Prior works (McMahan et al., 2017; Karimireddy et al., 2019; Reddi et al., 2020) attempt to overcome this issue by running fewer epochs or rounds of SGD on the devices, or attempt to stabilize server-side updates so that the resulting fused models correspond to inexact minimizations and can result in globally desirable properties. Dynamic Regularization. To overcome these issues, we revisit the FL problem, and view it primarily from a communication perspective, with the goal of minimizing communication, and as such allowing for significantly more processing and optimization at the device level , since communication is the main source of energy consumption (Yadav & Yadav, 2016; Latré et al., 2011) . This approach, while increasing computation for devices, leads to substantial improvement in communication efficiency over existing state-of-the-art methods, uniformly across the four FL scenarios (unreliable links, massive distribution, substantial heterogeneity, and unbalanced data). Specifically, in each round, we dynamically modify the device objective with a penalty term so that, in the limit, when model parameters converge, they do so to stationary points of the global empirical loss. Concretely, we add linear and quadratic penalty terms, whose minima is consistent with the global stationary point. We then provide an analysis of our proposed FL algorithm and demonstrate convergence of the local device models to models that satisfy conditions for local minima of global empirical loss with a rate of O 1 T where T is number of rounds communicated. For convex smooth functions, with m devices, and P devices active per round, our convergence rate for average loss with balanced data scales as O 1 T m P , substantially improving over the state-of-art (SCAFFOLD O 1 T m P ). For non-convex smooth functions, we establish a rate of O 1 T m P . We perform experiments on both visual and language real-world datasets including MNIST, EMNIST, CIFAR-10, CIFAR-100 and Shakespeare. We tabulate performance studying cases that are reflective of FL scenarios, namely, for (i) varying device participation levels, (ii) massively distributed data, (iii) various levels of heterogeneity, as well as (iv) unbalanced local data settings. Our proposed algorithm, FedDyn, has similar overhead to competing approaches, but converges at a significantly faster rate. This results in a substantial reduction in communication compared to baseline approaches such as conventional FedAvg (McMahan et al., 2017) , FedProx (Li et al., 2020) and SCAFFOLD (Karimireddy et al., 2019) , for achieving target accuracy. Furthermore, our approach is simple to implement, requiring far less hyperparameter tuning compared to competing methods. Contributions. We summarize our main results here. • We present, FedDyn, a novel dynamic regularization method for FL. Key to FedDyn is a new concept, where in each round the risk objective for each device is dynamically updated so as to ensure that the device optima is asymptotically consistent with stationary points of the global empirical loss, • We prove convergence results for FedDyn in both convex and non-convex settings, and obtain sharp results for communication rounds required for achieving target accuracy. Our results for convex case improves significantly over state-of-art prior works. FedDyn in theory is unaffected by heterogeneity, massively distributed data, and quality of communication links, • On benchmark examples FedDyn achieves significant communication savings over competing methods uniformly across various choices of device heterogeneity and device participation on massively distributed large-scale text and visual datasets. Related Work. FL is a fast evolving topic, and we only describe closely related approaches here. Comprehensive field studies have appeared in (Kairouz et al., 2019; Li et al., 2020) . The general FL setup involves two types of updates, the server and device, and each of these updates are associated with minimizing some local loss function, which by itself could be updated dynamically over different rounds. At any round, there are methods that attempt to fully optimize or others that propose inexact optimization. We specifically focus on relevant works that address the four FL scenarios (massive distribution, heterogeneity, unreliable links, and unbalanced data) here. One line of work proposes local SGD (Stich, 2019) based updates, wherein each participating device performs a single local SGD step. The server then averages received models. In contrast to local SGD, our method proposes to minimize a local penalized empirical loss. FedAvg (McMahan et al., 2017) is a generalization of local SGD, which proposes a larger number of local SGD steps per round. Still, FedAvg inexactly solves device side optimization. Identifying when to stop minimizing so that one gets a good accuracy-communication trade-off is based on tuning the number of epochs and the learning rate (McMahan et al., 2017; Li et al., 2020b) . Despite the strong empirical performance of FedAvg in IID settings, performance degrades in non-IID scenarios (Zhao et al., 2018) . Several modifications of FedAvg have been proposed to handle non-IID settings. These variants include using a decreasing learning rate (Li et al., 2020b) ; modifying device empirical loss dynamically (Li et al., 2020a) ; or modifying server side updates (Hsu et al., 2019; Reddi et al., 2020) . Methods that use a decreasing learning rate or customized server side updates still rely on local SGD updates within devices. While these works do recognize the incompatibility of local and global stationary points, their proposed fix is based on inexact minimization. Additionally, in order to establish convergence for non-IID situations, these works impose additional "bounded-non-IID" conditions. FedProx (Li et al., 2020a ) is related to our method. Like us they propose a dynamic regularizer, which is modified based on server supplied models. This regularizer penalizes updates that are far away from the server model. Nevertheless, the resulting regularizer does not result in aligning the global and local stationary points, and as such inexact minimization is warranted, and they do so by carefully choosing learning rates and epochs. Furthermore, tuning requires some knowledge of statistical heterogeneity. In a similar vein, there are works that augment updates with extra device variables that are also transmitted along with the models (Karimireddy et al., 2019; Shamir et al., 2014) . These works prove convergence guarantees through adding device-dependent regularizers. Nevertheless, they suffer additional communication costs and they are not extensively experimented with deep neural networks. Among them, SCAFFOLD (Karimireddy et al., 2019 ) is a closely related work even though it transmits extra variables and a more detailed comparison is given in Section 2. Another line of distributed optimization methods (Konečnỳ et al., 2016; Makhdoumi & Ozdaglar, 2017; Shamir et al., 2014; Yuan & Ma, 2020; Pathak & Wainwright, 2020; Liang et al., 2019; Li et al., 2020c; Condat et al., 2020) could be considered in this setting. Moreover, there are works that extend analysis of SGD type methods to FL settings (Gorbunov et al., 2020; Khaled et al., 2020b; Li & Richtárik, 2020) . However, these algorithms are proposed for full device participation case which fails to satisfy one important aspect of FL. FedSVRG (Konečnỳ et al., 2016) and DANE (Shamir et al., 2014) need gradient information from all devices at each round and they are not directly applicable to partial FL settings. For example, FedDANE (Li et al., 2019) is a version of DANE that works in partial participation. However, FedDANE performs worse than FedAvg empirically with partial participation (Li et al., 2019) . Similar to these works, FedPD (Zhang et al., 2020) method is proposed in distributed optimization with a different participation notion. FedPD activates either all devices or no devices per round which again fails to satisfy partial participation in FL. Lastly, another set of works aims to decrease communication costs by compressing the transmitted models (Dutta et al., 2019; Mishchenko et al., 2019; Alistarh et al., 2017) . They save communication costs through decreasing bit-rate of the transmission. These ideas are complementary to our work and they can be integrated to our proposed solution.

2. METHOD

We assume there is a cloud server which can transmit and receive messages from m client devices. Each device, k ∈ [m] consists of N k training instances in the form of features, x ∈ X and corresponding labels y ∈ Y that are drawn IID from a device-indexed joint distribution, (x, y) ∼ P k . Our objective is to solve arg min θ∈R d   (θ) 1 m k∈[m] L k (θ)   Algorithm 1: Federated Dynamic Regularizer -(FedDyn) Input: T, θ 0 , α > 0, ∇L k (θ 0 k ) = 0. for t = 1, 2, . . . T do Sample devices P t ⊆ [m] and transmit θ t-1 to each selected device, for each device k ∈ P t , and in parallel do Set θ t k = argmin θ L k (θ) -∇L k (θ t-1 k ), θ + α 2 θ -θ t-1 2 , Set ∇L k (θ t k ) = ∇L k (θ t-1 k ) -α θ t k -θ t-1 , Transmit device model θ t k to server, end for for each device k ∈ P t , and in parallel do Set θ t k = θ t-1 k , ∇L k (θ t k ) = ∇L k (θ t-1 k ), end for Set h t = h t-1 -α 1 m k∈Pt θ t k -θ t-1 , Set θ t = 1 |Pt| k∈Pt θ t k -1 α h t end for where, L k (θ) = E (x,y)∼D k [ k (θ; (x, y))] is the empirical loss of the kth device, and θ are the parameters of our neural network, whose structure is assumed to be identical across the devices and the server. We denote by θ * a local minima of the global empirical loss function. FedDyn Method. Our proposed method, FedDyn, is displayed in Algorithm 1. In each round, t ∈ [T ], a subset of devices P t ⊂ [m] are active, and the server transmits its current model, θ t-1 , to these devices. Each active device then optimizes a local empirical risk objective, which is the sum of its local empirical loss and a penalized risk function. The penalized risk, which is dynamically updated, is based on current local device model, and the received server model: θ t k = argmin θ R k (θ; θ t-1 k , θ t-1 ) L k (θ) -∇L k (θ t-1 k ), θ + α 2 θ -θ t-1 2 . ( ) Devices compute their local gradient, ∇L k θ t-1 k , recursively, by noting that the first order condition for local optima must satisfy, ∇L k (θ t k ) -∇L k (θ t-1 k ) + α(θ t k -θ t-1 ) = 0 (2) Stale devices do not update their models. Updated device models, θ t k , k ∈ P t are then transmitted to server, which then updates its model to θ t as displayed in Algorithm 1. Intuitive Justification. To build intuition into our method, we first highlight a fundamental issue about the Federated Dynamic Regularizer setup. It is that stationary points for device losses, in general, do not conform to global losses. Indeed, a global stationary point, θ * must necessarily satisfy, ∇ (θ t ) 1 m k∈[m] ∇L k (θ * ) = k∈[m] E (x,y)∼D k ∇ k (θ * ; (x, y)) = 0. In contrast a device's stationary point, θ * k satisfies, ∇L k (θ * k ) = 0, and in general due to heterogeneity of data (P k = P j for k = j), the individual device-wise gradients are non-zero ∇L k (θ * ) = 0. This means that the dual goals of (i) seeking model convergence to a consensus, namely, θ t k → θ t → θ * , and (ii) the fact that model updates are based on optimizing local empirical losses is inconsistent 2 . Dynamic Regularization. Our proposed risk objective in Eq. 1 dynamically modifies local loss functions, so that, if in fact local models converge to a consensus, the consensus point is consistent with stationary point of the global loss. To see this, first note that if we initialize at a consensus point, namely, θ t-1 k = θ t-1 , we have, ∇R(θ, θ t-1 k , θ t-1 ) = 0 for θ = θ t-1 . Thus our choice can be seen as modifying the device loss so that the stationary points of device risk is consistent with server model. Key Property of Algorithm 1. If local device models converge, they converge to the server model, and the convergence point is a stationary point of the global loss. To see this, observe from Eq 2 that if θ t k → θ ∞ k , it generally follows that, ∇L k (θ t k ) → ∇L k (θ ∞ k ) , and as a consequence, we have θ t → θ ∞ k . In turn this implies that θ ∞ k → θ ∞ , i.e. , is independent of k. Putting all of this together with our server update equations we have that θ t convergence implies h t → 0. Now the server state h t k ∇L k (θ t k ), and as such in the limit we are left with k ∇L k (θ t k ) → k ∇L k (θ ∞ ) = 0. This implies that we converge to a point that turns out to be a stationary point of the global risk.

2.1. CONVERGENCE ANALYSIS OF FEDDYN.

Properties outlined in the previous section, motivates our FedDyn convergence analysis of device and server models. We will present theoretical results for strongly convex, convex and non-convex functions. Theorem 1. Assuming a constant number of devices are selected uniformly at random in each round, |P t | = P , for a suitably chosen of α > 0, Algorithm 1 satisfies, • µ strongly convex and L smooth {L k } m k=1 functions, E 1 R T -1 t=0 r t γ t - * = O   1 r T   β θ 0 -θ * 2 + m P 1 β   1 m k∈[m] ∇L k (θ * ) 2       • Convex and L smooth {L k } m k=1 functions, E 1 T T -1 t=0 γ t - * = O   1 T m P   L θ 0 -θ * 2 + 1 L 1 m k∈[m] ∇L k (θ * ) 2     • Nonconvex and L smooth {L k } m k=1 functions, E ∇ (γ T ) 2 = O   1 T   L m P (θ 0 ) - * + L 2 1 m k∈[m] θ 0 k -θ 0 2     where γ t = 1 P k∈Pt θ t k , θ * =arg min θ (θ), * = (θ * ) , r= 1 + µ α , R= T -1 t=0 r t β= max 5 m P µ, 30L and γ T is a random variable that takes values {γ s } T -1 s=0 with equal probability. Theorem 1 gives rates for strongly convex, convex and nonconvex local losses. For strongly convex and smooth functions, in expectation, a weighted average of active device averages converge at a linear rate. For convex and smooth functions, in expectation, the global loss of active device averages, converges at a rate O 1 T m P . Following convention, this rate is for the empirical loss averaged across devices. As such this rate would hold with moderate data imbalance. In situations with significant imbalance, which scales with data size, these results would have to account for the variance in the amount of data/device. Furthermore, the m P factor might appear surprising, but note that our bounds hold under expectation, namely, the error reflects the average over all random choices of devices. Similarly, for nonconvex and smooth functions, in expectation, average of active device models converges to a stationary point at O 1 T m P rate. The expectation is taken over randomness in active device set at each round. Similar to known convergence theorems, the problem dependent constants are related to how good the algorithm is initialized. We refer to Appendix B for a detailed proof. FedDyn vs. SCAFFOLD (Karimireddy et al., 2019) . While SCAFFOLD appears to be similar to our method, there are fundamental differences. Practically, SCAFFOLD communicates twice as many bits as FedDyn or Federated Dynamic Regularizer, transmitting back and forth, both ) is unknown and must be transmitted by the server, leading to increased bit-rate. Note that this is unavoidable, since ignoring this term, leads to freezing device updates (optimizing L k (θ) -∇L k (θ t ), θθ t + α 2 θθ t 2 results in θ = θ t ). This extra term is a surrogate for ∇ (θ t ), which is unavailable. As such we believe that these differences are responsible for FedDyn's improved rate (in rounds) in theory as well as practice. Finally, apart from conceptual differences, there are also implementation differences. SCAFFOLD runs SGD, and adapts hyperparameter tuning for a given number of rounds to maximize accuracy. In contrast, our approach, based on exact minimization, is agnostic to specific implementation, and as such we utilize significantly less tuning.

3. EXPERIMENTS

Our goal in this section is to evaluate FedDyn against competing methods on benchmark datasets for various FL scenarios. Our results will highlight tradeoffs and benefits of our exact minimization relative to prior inexact minimization methods. To ensure a fair comparison, the usual SGD procedure is adapted for the FedDyn algorithm in the device update as in FedAvg rather than leveraging an off the shelf optimization solver. We provide a brief description of the datasets and the models used in the experiments. A detailed description of our setup can be found in Appendix A.1. Partial participation was handled by sampling devices at random in each round independent of previous rounds. Datasets. We used benchmark datasets with the same train/test splits as in previous works (McMahan et al., 2017; Li et al., 2020a) which are MNIST (LeCun et al., 1998) , CIFAR-10, CIFAR-100 (Krizhevsky et al., 2009) , a subset of EMNIST (Cohen et al., 2017) (EMNIST-L), Shakespeare (Shakespeare, 1994) as well as a synthetic dataset. The IID split is generated by randomly assigning datapoints to the devices. The Dirichlet distribution is used on the label ratios to ensure uneven label distributions among devices for non-IID splits as in (Yurochkin et al., 2019) . For example, in MNIST, 100 device experiments, each device has about 5 and 3 classes that consume 80% of local data at Dirichlet parameter settings of 0.6 and 0.3 respectively. To generate unbalanced data, we sample the number of datapoints from a lognormal distribution. Controlling the variance of lognormal distribution gives unbalanced data. For instance, in CIFAR-10, 100 device experiments, balanced and unbalanced data settings have standard deviation of device sample size of 0 and 0.3 respectively. Models. We use fully-connected neural network architectures for MNIST and EMNIST-L with 2 hidden layers. The number of neurons in the layers are 200 and 100; and the models achieve 98.4% and 95.0% test accuracy in MNIST and EMNIST-L respectively. The model used for MNIST is the same as used in (McMahan et al., 2017) . For CIFAR-10 and CIFAR-100, we use a CNN model, similar to (McMahan et al., 2017) , consisting of 2 convolutional layers with 64 5 × 5 filters followed by 2 fully connected layers with 394 and 192 neurons, and a softmax layer. The model achieves 85.2% and 55.3% test accuracy for CIFAR-10 and CIFAR-100 respectively. For the next character prediction task (Shakespeare) , we use a stacked LSTM, similar to (Li et al., 2020a) . This architecture achieves a test accuracy of 50.8% and 51.2% in IID and non-IID settings respectively. Both IID and non-IID performances are reported since splits are randomly regenerated from the entire Shakespeare writing. Hence centralized data and the centralized model performance is different. In passing, we note that while the accuracies reported are state-of-art for our chosen models, higher capacity models can achieve higher performance on these datasets. As such, our aim is to compare the relative performance of these models in FL using FedDyn and other strong baselines. Comparison of Methods. We report the performance of FedDyn, SCAFFOLD, FedAvg and FedProx on synthetic and real datasets. We also experimented with distributed SGD, where devices in each round compute gradients on the server supplied model on local data, and communicate these gradients. Its performance was not competitive relative to other methods. Therefore, we do not tabulate it here. We cover synthetic data generation and its results in Appendix A.1. The standard goal in FL is to minimize amount of bits transferred. For this reason, we adopt the number of models transmitted to achieve a target accuracy as our metric in our comparisons. This metric is different than comparing communication rounds since not all methods communicate the same amount of information per round. FedDyn, FedAvg and FedProx transmit/receive the same amount of models for a fixed number of rounds whereas SCAFFOLD costs twice due to transmission of states. We compare algorithms for two different accuracy levels which we pick them to be close to performance obtained by centralizing data. Along with transmission costs of each method, we report the communication savings of FedDyn compared to each baseline in parenthesis. For methods that could not achieve aimed accuracy within the communication constraint, we append transmission cost with + sign. We observe FedDyn results in communication savings compared to all baselines to reach a target accuracy. We test FedDyn under the four characteristic properties of FL which are partial participation, large number of devices, heterogeneous data, and unbalanced data. Moderate vs. Large Number of Devices. FedDyn significantly outperforms competing methods in the practically relevant massively distributed scenario. We report the performance of FedDyn on CIFAR-10 and CIFAR-100 with moderate and large number of devices in Table 1 , while keeping the participation level constant (10%) and the data amounts balanced. Specifically, the moderately distributed setting has 100 devices with 500 images per device. The massively distributed setting has 1000 devices with 50 images per device for CIFAR-10, as well as 500 devices with 100 images per device for CIFAR-100. In each distributed setting, the data is partitioned in both IID and non-IID (Dirichlet 0.3) fashion. FedDyn leads to substantial transmission reduction in each of the regimes. First, the communication saving in the massive setting is significantly larger relative to the moderate setting. Compared to SCAFFOLD, FedDyn leads to 4.8× and 2.9× gains respectively on CIFAR-10 IID setting. SCAFFOLD is not able to achieve 80% within 2000 rounds in the massive setting (shown in Figure 4a ), thus actual saving is more than 4.8×. Similar trend is observed in the non-IID setting of CIFAR-10 and CIFAR-100. Second, all the methods require more communications to achieve a reasonable accuracy in the massive setting as the dataset is more decentralized. For instance, it takes FedDyn 637 rounds to achieve 84.5% with 100 devices, while it takes 840 rounds to achieve 80.0% with 1000 devices. Similar trend is observed for CIFAR-100 and other methods. FedDyn always achieves the target accuracy with fewer rounds and thus leads to significant saving. Third, a higher target accuracy may result in a greater saving. For instance, the saving relative to SCAFFOLD increases from 3× to 4.8× in the CIFAR-10 IID massive setting. We may attribute this to the fact that FedDyn aligns device functions to global loss and efficiently optimizes the problem. Full vs. Partial Participation Levels. FedDyn outperforms baseline methods across different device participation levels. We consider different device participation levels with 100 devices and balanced data in Table 2 where part of CIFAR-10 and CIFAR-100 results are omitted since they are reported in moderate number of devices section of Table 1 . The Shakespeare non-IID results are separately shown, since it has a natural non-IID split which does not conform with the Dirichlet distribution. The communication gain, with respect to best baseline, increases with greater participation levels from 2.9× to 9.4×; 4.0× to 12.8× and 4.2× to 7.9× for CIFAR-10 in different device distribution settings. We observe a similar performance increase in full participation for most of the datasets. This validates our hypothesis that FedDyn more efficiently incorporates information from all devices compared to other methods, and results in more savings in full participation. Similar to previous results, a greater target accuracy gives a greater savings in most of the settings. We also report results for 1% participation regime with different device distribution settings (See Table 5 in Appendix A.1). Balanced vs. Unbalanced Data. FedDyn is more robust to unbalanced data than competing methods. We fix number of devices ( 100) and participation level (10%) and consider effect of unbalanced data (Table 4 (Appendix A.1)). FedDyn achieves 4.3× gains over the best competitor, SCAFFOLD to achieve the target accuracy. As before, gains increase with the target accuracy. IID vs. non-IID Device Distribution. FedDyn outperforms baseline methods across different device distribution levels. We consider heterogeneous device distributions in the context of varying device numbers, participation levels and balanced-unbalanced settings in Table 1 , 2 and 4 (Appendix A.1) respectively. Device distributions become more non-IID as we go from IID, Dirichlet .6 to Dirichlet .3 splits which makes global optimization problem harder. We see a clear effect of this change in Table 2 for 10% participation level and in Table 4 for unbalanced setting. For instance, increasing non-IID level results in a greater communication saving such as from 2.9×, 4.0× to 4.2× in CIFAR-10 10% participation. Similar statement holds for MNIST, EMNIST-L and Shakespeare in Table 2 and for CIFAR-10 unbalanced setting in Table 4 . We do not observe a significant difference in savings for full participation setting in Table 2 . Summary. Overall, FedDyn consistently leads to substantial communication savings compared to baseline methods uniformly across various FL regimes of interest. We realize large gains in the practically relevant massively distributed data setting.

4. CONCLUSION

We proposed FedDyn, a novel FL method for distributively training neural network models. FedDyn is based on exact minimization, wherein at each round, each participating device, dynamically updates its regularizer so that the optimal model for the regularized loss is in conformity with the global empirical loss. Our approach is different from prior works that attempt to parallelize gradient computation, and in doing so they tradeoff target accuracy with communications, and necessitate inexact minimization. We investigate different characteristic FL settings to validate our method. We demonstrate both through empirical results on real and synthetic data as well as analytical results that our scheme leads to efficient training with convergence rate as O 1 T where T is number of rounds, in both convex and non-convex settings, and a linear rate in strongly convex setting, while being fully agnostic to device heterogeneity and robust to large number of devices, partial participation and unbalanced data. Dataset. We introduce a synthetic dataset to reflect different properties of FL by using a similar process as in (Li et al., 2020a) . The datapoints (x j , y j ) of device i are generated based on y j = arg max(θ * i x j + b * i ) where x j ∈ R 30×1 , y j ∈ {1, 2, . . . 5}, θ * i ∈ R 5×30 , and b * i ∈ R 5×1 . (θ * i , b * i ) tuple represents the optimal parameter set for device i and each element of these tuples are randomly drawn from N (µ i , 1) where µ i ∼ N (0, γ 1 ). The features of datapoints are modeled as (x j ∼ N (ν i , σ)) where σ is a diagonal covariance matrix with elements σ k,k = k -1.2 and each element of ν i is drawn from N (β i , 1) where β i ∼ N (0, γ 2 ). The number of datapoints in device i follows a lognormal distribution with variance γ 3 . In this generation procees, γ 1 , γ 2 and γ 3 regulate the relation of the optimal models for each device, the distribution of the features for each device and the amount of datapoints per device respectively. We simulate different settings by allowing only one type of heterogeneity at a time and disabling the randomness from the other two. For instance, if we want to disable type 1 heterogeneity, we draw one single set of optimal parameters (θ * , b * ) ∼ N (0, 1) and use it to generate datapoints for all devices. Similarly, ν i is set to 0 to disable type 2 heterogeneity and γ 3 is set to 0 to disable type 3 heterogeneity. We consider four settings in total, including type 1, 2, and 3 heterogeneous as well as a homogeneous setting. The number of devices is set to 20 and the number of datapoints per device is on average 200 in the generation process. Models. We test FedDyn, SCAFFOLD, FedAvg and FedProx using a multiclass logistic classification model with cross entropy loss. We keep batch size to be 10, weight decay to be 10 -5 . We test learning rates in [1, .1] and epochs in [1, 10, 50] for all three algorithms. α parameter of FedDyn is chosen among [.1, .01, .001]; K parameter of SCAFFOLD is searched in [20, 200, 1000] which corresponds to the same amount of computation using above epoch list; and µ regularization hyperparameter of FedProx in [0.01, .0001]. Table 6 reports the number models transmitted relative to one round of FedAvg to achieve the target training loss for best hyperparameter selection in various settings with 10% device participation. As shown, FedDyn leads to communication savings in each of the settings in range 1.1× to 7.6×.

A.2 REAL DATA

Datasets. MNIST, EMNIST-L, CIFAR-10 and CIFAR-100 are used for image classification tasks and Shakespeare dataset is used for a next character prediction task. The image size is (1 × 28 × 28) in MNIST and EMNIST; (3 × 32 × 32) in CIFAR-10 and CIFAR-100 with overall 10 classes in MNIST and CIFAR-10; 62 classes in EMNIST; and 100 classes in CIFAR-100. We choose the first 10 letters from the letter section of EMNIST (named it as EMNIST-L) similar to (Li et al., 2020a) work. Features in Shakespeare dataset consists of 80 characters and labels are the following characters. Overall, there are 80 different labels for datapoints. We use the usual train and test splits for MNIST, EMNIST-L, CIFAR-10 and CIFAR-100. The number of training and test samples of the benchmark datasets are summarized in Table 3 . To generate IID splits, we randomly divide training datapoints and assign them to devices. For non-IID splits, we utilize the Dirichlet distribution as in (Yurochkin et al., 2019) . Firstly, a vector of size equal to the number of classes are drawn using Dirichlet distribution for each device. These vectors correspond to class priors per devices. Then one label is sampled based on these vectors for each device and an image is sampled without replacement based on the label. This process is repeated until all datapoints are assigned to devices. The procedure allows the label ratios of each device to follow a Dirichlet distribution. The hyperparameter of Dirichlet distribution corresponds to statistical heterogeneity level in the device datapoints. Overall, for a 100 device experiment, each device has 600, 480, 500 and 500 datapoints in MNIST, EMNIST-L, CIFAR-10 and CIFAR-100 respectively. For these datasets, three different federated settings are generated including an IID and two non-IID Dirichlet settings with .6 and .3 priors. Figure 3 shows the heterogeneity levels for MNIST dataset in these different settings. The amount of most occurred class labels that consume 40%, 60% and 80% of device data are shown in the histogram plots. For example, every class label is equally represented in IID setting hence 4, 6 and 8 classes occupy 40%, 60%, and 80% of the local datapoints for each device. If we consider non-IID settings, we see 80% of local data belongs to mostly 4 or 5 different classes for Dirichlet .6; and 3 or 4 different classes for Dirichlet .3 settings. To generate unbalanced data, we sample datapoint amounts from a lognormal distribution. Controlling the variance of lognormal distribution gives unbalanced data per devices. For instance, in CIFAR-10, balanced and unbalanced data settings have standard deviation of data amounts among devices as 0 and 0.3 respectively. LEAF (Caldas et al., 2018) is used to generate the Shakespeare dataset used in this work. The LEAF framework allows to generate IID as well as non-IID federated settings. The non-IID dataset is the natural split of Shakespeare where each device corresponds to a role and the local dataset contains this role's sentences. The IID dataset is generated by combining the sentences from all roles and randomly dividing them into devices. In this work, we consider 100 devices and restrict number of datapoints per device to 2000. Models. We use fully connected neural network architectures for MNIST and EMNIST-L. Both models take input images as a vector of 784 dimensions followed by 2 hidden layers and a final softmax layer. The number of neurons in the hidden layers are 200 and 100 for MNIST and EMNIST-L respectively. These models achieve 98.4% and 95.0% test accuracy in MNIST and EMNIST-L if trained on datapoints from all devices. The model considered for MNIST is the same model used in original FedAvg work (McMahan et al., 2017) . For CIFAR-10 and CIFAR-100, we use a CNN consisting of two convolutional layers with 64 5 × 5 filters, two 2 × 2 max pooling layers, two fully connected layers with 394 and 192 neurons, and finally a softmax layer. The models achieve 85.2% and 55.3% test accuracy in CIFAR-10 and CIFAR-100 respectively. Our CNN model is similar to the used for CIFAR-10 in the original FedAvg work (McMahan et al., 2017) , except that we don't use Batch Normalization layers. For the next character prediction task (Shakespeare) , we use an LSTM. The model converts an 80 character long input sequence to a 80 × 8 sequence using an embedding. This sequence is fed to a two layer LSTM with hidden size of 100 units. The output of stacked LSTM is passed to a softmax layer. Overall, this architecture achieves a test accuracy of 50.8% and 51.2% in IID and non-IID settings, respectively, if trained on data from all devices. We report both IID and non-IID performance here because the datasets are randomly regenerated out of the whole Shakespeare writing hence train and test split is different for both cases. This Neural Network model is the same model used in the original FedProx study (Li et al., 2020a) . In passing, we note here that, we are not after state of the art model performances for these datasets, our aim is to compare the performances of these models in federated setting using FedDyn and other baselines. Hyperparameters. We consider different hyperparameter configurations for different setups and datasets. For all the experiments, we fix batch size as 50 for MNIST, CIFAR-10, CIFAR-100 and EMNIST-L datasets and as 100 for Shakespeare dataset. We note here that µ, α and K hyperparameters are used only in FedProx, FedDyn and SCAFFOLD respectively. K is the equivalent of epoch for SCAFFOLD algorithm and we searched K values to have the same amount of local computation as in other methods. For example, if each device has 500 datapoints, batch size is 50 and epoch is 10, local devices apply 100 SGD steps which is equivalent to K being 100. For the partial participation, 100 devices, balanced data setup, the selected configuration for FedAvg is .1 learning rate and 10 epoch; for FedProx is .1 learning rate and .0001 µ; for FedDyn is .1 learning rate, 50 epoch and .01 α; and for SCAFFOLD is .1 learning rate and 600 K for all IID and Dirichlet settings except that α is chosen to be .03 for 10% IID setting. 0.998 learning rate decay per communication round is used and weight decay of 10 -4 is applied to prevent overfitting for all methods. For the centralized model, we choose learning rate as .1, epoch as 150 and learning rate is halved in every 50 epochs. EMNIST-L. We used similar hyperparameters as in MNIST dataset. The configuration for FedAvg is .1 learning rate and 20 epoch; for FedProx is .1 learning rate and 10 -4 µ; for FedDyn is .1 learning rate, 50 epoch and 0.005 α; and for SCAFFOLD is .1 learning rate and 500 K for all IID and Dirichlet full participation settings. The selected configuration for FedAvg is .1 learning rate and 10 epoch; for FedProx is .1 learning rate and .0001 µ; for FedDyn is .1 learning rate, 50 epoch; and for SCAFFOLD is .1 learning rate and 500 K for all IID and Dirichlet partial settings. α is chosen to be .003 for 10% and 1% IID; .005 for 10% Dirichlet .6 and 1% Dirichlet .3 ; .001 for 1% Dirichlet .6 and .01 for 10% Dirichlet .3 settings. 0.998 learning rate decay per communication round is used and weight decay of 10 -4 is applied to prevent overfitting for all methods. For the centralized model, we choose learning rate as .1, epoch as 150 and learning rate is halved in every 50 epochs. CIFAR-10. The same hyperparameters are applied to all the CIFAR-10 experiments, including: 0.1 for learning rate, 5 for epochs, and 10 -3 for weight decay. The learning rate decay is selected from the range of [0.992, 0.998, 1.0]. The α value is selected from the range of [10 -3 , 10 -2 , 10 -1 ] for FedDyn. The µs value is selected from the range of [10 -2 , 10 -3 , 10 -4 ]. For the centralized model, we choose learning rate as .1, epoch as 500 and learning rate decay as .992. CIFAR-100. The same hyperparameters are applied to the CIFAR-100 experiments with 100 devices. including: 0.1 for learning rate, 5 for epochs, and 10 -3 for weight decay. The learning rate decay is selected from the range of [0.992, 0.998, 1.0]. The α value is selected from the range of [10 -3 , 10 -2 , 10 -1 ] for FedDyn. The µs value is selected from the range of [10 -2 , 10 -3 , 10 -4 ]. As for 500 device, balanced data, 10% participation, IID setup, .1 learning rate, .0001 µ, 10 -3 weight decay applied. Epochs in [2, 5] and corresponding Ks in [4,10] searched. αs in [.1, .01, .001] are considered for FedDyn. Epoch of 2 is selected for FedDyn, FedAvg and FedProx, K of 4 is selected for SCAFFOLD. .01 α value is selected for FedDyn. The same parameters are chosen for 500 device, balanced data, 10% participation, Dirichlet .3 setup. As for 100 device, unbalanced data, 10% participation, IID and Dirichlet .3 settings, epoch of 2 is selected for FedDyn, FedAvg and FedProx, K of 20 is selected for SCAFFOLD. .1 α value is applied for FedDyn. .0001 µ is used in FedProx. For the centralized model, we choose learning rate as .1, epoch as 500 and learning rate decay as .992. Shakespeare. As for 100 devices, balanced data, full participation setup, the hyperparameters are searched with all combinations of learning rate in [1], epochs in [1, 5], Ks in [20, 100], µs in [.01, .0001] and αs in [.001, .009, .01, .015]. Weight decay of 10 -4 is applied to prevent overfitting and no learning rate decay across communications rounds is used. The selected configuration for FedAvg is 1 learning rate and 5 epoch; for FedProx is 1 learning rate, 5 epoch and .0001 µ; for FedDyn is 1 learning rate, 5 epoch and .009 α; and for SCAFFOLD is 1 learning rate and 100 K in IID and non IID settings. For the partial participation, 100 devices, balanced data setup, we choose 1 learning rate and 5 epoch for FedAvg; 1 learning rate, 5 epoch and .0001 µ for FedProx; 1 learning rate and 100 K for SCAFFOLD; and 1 learning rate and 5 epoch for FedDyn in all cases. α is .015 and .001 for 10% and 1% settings respectively. No learning rate decay is applied for 10% settings and a decay of .998 is applied for 1% settings. Weight decay of 10 -4 is applied to prevent overfitting. For the centralized model, we choose learning rate as 1, epoch as 150 and learning rate is halved in every 50 epochs. Additionally, we performed gradient clipping to prevent overflow in weights for all methods. We found out that, this increases stability of algorithms. Convergence Plots. We give convergence plots of experiments. The convergence plots of moderate and large number of devices in different device distributions are shown in Figure 4 and 5 for CIFAR-10 and CIFAR-100 datasets. Similarly, convergence curves of different participation levels and distributions are plotted in Figure 6 , 7, 8, 9 and 10 for all datasets. Finally, Figure 12 and 13 show convergence plots for balanced data and unbalance data in different device distributions. We emphasize that convergence curves show accuracy achieved with respect to rounds communicated. However, the metric we want to minimize, the amount of information transmitted, is not the same as number of communication rounds. For instance, SCAFFOLD transmits two models including state of devices per communication round. This difference is accounted in the tables. We observed that averaging all device models gives more stable convergence curves hence we report the performance of the average model from all devices in each communication round. We note that we do not modify the algorithms, this part is only for reporting purposes. Additional to experiments stated, we test our algorithm with a more complex model. We consider ResNet18 (He et al., 2016) structure on CIFAR-10 IID, 100 devices, balanced data, 10% participation setting. Batch normalization layers have inherent statistics which can be problematic in FL. Therefore, we use group normalization (Wu & He, 2018) instead of Batch normalization in ResNet18. The convergence curves are shown in Figure 11 . FedDyn still outperforms the baseline methods in a higher capacity model setup.

A.3 α SENSITIVITY ANALYSIS OF FEDDYN

α is an important parameter of FedDyn. Indeed, it is the only hyperparameter of the algorithm when devices have access to an optimization solver. In theory, α balances two problem dependent constants as shown in Theorem 2, Theorem 3 and Theorem 4. Consequently, optimal value of α depends on these constants. Since these constants are independent of T , the value of α does not asymptotically affect convergence rate. To test sensitivity, we consider CIFAR-10, IID, 100 devices, 10% participation setting. Figure 1a shows convergence plots for different α configurations while keeping all other parameters constant in FedDyn. Figure 1b presents the best achieved test accuracy with respect to different α values. We see that best test performance is obtained when α = 10 -1 . We note that all configurations converge, but some of them converges to a better stationary points. This aligns with the theory because we guarantee convergence to a stationary point. In this work, we aim to solve FL problem with four principle characteristic which are partial participation due to unreliable communication links, massive number of devices, heterogeneous device data and unbalanced data amounts per device. Partial participation is a critical property, because, it is inconceivable that we will not be in a situation where we have all devices participating in each round. However, FedSplit does not support partial participation. Nevertheless, we adapt FedSplit to partial participation setting with the following changes. If a device is not active in the current round, its model z t+1 k = z t k and its intermediate state z t+ 1 2 k = z t-1+ 1 2 k are frozen. For the server model, we have two options. First option is to keep the server model as average of all device models, x t = 1 m k∈m z t k , which is named as FedSplit All. Second option is to have the server model as the average of only current round's active devices x t = 1 |Pt| k∈Pt z t k , which is named as FedSplit Act. In passing, we do not claim that these modifications are optimal. For empirical evaluation, we consider CIFAR-10, 100 devices, 100% and 10% participation settings. Figure 2a and 2b show comparison between FedSplit and FedDyn for 100% and 10% participation levels respectively. FedSplit All and FedSplit Act are the same in full participation setting hence shown as one method. We observe that FedDyn performs better than FedSplit in both cases. We see that FedSplit All where the server model averages all device models is significantly underperforming than FedSplit Act where the server only averages active devices. This is due to the fact that the server model is too slow to change when all devices are averaged because most of the devices are the same across consecutive rounds. We further note that it might not be easy to get convergence theory of FedSplit in the partial participation setting.  1. L k is L smooth if ∇L k (x) -∇L k (y) ≤ L x -y ∀x, y Smoothness implies the following quadratic bound, L k (y) ≤ L k (x) + ∇L k (x), y -x + L 2 y -x 2 ∀x, y If {L k } m k=1 s are convex and L smooth we have 1 2Lm k∈[m] ∇L k (x) -∇L k (x * ) 2 ≤ (x) -(x * ) ∀x (5) -∇L k (x), z -y ≤ -L k (z) + L k (y) + L 2 z -x 2 ∀x, y, z where (x) = 1 m m k=1 L k (x) and ∇ (x * ) = 0. We state convergence as, Theorem 2. For convex and L smooth {L k } m k=1 functions and α ≥ 25L, Algorithm 1 satisfies E 1 T T -1 t=0 γ t -(θ * ) ≤ 1 T   10α θ 0 -θ * 2 + 100 m P 1 α   1 m k∈[m] ∇L k (θ * ) 2     = O 1 T where γ t = 1 P k∈Pt θ t k , θ * = arg min θ (θ). If α = 30L m P , we get the statement in Theorem 1. Throughout the proof, we utilize similar techniques as in SCAFFOLD (Karimireddy et al., 2019) convergence. We define a set of variables which are useful in the analysis. Algorithm 1 freezes θ k and its gradients if the device is not active. Let's define virtual { θt k } variables as θt k = arg min θ L k (θ) -∇L k (θ t-1 k ), θ + α 2 θ -θ t-1 2 ∀k ∈ [m], t > 0 (7) We see that θt k = θ t k if k ∈ P t and θt k doesn't depend on P t . First order condition in Eq. 7 and in device optimization give θt k -θ t-1 = 1 α (∇L k (θ t-1 k ) -∇L k ( θt k )) ∀k ∈ [m]; θ t k -θ t-1 = 1 α (∇L k (θ t-1 k ) -∇L k (θ t k )) ∀k ∈ P t (8) θ t consists of active device average and gradient parts. Let's express active device average and its relation with the server model as, γ t = 1 P k∈Pt θ t k ; γ t = θ t + 1 α h t (9) Due to linear update of ∇L k , h state in the server becomes as h t = 1 m k∈[m] ∇L k (θ t k ). Let's define some quantities that we would like to control. C t = 1 m k∈[m] E ∇L k (θ t k ) -∇L k (θ * ) 2 , t = 1 m k∈[m] E θt k -γ t-1 2 C t tracks how well local gradients of device models approximate the gradient of optimal model. If models converge to θ * , C t will be 0. t keeps track of how much local models change compared to average of device models from previous round. Again, upon convergence t will be 0. After these definitions, Theorem 2 can be seen as a direct consequence of the following Lemma, Lemma 1. For convex and L smooth {L k } m k=1 functions, if α ≥ 25L, Algorithm 1 satisfies E γ t -θ * 2 + κC t ≤ E γ t-1 -θ * 2 + κC t-1 -κ 0 E (γ t-1 ) -(θ * ) where κ = 8 m P 1 α L+α α 2 -20L 2 , κ 0 = 2 1 α α 2 -20αL-40L 2 α 2 -20L 2 Lemma 1 can be telescoped in the following way, κ 0 E (γ t-1 ) -(θ * ) ≤ E γ t-1 -θ * 2 + κC t-1 -E γ t -θ * 2 + κC t κ 0 T t=1 E (γ t-1 ) -(θ * ) ≤ E γ 0 -θ * 2 + κC 0 -E γ T -θ * 2 + κC T If α ≥ 25L, κ 0 and κ become positive. By definition, we also have C t sequences as positive. Eliminating negative terms on RHS gives, κ 0 T t=1 E (γ t-1 ) -(θ * ) ≤ E γ 0 -θ * 2 + κC 0 Applying Jensen on LHS gives, E 1 T T t=1 γ t-1 -(θ * ) ≤ 1 T 1 κ 0 γ 0 -θ * 2 + κC 0 = O 1 T which proves the statement in Theorem 2. Similar to fundamental gradient descent analysis, γ t -θ * 2 is expressed as γ t -γ t-1 +γ t-1 -θ * 2 and expanded in the proof of Lemma 1. The resulting expression has (γ tγ t-1 ) and γ tγ t-1 2 terms. To tackle these extra terms, we state the following Lemmas and prove long ones at the end. Lemma 2. Algorithm 1 satisfies E γ t -γ t-1 = 1 αm k∈[m] E -∇L k ( θt k ) Proof. E γ t -γ t-1 = E 1 P k∈Pt θ t k -θ t-1 - 1 α h t-1 = E 1 P k∈Pt θ t k -θ t-1 - 1 α h t-1 = E 1 αP k∈Pt ∇L k (θ t-1 k ) -∇L k (θ t k ) -h t-1 = E 1 αP k∈Pt ∇L k (θ t-1 k ) -∇L k ( θt k ) -h t-1 = E   1 αm k∈[m] ∇L k (θ t-1 k ) -∇L k ( θt k ) -h t-1   = 1 αm k∈[m] E -∇L k ( θt k ) where first equation is from definition in Eq. 9. The following equations come from Eq. 8 and θt k = θ t k if k ∈ P t respectively. Fifth equation is due to taking expectation while conditioning on randomness before time t. If conditioned on randomness prior to t, every variable except P t is revealed and each device is selected with probability P m . Last one is due to definition of h t = 1 m k∈[m] ∇L k (θ t k ). Similarly, γ tγ t-1 2 is bounded with the following, Lemma 3. Algorithm 1 satisfies E γ t -γ t-1 2 ≤ t Proof. E γ t -γ t-1 2 =E 1 P k∈Pt θ t k -γ t-1 2 ≤ 1 P E k∈Pt θ t k -γ t-1 2 = 1 P E k∈Pt θt k -γ t-1 2 = 1 P P m k∈[m] E θt k -γ t-1 2 = t where first equality comes from Eq. 9. The following inequality applies Jensen. Remaining relations are due to θt k = θ t k if k ∈ P t , taking expectation by conditioning on randomness before time t and definition of t . We need to further bound excess t term arising in Lemma 3. We introduce two more Lemmas to handle this term. Lemma 4. For convex and L smooth {L k } m k=1 functions, Algorithm 1 satisfies 1 -4L 2 1 α 2 t ≤ 8 1 α 2 C t-1 + 8L 1 α 2 E (γ t-1 ) -(θ * ) Lemma 5. For convex and L smooth {L k } m k=1 functions, Algorithm 1 satisfies C t ≤ 1 - P m C t-1 + 2L 2 P m t + 4L P m E (γ t-1 ) -(θ * ) E (γ t-1 ) -(θ * ) terms constitute LHS of the telescopic sum. Let's express γ t -θ * 2 term as, E γ t -θ * 2 =E γ t-1 -θ * + γ t -γ t-1 2 =E γ t-1 -θ * 2 + 2E γ t-1 -θ * , γ t -γ t-1 + E γ t -γ t-1 2 =E γ t-1 -θ * 2 + 2 αm k∈[m] E γ t-1 -θ * , -∇L k ( θt k ) + E γ t -γ t-1 2 ≤E γ t-1 -θ * 2 + 2 αm k∈[m] E L k (θ * ) -L k (γ t-1 ) + L 2 θt k -γ t-1 2 + E γ t -γ t-1 2 =E γ t-1 -θ * 2 - 2 α E (γ t-1 ) -(θ * ) + L α t + E γ t -γ t-1 2 (10) where we first expand the square term and use Lemma 2. Following inequality is due to Inq. 6. Let's scale Lemma 4 and 5 with α L+α α 2 -20L 2 and 8 m P 1 α L+α α 2 -20L 2 respectively. We note that the coefficients are positive due to the condition on α. Summing Inq. 10, Lemma 3, scaled versions of Lemma 5 and 4 gives the statement in Lemma 1. We give the omitted proofs here. Lemma 6. ∀{v j } n j=1 ∈ R d , triangular inequality satisfies n j=1 v j 2 ≤ n n j=1 v j Proof. Using Jensen we get, 1 n n j=1 v j 2 ≤ 1 n n j=1 v j 2 . Multiplying both sides with n 2 gives the inequality. Lemma 7. Algorithm 1 satisfies E h t 2 ≤ C t Proof. E h t 2 =E 1 m k∈[m] ∇L k (θ t k ) 2 = E 1 m k∈[m] ∇L k (θ t k ) -∇L k (θ * ) 2 ≤ 1 m k∈[m] E ∇L k (θ t k ) -∇L k (θ * ) 2 = C t First equality is due to server update rule of h vector; second adds (∇ (θ * ) = 0); third applies Jensen Inq.; and last one is the definition of C t . Proof of Lemma 4 t = 1 m k∈[m] E θt k -γ t-1 2 = 1 m k∈[m] E θt k -θ t-1 - 1 α h t-1 2 = 1 α 2 1 m k∈[m] E ∇L k (θ t-1 k ) -∇L k ( θt k ) -h t-1 2 = 1 α 2 1 m k∈[m] E ∇L k (θ t-1 k ) -∇L k (θ * ) + ∇L k (θ * ) -∇L k (γ t-1 ) + ∇L k (γ t-1 ) -∇L k ( θt k ) -h t-1 2 ≤ 4 α 2 1 m k∈[m] E ∇L k (θ t-1 k ) -∇L k (θ * ) 2 + 4 α 2 1 m k∈[m] E ∇L k (γ t-1 ) -∇L k (θ * ) 2 + 4 α 2 1 m k∈[m] E ∇L k ( θt k ) -∇L k (γ t-1 ) 2 + 4 α 2 E h t-1 2 ≤ 4 α 2 1 m k∈[m] E ∇L k (θ t-1 k ) -∇L k (θ * ) 2 + 4 α 2 1 m k∈[m] E ∇L k (γ t-1 ) -∇L k (θ * ) 2 + 4 α 2 1 m k∈[m] E ∇L k ( θt k ) -∇L k (γ t-1 ) 2 + 4 α 2 C t-1 ≤ 8 α 2 C t-1 + 4L 2 α 2 t + 8L α 2 E (γ t-1 ) -(θ * ) where first and second come from Eq. 9 and 8. Following inequalities come from Lemma 6, 7, smoothness and Inq. 5. Rearranging terms gives the Lemma.

Proof of Lemma 5

C t = 1 m k∈[m] E ∇L k (θ t k ) -∇L k (θ * ) 2 = 1 - P m 1 m k∈[m] E ∇L k (θ t-1 k ) -∇L k (θ * ) 2 + P m 1 m k∈[m] E ∇L k ( θt k ) -∇L k (θ * ) 2 = 1 - P m C t-1 + P m 1 m k∈[m] E ∇L k ( θt k ) -∇L k (γ t-1 ) + ∇L k (γ t-1 ) -∇L k (θ * ) 2 ≤ 1 - P m C t-1 + 2P m 1 m k∈[m] E ∇L k ( θt k ) -∇L k (γ t-1 ) 2 + 2P m 1 m k∈[m] E ∇L k (γ t-1 ) -∇L k (θ * ) 2 ≤ 1 - P m C t-1 + 2L 2 P m t + 2P m 1 m k∈[m] E ∇L k (γ t-1 ) -∇L k (θ * ) 2 ≤ 1 - P m C t-1 + 2L 2 P m t + 4LP m E (γ t-1 ) -(θ * ) where first equality comes from taking expectation with respect to P t ; second equality comes from definition of C t . Inequalities follow from Lemma 6, smoothness and Inq. 5 respectively.

B.2 STRONGLY CONVEX ANALYSIS

We state convergence for µ strongly convex and L smooth {L k } m k=1 functions as, Theorem 3. For µ strongly convex and L smooth {L k } m k=1 functions and α ≥ max 5 m P µ, 30L , Algorithm 1 satisfies E 1 R T -1 t=0 r t γ t -(θ * ) ≤ 1 r T -1   20α θ 0 -θ * 2 + 400 m P 1 α   1 m k∈[m] ∇L k (θ * ) 2     where γ t = 1 P k∈Pt θ t k , r = 1 + µ α , R = T -1 t=0 r t , θ * = arg min θ (θ). If α = max 5 m P µ, 30L we get the statement in Theorem 1. We will use the same { θt k }, γ t , C t , t variables defined in Eq. 7, 8, 9. With these definitions in mind, Theorem 3 can be seen as a direct consequence of the following Lemma, Lemma 8. For µ strongly convex and L smooth {L k } m k=1 functions, if α ≥ max 5 m P µ, 30L , Algorithm 1 satisfies r E γ t -θ * 2 + κC t ≤ E γ t-1 -θ * 2 + κC t-1 -κ 0 E (γ t-1 ) -(θ * ) where κ = 8m(L+α) z , κ 0 = 2α 3 P +2α 2 P µ-2α 2 mµ-40α 2 LP -80αL 2 P -40αLP µ+8αLmµ+16L 2 mµ-80L 2 P µ αz , z = α 3 P + α 2 P µ -α 2 mµ -20αL 2 P + 4L 2 mµ -20L 2 P µ, r = 1 + µ α . Let's multiply Lemma 8 with r t-1 and telescope as, κ 0 r t-1 E (γ t-1 ) -(θ * ) ≤ r t-1 E γ t-1 -θ * 2 + κC t-1 -r t E γ t -θ * 2 + κC t κ 0 T t=1 r t-1 E (γ t-1 ) -(θ * ) ≤ E γ 0 -θ * 2 + κC 0 -r T E γ T -θ * 2 + κC T If α ≥ max 5 m P µ, 30L , κ 0 and κ become positive. Dividing both sides with R = T -1 t=0 r t and eliminating negative terms on RHS gives, κ 0 1 R T t=1 r t-1 E (γ t-1 ) -(θ * ) ≤ 1 R E γ 0 -θ * 2 + κC 0 Applying Jensen on LHS gives, E 1 R T t=1 r t-1 γ t-1 -(θ * ) ≤ 1 R 1 κ 0 γ 0 -θ * 2 + κC 0 We have 1 R = r-1 r T -1 ≤ 1 r T -1 . Combining two inequalities, we get, E 1 R T t=1 r t-1 γ t-1 -(θ * ) ≤ 1 r T -1 1 κ 0 γ 0 -θ * 2 + κC 0 which proves the statement in Theorem 3. The proof of Lemma 8 is similar to the convex analysis. We generalize In. 6 to strongly convex functions for{L k } m k=1 s are µ strongly convex and L smooth as, -∇L k (x), z -y ≤ -L k (z) + L k (y) + L 2 z -x 2 - µ 2 x -y 2 ∀x, y, z Since strongly convex functions are convex functions and we only change In. 6, we can directly use Lemma 2, 3, 4 and 5. Let's rewrite γ tθ * 2 expression as, E γ t -θ * 2 =E γ t-1 -θ * + γ t -γ t-1 2 =E γ t-1 -θ * 2 + 2E γ t-1 -θ * , γ t -γ t-1 + E γ t -γ t-1 2 =E γ t-1 -θ * 2 + 2 αm k∈[m] E γ t-1 -θ * , -∇L k ( θt k ) + E γ t -γ t-1 2 ≤ 2 αm k∈[m] E L k (θ * ) -L k (γ t-1 ) + L 2 θt k -γ t-1 2 - µ 2 θt k -θ * 2 + E γ t-1 -θ * 2 + E γ t -γ t-1 2 =E γ t-1 -θ * 2 - 2 α E (γ t-1 ) -(θ * ) + L α t - µ α 1 m k∈[m] E θt k -θ * 2 + E γ t -γ t-1 2 ≤E γ t-1 -θ * 2 - 2 α E (γ t-1 ) -(θ * ) + L α t - µ α E γ t -θ * 2 + E γ t -γ t-1 2 (12) where we first expand the square term and use Lemma 2. Following inequalities use Inq. 11 and Lemma 9. Rearranging In. 12 gives, 1 + µ α E γ t -θ * 2 ≤ E γ t-1 -θ * 2 - 2 α E (γ t-1 ) -(θ * ) + L α t + E γ t -γ t-1 2 Let's define z = α 3 P + α 2 P µ -α 2 mµ -20αL 2 P + 4L 2 mµ -20L 2 P µ. Let's scale Lemma 4 and 5 with α(L+α)(P α+P µ-mµ) z and 8m(L+α)(α+µ) αz respectively. We note that the coefficients are positive due to the condition on α. Summing Inq. 13, Lemma 3, scaled versions of Lemma 5 and 4 gives the statement in Lemma 8. We give Lemma 9 and its proof here. Lemma 9. Algorithm 1 satisfies - 1 m k∈[m] E θt k -θ * 2 ≤ -E γ t -θ * 2 Proof. E γ t -θ * 2 =E 1 P k∈Pt θ t k -θ * 2 ≤ 1 P E k∈Pt θ t k -θ * 2 = 1 P E k∈Pt θt k -θ * 2 = 1 m k∈[m] E θt k -θ * 2 where first equality comes from Eq. 9. The following inequality applies Jensen. Remaining relations are due to θt k = θ t k if k ∈ P t and taking expectation by conditioning on randomness before time t. Rearranging the terms gives the statement in Lemma. If α = 30L m P , we get the statement in Theorem 1. We will use { θt k } and γ t variables as defined Eq. 7, 8, 9. Since we aim to find a stationary in the nonconvex case, let's define a new C t and keep t the same as, C t = 1 m k∈[m] E θ t k -γ t 2 , t = 1 m k∈[m] E θt k -γ t-1 2 Similarly, C t tracks how well local models approximate the current active device average. Upon convergence C t and t will be 0. Theorem 4 can be seen as a direct consequence of the following Lemma, Lemma 10. For L smooth {L k } m k=1 functions, if α ≥ 20L m P , Algorithm 1 satisfies E (γ t ) + κC t ≤ E (γ t-1 ) + κC t-1 -κ 0 E ∇ (γ t-1 ) 2 where κ = 4L 3 P α+L α 2m-P z , κ 0 = 1 2α α 2 P 2 -4αLP 2 -32L 2 m 2 -16L 2 P m-24L 2 P 2 z , z = α 2 P 2 -32L 2 m 2 + 16L 2 P m -20L 2 P 2 . Lemma 10 can be telescoped as, κ 0 E ∇ (γ t-1 ) 2 ≤ E (γ t-1 ) - * + κC t-1 -E (γ t ) - * + κC t κ 0 T t=1 E ∇ (γ t-1 ) 2 ≤ E (γ 0 ) - * + κC 0 -E (γ T ) - * + κC T If α ≥ 20L m P , we have κ 0 and κ as positive quantities. By definition, we also have C t sequences as positive. Eliminating negative terms on RHS and summing over time give, E 1 T T t=1 ∇ (γ t-1 ) 2 ≤ 1 T 1 κ 0   (θ 0 ) - * + κ   1 m k∈[m] E θ 0 k -θ 0 2     which proves the statement in Theorem 2. The proof of Lemma 10 builds on Inq. 4 where we upper bound (γ t ) with (γ t-1 ). Inq. 4 gives (γ tγ t-1 ) and ∇ (γ t-1 ) on RHS. We state a set of Lemmas to tackle these terms. We note here that Lemma 2 and 3 holds since t is the same as in convex case. To bound excess t term, we introduce two more Lemmas as Lemma 11. For L smooth {L k } m k=1 functions, Algorithm 1 satisfies Using Inq. 4 we get, 1 -4L 2 1 α 2 t ≤ 8L 2 1 α 2 C t-1 + 4 1 α 2 E ∇ (γ E (γ t ) -E (γ t-1 ) - L 2 E γ t -γ t-1 2 ≤ E ∇ (γ t-1 ), γ t -γ t-1 = 1 α E   ∇ (γ t-1 ), 1 m k∈[m] -∇L k ( θt k )   ≤ 1 2α E 1 m k∈[m] ∇L k ( θt k ) -∇L k (γ t-1 ) 2 - 1 2α E ∇ (γ t-1 ) 2 ≤ 1 2α 1 m k∈[m] E ∇L k ( θt k ) -∇L k (γ t-1 ) 2 - 1 2α E ∇ (γ t-1 ) ≤ L 2 2α t - 1 2α E ∇ (γ t-1 ) 2 (14) where first equality uses Lemma 2. The following inequalities are due to a, b ≤ 1 2 b+a 2 -1 2 a 2 , Jensen Inq. and smoothness. Let's define z = α 2 P 2 -32L 2 m 2 + 16L 2 P m -20L 2 P 2 and scale Lemma 12, 3 and 11 with z 0 = 4L 3 P α+L α 2m-P z , z 1 = L 2 + z 0 2m P , and z 2 = LP 2 α 2 L+α z respectively. We note that the coefficients are positive due to the condition on α. Summing Inq. 14, scaled versions of Lemma 3, 11 and 12 gives the statement in Lemma 10. Lastly, we note that the convergence analysis is given with respect to L2 norm in the gradients. L2 norm arises in the analysis because In. 4 has L2 norm due to our definition of smoothness. Furthermore, the analysis can be extended to different norms. To do so, smoothness needs to be defined with respect to primal and dual norms as in Eq. 3 in Nesterov et al. (2020) . We give the omitted proofs here.

Proof of Lemma 11

t = 1 m k∈[m] E θt k -γ t-1 2 = 1 m k∈[m] E θt k -θ t-1 - 1 α h t-1 2 = 1 α 2 1 m k∈[m] E ∇L k (θ t-1 k ) -∇L k ( θt k ) -h t-1 2 = 1 α 2 1 m k∈[m] E ∇L k (θ t-1 k ) -∇L k (γ t-1 ) + ∇L k (γ t-1 ) -∇L k ( θt k ) -∇ (γ t-1 ) + ∇ (γ t-1 ) -h t-1 2 ≤ 4 α 2 1 m k∈[m] E ∇L k (θ t-1 k ) -∇L k (γ t-1 ) 2 + 4 α 2 1 m k∈[m] E ∇L k (γ t-1 ) -∇L k ( θt k ) 2 + 4 α 2 E ∇ (γ t-1 ) 2 + 4 α 2 E ∇ (γ t-1 ) -h t-1 2 ≤ 4 α 2 1 m k∈[m] E ∇L k (θ t-1 k ) -∇L k (γ t-1 ) 2 + 4 α 2 1 m k∈[m] E ∇L k (γ t-1 ) -∇L k ( θt k ) 2 + 4 α 2 E ∇ (γ t-1 ) 2 + 4 α 2 1 m k∈[m] E ∇L k (γ t-1 ) -∇L k (θ t-1 k ) 2 ≤ 8L 2 α 2 C t-1 + 4L 2 α 2 t + 4 α 2 E ∇ (γ t-1 ) 2 where first, second and third come from definition of t , Eq. 9 and 8. The following inequalities are due to Lemma 6, Jensen Inq. and smoothness. Rearranging terms gives the Lemma. Proof of Lemma 12 where we start with definition of C t . First inequality is due to a + b 2 ≤ (1 + z) a 2 + 1 + 1 z b 2 for z > 0. The following equality takes expectation conditioned on randomness before time t. Since each device is selected with probability P m , θ t k is a random variable that is equal to θt k with probability P m . Otherwise, it is θ t-1 k . Final equality is due to definitions of t and C t . C t = 1 m k∈[m] E θ t k -γ t 2 = 1 m k∈[m] E θ t k -γ t-1 + γ t-1 -γ t 2 ≤ 1 + P 2m -P 1 m k∈[m] E θ t k -γ t-1 2 + 1 + 2m -P P 1 m k∈[m] E γ t -γ t-1 2 = P m 1 + P 2m -P 1 m k∈[m] E θt k -γ t-1 2 + 1 - P m 1 + P 2m -P 1 m k∈[m] E θ t-1 k -γ t-



To see this consider the situation where losses are differentiable. As such stationary points for global empirical loss demand that only the sum of the gradients of device empirical losses are zero, and not necessarily that the individual device gradients are zero. Indeed, in statistically heterogeneous situations, such as where we have heterogeneous dominance of classes, stationary points of local empirical functions do not coincide. As pointed in related work prior works based on SGD implicitly account for the inconsistency by performing inexact minimization, and additional hyperparameter tuning.



As for the 100 devices, balanced data, full participation setup, hyperparameters are searched for all algorithms in all IID and Dirichlet settings for a fixed 100 communication rounds. The search space consists of learning rates in [.1, .01], epochs in[10, 20, 50], Ks in [120, 240, 600], µs in [1, .01, .0001] and αs in [.001, .01, .03, .1]. Weight decay of 10 -4 is applied to prevent overfitting and no learning rate decay across communications rounds is used. The selected configuration for FedAvg is .1 learning rate and 20 epoch; for FedProx is .1 learning rate and .0001 µ; for FedDyn is .1 learning rate, 50 epoch and .01 α; and for SCAFFOLD is .1 learning rate and 600 K for all IID and Dirichlet settings. These configurations are fixed and their performances are obtained for 500 communication rounds.

Figure 1: CIFAR-10 -α sensitivity analysis of FedDyn.

Figure 2: CIFAR-10 -FedSplit and FedDyn comparison in full and 10% participation settings.

Figure 3: MNIST-Histogram of device counts whose 40% (3a), 60% (3b), and 80% (3c) datapoints belong to k classes.

Figure 4: CIFAR-10-Convergence curves for different 100 and 1000 devices in the IID and Dirichlet (.3) settings with 10% participation level and balanced data.

Figure 5: CIFAR-100-Convergence curves for different 100 and 500 devices in the IID and Dirichlet (.3) settings with 10% participation level and balanced data.

Figure 6: CIFAR-10-Convergence curves for participation fractions ranging from 100% to 10% to 1% in the IID, Dirichlet (.6) and Dirichlet (.3) settings with 100 devices and balanced data.

Figure 9: EMNIST-L-Convergence curves for participation fractions ranging from 100% to 10% to 1% in the IID, Dirichlet (.6) and Dirichlet (.3) settings with 100 devices and balanced data.

Figure 10: Shakespeare-Convergence curves for participation fractions ranging from 100% to 10% to 1% in the IID, and non-IID settings with 100 devices and balanced data.

Figure 11: Convergence curves for ResNet18 with 1000 devices and balanced data.

Figure 12: CIFAR-10-Convergence curves for balanced and unbalanced data distributions with 10% participation level as well as 100 devices in the IID and Dirichlet (.3) settings.

NONCONVEX ANALYSIS We state convergence for nonconvex L smooth {L k } m k=1 s as, Theorem 4. For nonconvex and L smooth {L k } m k=1 functions and α ≥ 20L m P , Algorithm 1 satisfies

Number of parameters transmitted relative to one round of FedAvg to reach target test accuracy for moderate and large number of devices in IID and Dirichlet .3 settings. SCAFFOLD communicates the current model and its associated gradient per round, while others communicate only the current model. As such number of rounds for SCAFFOLD is one half of those reported.

Number of parameters transmitted relative to one round of FedAvg to reach target test accuracy for 100% and 10% participation regimes in the IID, non-IID settings. SCAFFOLD communicates the current model and its associated gradient per round, while others communicate only the current model. As such number of rounds for SCAFFOLD is one half of those reported.



Number of parameters transmitted relative to one round of FedAvg to reach target test accuracy for balanced data and unbalanced data in IID and Dirichlet .3 settings with 10% participation. SCAFFOLD communicates the current model and its associated gradient per round, while others communicate only the current model. As such number of rounds for SCAFFOLD is one half of those reported.

Number of parameters transmitted relative to one round of FedAvg to reach target test accuracy for 1% participation regime in the IID, non-IID settings. SCAFFOLD communicates the current model and its associated gradient per round, while others communicate only the current model. As such number of rounds for SCAFFOLD is one half of those reported.

Number of parameters transmitted relative to one round of FedAvg to reach target test accuracy for convex synthetic problem in different types of heterogeneity settings. SCAFFOLD communicates the current model and its associated gradient per round, while others communicate only the current model. As such number of rounds for SCAFFOLD is one half of those reported.

Lemma 12. For L smooth {L k } m k=1 functions, Algorithm 1 satisfies

ACKNOWLEDGEMENTS

This research was supported by a gift from ARM corporation (DA), and CCF-2007350 (VS), CCF-2022446(VS), CCF-1955981 (VS), the Data Science Faculty Fellowship from the Rafik B. Hariri Institute.

