LOCAL COEFFICIENT OPTIMIZATION IN FEDERATED LEARNING

Abstract

Federated learning emerges as a promising approach to build a large-scale cooperative learning system among multiple clients without sharing their raw data. However, given a specific global objective, finding the optimal sampling weights for each client remains largely unexplored. This is particularly challenging when clients' data distributions are non-i.i.d. and clients partially participant. In this paper, we model the above task as a bi-level optimization problem which takes the correlations among different clients into account. We present a doubleloop primal-dual-based algorithm to solve the bi-level optimization problem. We further provide rigorous convergence analysis for our algorithm under mild assumptions. Finally, we perform extensive empirical studies under both toy examples and learning models from real datasets to verify the effectiveness of the proposed method.

1. INTRODUCTION

Federated learning has achieved high success in the large-scale cooperative learning system without sharing raw data. However, due to the large number of devices involved in the learning system, it is hard to check data quality (e.g., noise level) for individual devices. Further, it will degrade the model's ability when it is trained with bad-quality data. To eliminate the influence of the 'bad' devices, it is natural to reduce the weight of those devices. In most popular federated training algorithms (e.g., FedAvg (Li et al., 2019) ), all devices are weighted the same or with respect to the number of data points it holds. Borrowing the formulation of federated algorithms, we introduce a new variable x to control the weight of each device which is the coefficient of each local objective. We introduce a validation set in the server to validate whether coefficients improve the model. We formulate the whole problem as a bi-level optimization state in the following: min x f 0 (w * (x)) s.t. w * (x) ∈ arg min w N i=1 x (i) f i (w) x ∈ X = {x|x ≥ 0, ∥x∥ 1 = 1}, To solve problem (1), Kolstad & Lasdon (1990) propose an algorithm that calculates the gradient of x directly, i.e., ∂f 0 (w * (x)) ∂x (i) = -∇ w f 0 (w * (x)) ⊤ N i=1 ∇ 2 w f i (w * (x)) -1 ∇ w f i (w * (x)). But, due to the large parameter dimension of w, it is impossible to take the inverse of the Hessian or solve the linear system related to the Hessian. Meanwhile, due to a large amount of data in the local device, it is hard to directly estimate the gradient or the Hessian of the local function f i . Only stochastic gradient and stochastic hessian can be accessed. Thus, Ghadimi & Wang (2018) propose the BSA algorithm where the inverse of Hessian is approximated by a series of the power of Hessian (using K k=0 (I -ηH) k to approximate 1 η H -1 with certain η). Khanduri et al. (2021) propose SUSTAIN algorithm for solving stochastic bi-level optimization problems with smaller sample complexity. Similar to Ghadimi & Wang (2018) , they need an extra loop of the Hessian-vector product to approximate the product of the Hessian inverse with some vector. However, it is known that for constraint optimization, with the descent direction, the algorithm will not converge to the optimal point or even to the first-order stationary point, where the inner product between the descent direction and gradient is larger than 0 (Bertsekas, 2009) . Therefore, getting an accurate approximation of the hessian inverse is essential. With the series of the power, it has to start with k = 0 and apply several iterations to get an accurate approximation, increasing the computation and communication in federated learning. Fortunately, by noticing the KKT condition, information of the hessian inverse can be embedded into dual variables. Based on the smoothness of the objective, we can give a good initialization of dual variables rather than start with the same initialization in each iteration (like I in the series approximation). Thus, we propose a primal-dual-based algorithm to solve problem (1). Further, to solve a constrained optimization with non-linear equality constraints, adding the norm square of the equality constraint as an augmented term may not give the convexity to the augmented Lagrange function. As a result, it is hard for the min-max optimization algorithm to find the stationary point. Instead, with the assumption in Ghadimi & Wang (2018) , the function f i 's are assumed to be strongly convex, adding function f i 's as the augmented term can help introduce convexity and it will not change the stationary point of the min-max problem. Based on this new augmented Lagrange function, we prove that with stochastic gradient descent and ascent, w and λ can converge to the KKT point. Meanwhile, by the implicit function theorem, when w and λ are closed to the stationary point of min-max, the bias of estimating the gradient of x can be reduced to 0. Thus, with the primal-dual algorithm on w and λ and stochastic projected gradient descent on x, we show the convergence of our algorithm. Finally, we compare our algorithm with other algorithms on a toy example and real datasets (MNIST and F-MNIST with Network LeNet-5). The experimental results show that the proposed algorithm can perform well in strongly convex cases and even in some non-convex cases (Neural Networks). We summarize our contributions as follows: • In Federated Learning, we formulate the local coefficient learning problem as a bi-level optimization problem, which gives a way to identify the dataset quality in each local client for some specific task (where a small validation set is given). • In bi-level optimization, we introduce a primal-dual framework and show the convergence of the whole algorithm in the constrained and stochastic setting. • For some specific optimization problems with non-linear constraints, we give a new augmented term. With the new augmented term, the primal variable and dual variable can converge to the KKT point of the original problems.

2.1. PERSONALIZED FEDERATED LEARNING

The most related work in federated learning tasks will be personalized federated learning. Different from these works, we explicitly formulate a bi-level optimization problem. By adding a validation set, it can be more clearly identified the correlation of information from the other devices and from its own.

2.2. STOCHASTIC BI-LEVEL OPTIMIZATION

Bi-level optimization problem has been studied for a long time. One of the simplest cases in bi-level optimization is the singleton case, where the lower-level optimization has a unique global optimal point. Without calculating the inversion of the Hessian matrix of the lower level optimization problem, there are two major algorithms. Franceschi et al. (2017) approximates ∂w * (x) ∂x by ∂w T ∂x where w T is the iterate after T steps gradient descent for the lower optimization problem. Using this method, in each iteration, we need to communicate N (number of local devices) vectors among the server and local devices which is not communication efficient. The other method Ghadimi & Wang (2018) is to approximate (∇ 2 w g(w)) -1 by K i=0 (I -η∇ 2 g(w)) i , where g(w) is the objective function of lowerlevel optimization problem. Although Khanduri et al. (2021) point out that to approximate gradient for upper optimization function, we can get rid of taking the optimal point for lower optimization in each upper-level update optimization, which seems to get rid of double-loop approximation, it still needs a loop for approximating Hessian inverse with series. Guo & Yang (2021) uses SVRG to reduce the noise level of estimating stochastic gradient and Hessian to get better performance. Besides, all of the above works assume smoothness of the local Hessian, but none of them will apply the property directly into the algorithm. Different from the above works, we introduce a primal-dual framework into bi-level optimization, where the dual variable can record the information of Hessian. Also, Shi et al. (2005) ; Hansen et al. (1992) introduce the primal-dual framework, but they stay in quadratic regime or mix integer programming, which is non-trivial to extend the results to federated learning settings.

3. ALGORITHM DESIGN

Assume that each function of f i is a strongly convex function. Then, the optimal solution to the lower optimization problem becomes only a single point. Thus, with the implicit function theorem, we can calculate the gradient of f 0 (w * (x)) with respect to x as follows. Proposition 1. Suppose f i 's are strongly convex functions. Then for each x ∈ X , it holds that ∂f0(w * (x)) ∂x (i) = -∇ w f 0 (w * (x)) ⊤ N j=1 x (j) ∇ 2 w f j (w * (x)) -1 ∇ w f i (w * (x)). With the proposition 1, one can calculate the gradient of x, when w * (x) and the inverse of Hessian are given. However, for large scale problems, none of these can be easily obtained. Fortunately, by noticing the convexity of each function f i , we can replace the first constraint w * (x) ∈ arg min w N i=1 x (i) f i (w) with ∇ N i=1 x (i) f i (w) = 0. For given x, we can formulate the following constrained optimization problem: min w f0(w) s.t. N i=1 x (i) ∇wfi(w) = 0, By introducing the dual variable λ, we can easily get the Lagrange function. To solve the Lagrange function efficiently, we propose the following augmented Lagrange function. L x (w, λ) = f 0 (w) + λ ⊤ N i=1 x (i) ∇ w f i (w) + Γ N i=1 x (i) f i (w). (3) Different from the standard augmented terms, where the norm square of equality constraints is added to achieve strong convexity of the primal problem, we add the summation of f i 's with coefficient x (i) 's. If we use the norm square of the gradient constraint for general strongly convex functions, it will not be strongly convex. Thus, we can not directly adopt the gradient descent ascent algorithm. With the definition, we can obtain the following two propositions directly. Proposition 2. Suppose f i 's are strongly convex functions for i = 1, 2, • • • , N , x (i) ≥ 0 for all i and ∥x∥ 1 = 1. Then, Problem (2) satisfies Linear Independence Constraint Qualification and its KKT conditions can be written as follows: ∇wf0(w) + N i=1 x (i) ∇ 2 fi(w)λ = 0 N i=1 x (i) ∇fi(w) = 0. Proposition 3. Suppose f i 's are strongly convex functions for i = 1, 2, • • • , N , x (i) ≥ 0 for all i and ∥x∥ 1 = 1. Then, the stationary point of min w max λ L x (w, λ) is unique and satisfies the KKT conditions of problem (2). Let ( ŵ * (x), λ * (x)) be the stationary point of min w max λ L x (w, λ). From proposition 2, it holds that ŵ * (x) = w * (x) and ∂f0(w * (x)) ∂x (i) = λ * (x) ⊤ ∇wfi(w * (x)). Thus, with the KKT poiont w * (x) and λ * (x), we can estimate the gradient of x without estimating the inverse of Hessian. However, as λ ⊤ N i=1 x (i) ∇ w f i (w) can be highly non-convex function, which can be harmful to the optimization process. We add an additional constraint on the norm of λ and define the constraint set Λ. Thus, the problem (2) becomes min w max λ∈Λ Lx(w, λ) = f0(w) + λ ⊤ N i=1 x (i) ∇wfi(w) + Γ N i=1 x (i) fi(w). (5) We propose a double loop algorithm for solving problem (1). We show the algorithm in the Algorithm 1 and 2. In inner loop, we solve the augmented Lagrange for K steps. In each step, local client will receive the iterates w t,k and λ t,k . After that, each local client will calculate ∇f i (w t,k ) and ∇f i (w t,k ) based on the back propagation through two independent batches. The term ∇2 f i (w t,k )λ t,k is calculated with auto-differentiable framework (i.e. Pytorch, TensorFlow) or with the closed-form multiplication. Then the local device sends gradient estimation ∇w f i (w t,k ) and the estimated product of Hessian and λ ( ∇2 f i (w t,k )λ t,k ) to the server. For the server, in each step, the server will first send the primal variable (w t,k ) and dual variable (λt, k) to all local clients. Then, the server will receive the estimated gradients and estimated product from some local clients. Because not all devices will stay online in each step, we define a set Active t,k which records the clients that participant the optimization in (t, k) step. With the vectors collected from local clients, the server will calculate the gradient estimator of w t,k and λ t,k with respect to function L xt (w t,k , λ t,k ). And then, w t,k will be updated by a gradient descent step and λ t,k will be updated by a gradient ascent step. Different from local devices, after K inner loop update steps, based on the λ t,K and gradient estimated in each local client, the server will calculate the gradient of x based on equation 4 and perform a projected gradient descent step on x. In addition, if the i th agent is not in Active t,K , we set the gradient of x (i) to be zero. Algorithm 1 The bi-level primal dual algorithm on local device i 1: for t = 1, 2, • • • , T do 2: for k = 1, 2, • • • , K do 3: Receive w t,k , λ t,k from the server; 4: Sample a mini-batch and calculate ∇f i (w t,k );

5:

Sample a mini-batch and calculate ∇f i (w t,k ); 6: Calculate ∇2 w f i (w t,k )λ t,k with back propagation on scalar ∇w f (w t,k )λ t,k ; 7: Send ∇2 f i (w t,k )λ t and ∇f i (w t,k ) to the server; 8: end for 9: end for Remark 1. g x (i) can be calculated in the i-th device and sent to the server, which can reduce the computation in the server and will increase one-round communication with one real number between the server and devices. The rest of the analysis will remain to be the same.

4. THEORETICAL ANALYSIS

In this section, we analyze the convergence property of the proposed algorithm. First, we state some assumptions used in the analysis. (A1) f 0 , f 1 , • • • , f N are lower bounded by f, and f 0 , f 1 , • • • , f N have L 1 Lipschitz gradient. (A2) f 1 , • • • , f N are µ-strongly convex functions. Algorithm 2 The Bi-level primal dual algorithm on the Server 1: Input: Initial x 1 , w 1,1 , λ 1,1 , total iterations: K, T and step size: η w , η λ , η x . 2: for t = 1, 2, • • • , T do 3: for k = 1, 2, • • • , K do 4: Send w t,k , λ t,k to each local device; 5: Receive ∇f i (w t,k ) and ∇2 w f i (w t,k )λ t from Active t,k ; 6: g w = ∇f 0 (w t,k ) + N |Active t,k | i∈Active t,k x (i) t ∇2 w f i (w t,k )λ t,k + Γ ∇f i (w t,k ); 7: w t,k+1 = w t,k -η w g w ; 8: g λ = N |Active t,k | i∈Active t,k x (i) t ∇w f i (w t,k ) ; 9: λ t,k+1 = Π Λ (λ t,k + η λ g λ ); 10: end for 11: g x (i) = N |Active t,K | λ ⊤ t,K ∇w f i (w t,K ) for i ∈ Active t,K ; 12: g x (i) = 0 for i / ∈ Active t,K ; 13: x t+1 = P X (x t -η x g x ); 14: λ t+1,1 = λ t,K+1 ; 15: w t+1,1 = w t,K+1 16: end for 17: Output: x T , W T,K+1 . (A3) f 1 , • • • , f N has L 2 Lipschitz Hessian. (A4) max i∈{0,1,••• ,N } max x∈X ∥∇f i (w * (x))∥ ≤ D w . (A5) Each local estimation is unbiased with bounded variance σ 2 . (A6) Active t,k is independent and sampled from the set of nonempty subset of {1, 2, • • • , N }, where A1),(A2),(A3) are commonly used in the convergence analysis for bi-level optimization problems (Ji et al., 2021; Chen et al., 2021; Khanduri et al., 2021) . Unlike Ji et al. (2021) ; Chen et al. (2021) , where they need to assume f 0 , f 1 , • • • , f N to be L 0 Lipschitz, we assume the gradient norm are bounded at optimal solution. Because for machine learning models, regularization will be add into objective function, makes the norm of the optimal solution not be large. When w * (x) can be bounded by some constant. (A4) is reasonable in practice. Moreover, the Lipschitz assumption on function can directly infer (A4) with D w = L 0 . (A5) is a common assumption used for stochastic gradient methods (Ghadimi et al., 2016) and (A6) extend the assumption in Karimireddy et al. (2020) by giving the probability that a local devices will be chosen instead of uniformly sampling. P (i ∈ Active t,k ) = p for all i ∈ {1, 2, • • • , N }. Remark 2. ( Remark 3. With (A4), D λ = max x∈X ∥λ * (x)∥ is upper bounded by D w /µ. Proposition 4. When Λ = {λ | ∥λ∥ ≤ D λ }, then the stationary point of problem (5) is the KKT point of problem 2. With proposition 3 and 4, the stationary point of problem ( 5) is unique and we denote the stationary point as (w * (x), λ * (x)). To give the convergence of the whole algorithm, firstly, we give the convergence guarantee for the inner loop. Theorem 1. For given x ∈ X , when (A1) to (A6) holds, Γ > D λ L2+L1 µ and η w , η λ = Θ(1/ √ K), when randomly choose k ∈ {1, 2, • • • , K} with equal probability it holds that E λ ⊤ k ∇ w f i (w k) - ∂f 0 (w * (x)) ∂x (i) 2 = O(1/ √ K). Thus, with theorem 1, the gradient of x can be "well" estimated through the inner gradient descent ascent method when the number of inner loop steps is large enough. Then, we can obtain the following convergence result of the outer loop. Theorem 2. Suppose (A1) to (A6) holds, Γ > D λ L2+L1 µ , η w , η λ = Θ(1/ √ K), η x = Θ(1/ √ T ) and randomly choosing k ∈ {1, 2, • • • , K} with equal probability to approximate gradient of x. Define x = arg min y∈X (f 0 (w * (y)) + ρ 2 ∥y -x∥ 2 ) and ∇ρ f 0 (w * (x)) = ρ(x -x) for large ρ,it holds that 1 T T t=1 E∥ ∇ρ f 0 (w * (x t ))∥ 2 = O(1/ √ T + 1/ √ K). Remark 4. To achieve ϵ-stationary point (E∥ ∇ρ f 0 (w * (x t ))∥ 2 ≤ ϵ), O(1/ϵ 4 ) samples are needed in each local client and in the server. Different from the previous works on bilevel optimization(e.g. Ghadimi & Wang (2018) , Khanduri et al. (2021) and Franceschi et al. ( 2017)), we prove the convergence when optimization variable x has a convex constraint.

4.1. PROOF SKETCH OF THEOREM 1

To show the convergence of inner loop, we first construct a potential function for inner loop objective. Define Φ x (w, λ) = L x (w, λ) -2d(λ), where d(λ) = min w L x (w, λ) for given x. The intuition of defining this potential function is that L x (w, λ) is not necessarily decreasing in each iteration, as λ is performing a gradient ascent step. Meanwhile, gradient λ taken is an approximation of gradient of d(λ). Thus, by subtracting d(λ), we can obtain that Φ will decrease during iterations. Therefore, the first thing is to show the lower bound of function Φ. Lemma 1 (Lower bound of Φ). Suppose (A1)-(A4) hold. It holds that Φ x (w, λ) is bounded below by f . The proof of this lemma is basically due to the definition of Φ x (w, λ) and d(λ). Then, similar to the proof of gradient descent, we give a lemma that shows the descent of potential function under certain choices of hyperparameters. Lemma 2 (Potential function descent, proof can be found in Lemma 11 in Appendix). Suppose (A1)-(A6) hold. In addition, we assume Γ > D λ L2+L1 µ , it holds that E[Φ x (w t,k , λ t,k ) -Φ x (w t,k+1 , λ t,k+1 )] ≤ -C 1 E∥∇ w L x (w t,k , λ t,k )∥ 2 -C 2 E[∥λ t -λ * t ∥ 2 ] + C 3 σ 2 , where λ + t = Π Λ (λ t + η λ ∇d(λ t )), C 1 = Θ(η w -η 2 w -η 2 λ -η λ ), C 2 = Θ(η λ ) and C 3 = O(η 2 w + η 2 λ ) Thus, when choosing sufficient small η w and η λ , we can achieve positive C 1 and C 2 . Together with the lower bound of the function Φ, the convergence of the inner algorithm can be shown. Because of the uniqueness of the KKT point, by choosing η w and η λ in order of 1/ √ K, it can be shown that 1 K K k=1 E∥w t,k -w * (xt)∥ 2 = O(1/ √ K), 1 K K k=1 E∥λ t,k -w * (xt)∥ 2 = O(1/ √ K). Therefore, with the convergence rate of w t,k and λ t,k and equation 4, we can easily prove theorem 1.

4.2. PROOF SKETCH OF THEOREM 2

To apply stochastic gradient descent analysis on x, although we have smoothness for function f 0 , f 1 , • • • , f N on w, we need to verify the smoothness of f 0 (w * (x)) with respect to x. Lemma 3 (Convergence of stochastic gradient descent with biased gradient estimation, proof can be found in Lemma 14 in Appendix). Suppose function f (x) is lower bounded by f with L-Lipshitz gradient. g(x) is an unbiased gradient estimator of ∇f (x) satisfying that expected norm of g(x) are bounded by G in domain X for function f . Then with update rule x t+1 = Π X (x t -η x (g(x t ) + ξ t )), where η x = Θ(1/ √ T ), X is a convex set and E∥ξ t ∥ 2 ≤ ϵ 2 . By defining x = arg min y∈X (f (y) + ρ 2 ∥y -x∥ 2 ) and ∇ρ f (x) = ρ(x -x) , where ρ = 2L, then it holds that 1 T E T t=1 ∥ ∇ρ f (x t )∥ 2 = O(1/ √ T + ϵ 2 ). As Lemma 3 suggests, when f 0 (w * (x)) satisfying L-Lipschitz gradient, bounded estimation error and bounded gradient norm, the convergence rate can achieve O(1/ √ T ) with a error term related to estimation error. Theorem 1 shows the estimation error can be bounded by O(1/ √ K). Combining this two results we can prove Theorem 2.

5. EXPERIMENTAL RESULTS

In this section, we compare our algorithm with other bi-level optimization algorithms (BSA (Ghadimi & Wang, 2018) , SUSTAIN (Khanduri et al., 2021) and RFHO (Franceschi et al., 2017) ) in two cases: the toy example and two vision tasks.Further, in vision tasks, agnostic federated learning (AFL) is tested (Mohri et al., 2019) . When k local steps are used in each algorithm, BSA, RFHO, and our algorithm will perform 2kd real number transmission, where d is the dimension of optimization. SUSTAIN will perform (k + 1)d real number transmission. In the vision tasks, they perform the same real number of transmissions as k = 1. 

5.1. TOY EXAMPLE

In this section, we apply algorithms to solve problem (1) with f i as follows: f i (w) = 1 2 ∥A i w -B i ∥ 2 + cos(a ⊤ i w -b i ), where A i ∈ R 30×20 , B i ∈ R 30 , a i ∈ R 20 and b i ∈ R are all generated from Gaussian distribution. The variance of each component in A i and a i is 1/ √ 20 and the variance of each component in B i is 1/ √ 30 and variance of b i is 1. When generated function f i is not 0.1-strongly convex, we randomly generate a new one until we get strongly convex f i whose modular is not less than 0.1. Three local steps (K=1,5,10) are tested. Here, the local steps are used for w update for algorithm BSA, RFHO, and our algorithm, and the local steps are used for Hessian estimation for algorithm BSA and SUSTAIN. Because for this toy example, we can easily compute the Hessian matrix and its inverse, we test the algorithm using the inverse of estimated Hessian to compute the gradient of x named GD. We test two settings of the toy example. One is the deterministic setting, where no estimation noise or client disconnection will occur. In the other setting, we add white Gaussian noise with a noise level of 0.5 in each estimation (including gradient estimation and Hessian estimation). Also, each client has a 0.5 probability of connecting with the server. To evaluate the performance of different algorithms, we calculate the function value of f 0 (w * (x)) and the stationary of x, i.e. x -Π X (x -0.001∇ x f 0 (w * (x))), where w * (x) is approximate by 200 gradient steps. We take N = 15 and run 20 times and get the results of different algorithms. The results of deterministic setting are shown in Figure 1 , and results of noise setting are shown in Figure 2 . As it is shown in Figure 1 , with local steps getting larger and larger, the performance of BSA, RFHO, and SUSTAIN is getting close to GD, while the performance of the primal-dual method is similar to GD whatever local step used in the algorithm even with only one single step. When noise is added in the Hessian, directly inverse may cause the biased estimation. Thus, the performance of GD gets much worse than it in the deterministic setting shown in Figure 2 . Also, in Figure 2 , our algorithm can perform better than other algorithms when the local step is small. When local steps increase to 10, BSA and our algorithm have competitive results. 

5.2. VISION TASKS

(4),••• ,(10) ∼ 6e -4 ∼ 6e -3 ∼ 2e -3 ∼ 2e -4 In this section, we apply algorithms to train LeNet5(LeCun et al., 1998) on dataset MNIST(LeCun et al., 1998) and Fashion-MNIST (Xiao et al., 2017) . To construct non-iid datasets on different local clients and the global server's validation set, we randomly pick 20 samples per label out of the whole training dataset and form the validation set. Then, the rest of the training data are divided into 3 sets, and each set will be assigned to a local client. The first client contains samples labeled as 0,1,2,3,4, the second client contains samples labeled as 5,6,7, and the third client contains samples labeled as 8,9 for all two datasets. To test the algorithm's ability to choose the proper coefficient of local clients, we add 7 noise nodes containing 5000 samples with random labels. We set the learning rate of w to be a constant learning rate without any decay selected from {0.1, 0.01, 0.001} for all training methods, and the learning rate of x is selected from {0.1, 0.01, 0.001, 0.0001}. The batch size for all three training cases is set to 64. Γ used in the proposed algorithm is set to be 1. For simplicity, we set the local step as 1. We run 2000 iterations for MNIST and 6000 iterations for Fashion-MNIST. Active probability is set in {0.5, 0.9, 1}. We compare the test accuracy among different methods. As a baseline, we report the test accuracy for training with the validation set only named val, training with the average loss of each client named avg, and training with x = (0.5, 0.3, 0.2, 0, • • • , 0) named opt. All experiments run on V100 with Pytorch (Paszke et al., 2019) . Results are shown in Figure 3 , Figure 4 and Table 1 . Figure 3 shows the test accuracy of the MNIST dataset with different active probabilities. Although SUSTAIN works better than the primal-dual algorithm when all local devices participate in the optimization process, when clients' participant rate decreases to 0.5, SUSTAIN works worse than ours method. Primal-dual become slower than SUSTAIN may be because of the initialization of the dual variable. When the dual variable is far from its real value it needs more time to get a good enough point. Other than SUSTAIN, our algorithm can converge faster and more stable to a high accuracy point. Further, we list the output of x and standard error of test accuracy for 5 different experiments for different algorithms in Table 1 . According to Table 1 , our algorithm can achieve a more stable output with respect to x, and the output x is more close to 0.5, 0.3, 0.2, which is related to the number of labels the first three clients holds. Figure 4 gives the test accuracy of training LeNet 5 on the Fashion-MNIST Dataset. Similar to the results of the MNIST dataset, when the clients' participant is high (0.9,1), SUSTAIN works slightly better than the primal-dual algorithm. But when more local devices disconnect to the server, the performance of SUSTAIN drops, while the primal dual algorithm remains fast convergence speed and high test accuracy. 

6. CONCLUSION

In this paper, we proposed a primal-dual-based method for solving a bi-level optimization problem based on a federated learning tasking (local coefficient learning). We give a theoretical analysis that shows the convergence of the proposed algorithm. Though the analysis shows it needs more iterations for the algorithm to converge to an ϵ-stationary point, it works well with a pretty small number of local steps in both toy case and neural network training. Other than that convergence rate can be improved (perhaps it should be in the order of O(1/ √ T ) instead of O(1/ √ T + 1/ √ K)) , the initialization of dual variable affects the speed for convergence, which we leave as the future work. 

A CONVERGENCE OF INNER LOOP

For simplicity, in this section we simplified the inner loop problem as the following: min w g(w) s.t. ∇h(w) = 0. Besides, the algorithm for solving the inner loop problem can be simplified as w t+1 = w t -η w ∇g(w t ) + ∇2 h(w t )λ t + Γ ∇h(w t ) = w t -η w ∇L(w t , λ t ) λ t+1 = Π Λ λ t + η λ ∇h(w t ) Furthermore, the assumptions are as the following: (A1) g and h are differentiable strongly convex function with modular µ, L 1 -Lipschitz gradient and lower bounded by f . (A2) h has L 2 -Lipschitz Hessian. (A3) Λ = {λ|∥λ∥ ≤ D λ }, where D ≥ ∥∇ 2 h(w * ) -1 ∇g(w * )∥ and w * = arg min h(w). (A4) Given w t , ∇g(w t ), ∇2 h(w t ), ∇h(w t ) and ∇h(w t ) are independent to each other. Besides, all of them are unbiased estimators with bounded variance with respect to mean value and a bounded constant σ, i.e. E ∇g(w t ) | w t = ∇g(w t ), E ∇g(w t ) -∇g(w t ) 2 | w t ≤ p∥∇g(w t )∥ 2 + σ 2 ; E ∇2 h(w t ) | w t = ∇ 2 h(w t ), E ∇2 h(w t ) -∇ 2 h(w t ) 2 F | w t ≤ p∥∇ 2 h(w t )∥ 2 + σ 2 ; E ∇h(w t ) | w t = ∇h(w t ), E ∇h(w t ) -∇h(w t ) 2 | w t ≤ p∥∇h(w t )∥ 2 + σ 2 . Thus, it is easy to show E ∇w L(w t , λ t ) | w t , λ t = ∇ w L(w t , λ t ), and E ∇w L(w t , λ t ) -∇ w L(w t , λ t ) 2 | w t , λ t ≤ (1 + Γ 2 + D 2 λ )(p∥∇ w L(w t , λ t )∥ 2 + σ 2 ). First, we give some notations that will be used in this section. Definition 1. Let L(w, λ) = g(w) + λ ⊤ ∇h(w) + Γh(w), d(λ) = min w L(w, λ), w * (λ) = arg min w L(w, λ), λ + t = Π Λ (λ t + η λ ∇h(w t ))) and L L = (1 + Γ)L 1 + D λ L 2 . Define potential function Φ(w t , λ t ) = L(w t , λ t ) -2d(λ t ). Lemma 4 (Descent of Lagrange function). For the function L, it holds that E [L(wt+1, λt+1) -L(wt, λt)] ≤ E C0(ηw)∥∇wL(wt, λt)∥ 2 + LLη 2 w (1 + Γ 2 + D 2 λ )σ 2 2 + (λt+1 -λt) ⊤ ∇h(wt+1)) , where C0(ηw) = L L η 2 w 2 + L L η 2 w p(1+Γ 2 +D 2 λ ) 2 -ηw. Proof. Because L(w, λ) has L L Lipschitz gradient, it holds that L(w t+1 , λ t ) ≤ L(w t , λ t ) + ⟨∇ w L(w t , λ t ), w t+1 -w t ⟩ + L L 2 ∥w t+1 -w t ∥ 2 . By taking expectations with respect to w t and λ t on both side of the above inequality, we obtain that E [L(w t+1 , λ t ) -L(w t , λ t ) | w t , λ t ] ≤ E ⟨∇ w L(w t , λ t ), w t+1 -w t ⟩ + L L 2 ∥w t+1 -w t ∥ 2 | w t , λ t = E ∇ w L(w t , λ t ), -η w ∇w L(w t , λ t ) + L L η 2 w 2 ∥∇ w L(w t , λ t )∥ 2 + L L η 2 w 2 ∥ ∇w L(w t , λ t ) -∇ w L(w t , λ t )∥ 2 | w t , λ t = E L L η 2 w 2 + L L η 2 w p(1 + Γ 2 + D 2 λ ) 2 -η w ∥∇ w L(w t , λ t )∥ 2 + L L η 2 w (1 + Γ 2 + D 2 λ )σ 2 2 | w t , λ t . Meanwhile, it holds that E [L(w t+1 , λ t+1 ) -L(w t+1 , λ t ) | w t , λ t ] = E (λ t+1 -λ t ) ⊤ ∇h(w t+1 )|w t , λ t Combining ( 6), ( 7) and taking expectation on the conditional expectation, we can obtain desired result. Lemma 5. Suppose Γµ > D λ L 2 + L 1 , it holds that ∥w * (λ 1 ) -w * (λ 2 )∥ ≤ β 1 ∥λ 1 -λ 2 ∥ for all λ 1 , λ 2 ∈ Λ, where β 1 = L1 Γµ-D λ L2-L1 . Proof. Note that L(w, λ) is a strongly convex function with respect to the w in the domain λ ∈ Λ with the modular (1 + Γ)µ -D λ L 2 . Thus, it holds that L(w * (λ 1 ), λ 2 ) -L(w * (λ 2 ), λ 2 ) ≥ Γµ -D λ L 2 -L 1 2 ∥w * (λ 1 ) -w * (λ 2 )∥ 2 . On the other hand, we have L(w * (λ 1 ), λ 2 ) -L(w * (λ 2 ).λ 2 ) = L(w * (λ 1 ), λ 2 ) -L(w * (λ 1 ), λ 1 ) + L(w * (λ 1 ), λ 1 ) -L(w * (λ 2 ), λ 1 ) + L(w * (λ 2 ), λ 1 ) -L(w * (λ 2 ), λ 2 ) ≤ (λ 1 -λ 2 ) ⊤ (∇h(w * (λ 2 )) -∇h(w * (λ 1 ))) - Γµ -D λ L 2 -L 1 2 ∥w * (λ 1 ) -w * (λ 2 )∥ 2 . Thus, by combining the above two inequalities, with Cauchy-Schwarz inequality, we can obtain (Γµ -D λ L 2 -L 1 )∥w * (λ 1 ) -w * (λ 2 )∥ 2 ≤ ∥λ 1 -λ 2 ∥∥∇h(w * (λ 1 )) -∇h(w * (λ 2 ))∥ ≤ L 1 ∥λ 1 -λ 2 ∥∥w * (λ 1 ) -w * (λ 2 )∥. Hence, we get the desired result. Lemma 6 (Ascent of dual function). It holds that E [d(λ t+1 ) -d(λ t )] ≥ E ⟨λ t+1 -λ t , ∇h(w * (λ t ))⟩ - L d 2 ∥λ t+1 -λ t ∥ 2 . Proof. It can be calculated by the implicit function theorem that ∇ λ d(λ) = ∇h(w * (λ)). Thus, we can obtain that for all λ 1 , λ 2 ∈ Λ ∥∇d(λ 1 ) -∇d(λ 2 )∥ = ∥∇h(w * (λ 1 )) -∇h(w * (λ 2 ))∥ ≤ L 1 β 1 ∥λ 1 -λ 2 ∥. Therefore, d(λ) is a differentiable function with L d = L 1 β 1 -Lipschitz gradient. With the definition of d(λ), we have E [d(λ t+1 ) -d(λ t )] ≥ E ⟨λ t+1 -λ t , ∇h(w * (λ t ))⟩ - L d 2 ∥λ t+1 -λ t ∥ 2 Lemma 7. With the strongly convexity of L and h, it holds that ∥w * (λ t ) -w * ∥ ≤ 1 µ ∥∇h(w * (λ))∥, and ∥w t -w * (λ t )∥ ≤ 1 Γµ -D λ L 2 -L 1 ∥∇ w L(w t , λ t )∥ Proof. Because of the µ-strongly convexity of h the inequality µ∥w 1 -w 2 ∥ ≤ ∥∇h(w 1 ) -∇h(w 2 )∥ holds for all w 1 , w 2 . With the ∇h(w * ) = 0, we get the result. Similar to h, because of Γµ -D λ L 2 -L 1 -strongly convexity of L, we can prove the second inequality. Lemma 8 (Descent of one step local SGD). It holds that E∥w t -w t+1 ∥ 2 ≤ E η 2 w (1 + p)∥∇ w L(w t , λ t )∥ 2 + η 2 w σ 2 Proof. E∥w t -w t+1 ∥ 2 = E η 2 w ∥ ∇w L(w t , λ t )∥ 2 ≤ E η 2 w (1 + p)∥∇ w L(w t , λ t )∥ 2 + η 2 w σ 2 Lemma 9. d(λ) is a µ 2 L L -strongly concave funtion. We define µ d = µ 2 L L . Proof. Let γ 0 be the largest eigenvalue of ∇ 2 d(λ), γ 1 be the largest eigenvalue of ∂w * (λ) ∂λ . Then, it holds that γ ≤ γ 1 µ. Meanwhile, ∂w * (λ) ∂λ = -∇ 2 w L(w * (λ), λ) -1 ∇ 2 w h(w * (λ)). Thus, γ 1 ≤ -µ L L . Therefore, d(λ) is a µ 2 L L -strongly concave funtion. Lemma 10. It holds that E∥λ t+1 -λ t ∥ 2 ≥ µ 2 d η 2 λ 16 E∥λ t -λ * ∥ 2 - µ d η λ 4 η 2 λ σ 2 + 2η 2 λ (1 + p)L 2 1 + L 2 1 η λ µ d E∥w t -w * (λ t )∥ 2 Proof. With the update rule of λ t+1 , it holds that E∥λ t+1 -λ * ∥ 2 = E∥Π Λ (λ t + η λ ∇h(w t )) -λ * ∥ 2 ≤ E∥λ t + η λ ∇h(w t ) -λ * ∥ 2 = E ∥λ t -λ * ∥ 2 + 2η λ ⟨λ t -λ * , ∇h(w t )⟩ + η 2 λ ∥ ∇h(w t )∥ 2 = E ∥λ t -λ * ∥ 2 + 2η λ ⟨λ t -λ * , ∇h(w * (λ t ))⟩ +2η λ ⟨λ t -λ * , ∇h(w t ) -∇h(w * (λ t ))⟩ + η 2 λ ∥ ∇h(w t )∥ 2 , where the last equality is because ∇h(w t ) is an unbiased estimator. Meanwhile, because d(λ) is a strongly concave function, it holds that η λ ⟨λ t -λ * , ∇h(w * (λ t ))⟩ ≤ -η λ µ d ∥λ t -λ * ∥ 2 (9) Further, it holds that E| ∇h(w t )∥ 2 = E∥ ∇h(w t ) + ∇h(w t ) -∇h(w t )∥ 2 ≤ σ 2 + E(1 + p)∥∇h(w t )∥ 2 ≤ σ 2 + E2(1 + p)∥∇h(w * (λ t )))∥ 2 + 2(1 + p)∥∇h(w * (λ t ) -∇h(w t )∥ 2 ≤ σ 2 + E2(1 + p)∥∇h(w * (λ t )))∥ 2 + 2(1 + p)L 2 1 ∥w * (λ t -w t ∥ 2 ≤ σ 2 + E2(1 + p)L 2 d ∥λ t -λ * ∥ 2 + 2(1 + p)L 2 1 ∥w * (λ t ) -w t ∥ 2 For ⟨λ t -λ * , ∇h(w t ) -∇h(w * (λ t ))⟩, it holds that ⟨λ t -λ * , ∇h(w t ) -∇h(w * (λ t ))⟩ ≤ µ d 2 ∥λ t -λ * ∥ 2 + 1 2µ d ∥∇h(w t ) -∇h(w * (λ t )∥ 2 ≤ µ d 2 ∥λ t -λ * ∥ 2 + L 2 1 2µ d ∥w t -w * (λ t )∥ 2 Combining ( 8), ( 9), ( 10) and ( 11), it holds that t+1 -λ * ∥ 2 ≤ (1 -µ d η λ + 2(1 + p)L 2 d η 2 λ )E∥λ t -λ * ∥ 2 + η 2 σ 2 + 2η 2 λ (1 + p)L 2 1 + L 2 1 η λ µ d E∥w t -w * (λ t )∥ 2 When η λ ≤ µ d 4(1+p)L 2 d , it holds that E∥λ t+1 -λ * ∥ 2 ≤ (1 - µ d η λ 2 )E∥λ * -λ t ∥ 2 + η 2 λ σ 2 + 2η 2 λ (1 + p)L 2 1 + L 2 1 η λ µ d E∥w t -w * (λ t )∥ 2 . ( ) It holds that 12) and ( 13), it holds that E∥λ t+1 -λ t ∥ 2 = E∥λ t+1 -λ * + λ * -λ t ∥ 2 = E ∥λ t+1 -λ * ∥ 2 + ∥λ * -λ t ∥ 2 + 2⟨λ t+1 -λ * , λ * -λ t ⟩ ≥ E (1 -ξ)∥λ t+1 -λ * ∥ 2 + 1 - 1 ξ ∥λ * -λ t ∥ 2 , ∀ξ > 1 (13) Let ξ = 1 + µ d η λ /4, (1 -ξ) 1 -µ d η λ 2 + 1 -1 ξ ≥ µ 2 d η 2 λ /16. combining ( E∥λ t+1 -λ t ∥ 2 ≥ (1 -ξ) 1 - µ d η λ 2 + 1 - 1 ξ ∥λ * -λ t ∥ 2 + (1 -ξ) η 2 λ σ 2 + 2η 2 λ (1 + p)L 2 1 + L 2 1 η λ µ d E∥w t -w * (λ t )∥ 2 ≥ µ 2 d η 2 λ 16 E∥λ t -λ * ∥ 2 - µ d η λ 4 η 2 λ σ 2 + 2η 2 λ (1 + p)L 2 1 Γµ -D λ L 2 -L 1 + L 2 1 η λ µ d Γµ -D λ L 2 -L 1 E∥w t -w * (λ t )∥ 2 Thus, we get the desired result. Define a potential function Φ(w t , λ t ) = L(w t , λ t ) -2d(λ t ). According to the definition we have L(w t , λ t ) > d(λ t ). Then, Φ(w t , λ t ) ≥ -d(λ t ) ≥ -min w g(w) + Γh(w) ≥ -(1 + Γ)f for all w t and λ t ∈ Λ. Lemma 11 (Descent of potential function). It holds that E [Φ(w t+1 , λ t+1 ) -Φ(w t , λ t )] ≤ E C 1 ∥∇ w L(w t , λ t )∥ 2 + C 2 ∥λ t -λ * t ∥ 2 + C 3 where C 1 = -η w + L L η 2 w 2 + L L η 2 w p(1 + Γ 2 + D 2 λ ) 2 + 4L 2 1 η λ η 2 w (1 + p) + 4L 2 1 η λ (Γµ -D λ L 2 -L 1 ) 2 + µ d 16 2η 2 λ (1 + p)L 2 1 + L 2 1 η λ µ d C 2 = - η λ µ 2 d 64 C 3 = L L η 2 w (1 + Γ 2 + D 2 λ )σ 2 2 + 2L 2 1 η λ η 2 w σ 2 + 2η 2 λ σ 2 + µ d η 2 λ σ 2 16 . Proof. With Lemma 4 and Lemma 6, it holds that E [Φ(w t+1 , λ t+1 ) -Φ(w t , λ t )] ≤ E C 0 (η w )∥∇ w L(w t , λ t )∥ 2 + (λ t+1 -λ t ) ⊤ ∇h(w t+1 ) -2⟨λ t+1 -λ t , ∇h(w * (λ t ))⟩ + L d ∥λ t+1 -λ t ∥ 2 + L L η 2 w (1 + Γ 2 + D 2 λ )σ 2 2 . We deal with each term as follows. For the second term and the third term, it holds that (λ t+1 -λ t ) ⊤ ∇h(w t+1 ) -2⟨λ t+1 -λ t , ∇h(w * (λ t ))⟩ = 2⟨λ t+1 -λ t , ∇h(w t+1 ) -∇h(w * (λ t ))⟩ -(λ t+1 -λ t ) ⊤ ∇h(w t+1 ) ≤ 2L 1 ∥λ t+1 -λ t ∥∥w t+1 -w * (λ t )∥ - 1 η λ ∥λ t+1 -λ t ∥ 2 ≤ 2L 1 ∥λ t+1 -λ t ∥ (∥w t -w * (λ t ))∥ + ∥w t+1 -w t ∥) - 1 η λ ∥λ t+1 -λ t ∥ 2 ≤ 1 2η λ - 1 η λ ∥λ t+1 -λ t ∥ 2 + 4L 2 1 η λ ∥w t -w * (λ t ))∥ 2 + 4L 2 1 η λ ∥w t+1 -w t ∥ 2 By taking the expectation on the both side of inequality, it holds that E (λt+1 -λt) ⊤ ∇h(wt+1) -2⟨λt+1 -λt, ∇h(w * (λt))⟩ ≤ E - 1 2η λ ∥λt+1 -λt∥ 2 + 4L 2 1 η λ ∥wt -w * (λt))∥ 2 + 4L 2 1 η λ ∥wt+1 -wt∥ 2 ≤ E - 1 2η λ ∥λt+1 -λt∥ 2 + 4L 2 1 η λ η 2 w (1 + p)∥∇wL(wt, λt)∥ 2 + 4L 2 1 η λ (Γµ -D λ L2 -L1) 2 ∥∇wL(wt, λt) 2 ∥ 2 + 2L 2 1 η λ η 2 w σ 2 . Meanwhile, with Lemma 10 and Lemma 7 it holds that E∥λ t+1 -λ t ∥ 2 ≥ µ 2 d η 2 λ 16 E∥λ t -λ * ∥ 2 - µ d η λ 4 η 2 λ σ 2 + 2η 2 λ (1 + p)L 2 1 + L 2 1 η λ µ d E∥w t -w * (λ t )∥ 2 ≥ µ 2 d η 2 λ 16 E∥λ t -λ * ∥ 2 - µ d η λ 4 η 2 λ σ 2 + 2η 2 λ (1 + p)L 2 1 + L 2 1 η λ µ d E∥∇ w L(w t , λ t )∥ 2 Further, when η λ ≤ 1/(4L d ), it holds that -1 2η λ ∥λ t+1 -λ t ∥ 2 + L d ∥λ t+1 -λ t ∥ 2 ≤ 1 4η λ . Thus, it holds that E [Φ(w t+1 , λ t+1 ) -Φ(w t , λ t )] ≤ E C 0 (η w )∥∇ w L(w t , λ t )∥ 2 + (λ t+1 -λ t ) ⊤ ∇h(w t+1 ) -2⟨λ t+1 -λ t , ∇h(w * (λ t ))⟩ + L d ∥λ t+1 -λ t ∥ 2 + L L η 2 w (1 + Γ 2 + D 2 λ )σ 2 2 ≤ E C 1 ∥∇ w L(w t , λ t )∥ 2 + C 2 ∥λ t -λ * t ∥ 2 + C 3 , where C 1 = -η w + L L η 2 w 2 + L L η 2 w p(1 + Γ 2 + D 2 λ ) 2 + 4L 2 1 η λ η 2 w (1 + p) + 4L 2 1 η λ (Γµ -D λ L 2 -L 1 ) 2 + µ d 16 2η 2 λ (1 + p)L 2 1 + L 2 1 η λ µ d C 2 = - η λ µ 2 d 64 C 3 = L L η 2 w (1 + Γ 2 + D 2 λ )σ 2 2 + 2L 2 1 η λ η 2 w σ 2 + 2η 2 λ σ 2 + µ d η 2 λ σ 2 16 .

Proof. Proof of the Inner Convergence

By the definition of Φ, it holds that Φ(w, λ) ≥ -d(λ) ≥ f . Thus, with Lemma 11, by summing up T terms, it holds that 1 K K t=1 E η w 2 ∥∇ w L(w t , λ t )∥ 2 + η λ L 2 L 8µ 4 ∥λ t -λ * ∥ 2 ≤ Φ(w 1 , λ 1 ) -f K + C 3 Then, let η w = Θ(1/ √ K) and η λ = Θ(1/ √ K), it holds that 1 K E K t=1 ∥λ t -λ * ∥ 2 = O(1/ √ K) 1 K E K t=1 ∥w t -w * ∥ 2 = O(1/ √ K) On the other hand, it holds that ∥λ ⊤ t ∇ w f i (w t ) -λ * ⊤ ∇ w f i (w * )∥ 2 ≤ 2∥∇ w f i (w * )∥ 2 ∥λ t -λ * ∥ 2 + 2∥λ t ∥ 2 ∥∇f i (w t ) -∇f i (w * )∥ 2 ≤ 2D 2 w ∥λ t -λ * ∥ 2 + 2D 2 λ L 2 1 ∥w t -w * ∥ 2 Therefore, it holds that 1 K E K t=1 ∥λ ⊤ t ∇ w f i (w t )-λ * ⊤ ∇ w f i (w * )∥ 2 ≤ 2 K E K t=1 D 2 w ∥λ t -λ * ∥ 2 +D 2 λ L 2 1 ∥w t -w * ∥ 2 = O(1/ √ K)

B PROOF OF OUTER LOOP CONVERGENCE

In this section, we give the proof of the outer loop convergence. Lemma 12. Suppose (A1)-(A4) holds. Then, for all x 1 , x 2 ∈ X , it holds that ∥w * (x 1 ) -w * (x 2 )∥ ≤ √ N D w µ + √ N D w L 1 µ 2 ∥x 1 -x 2 ∥ Proof. Because N i=1 x (i) 2 f i (w) is µ is a strongly convex function, it holds that N i=1 x (i) 2 (f i (w * (x 1 )) -f i (w * (x 2 ))) ≥ µ 2 ∥w * (x 1 ) -w * (x 2 )∥ 2 . On the other hand, we have N i=1 x (i) 2 (f i (w * (x 1 )) -f i (w * (x 2 ))) = N i=1 x (i) 2 f i (w * (x 1 )) - N i=1 x (i) 1 f i (w * (x 1 )) + N i=1 x (i) 1 f i (w * (x 1 )) - N i=1 x (i) 1 f i (w * (x 2 )) + N i=1 x (i) 1 f i (w * (x 2 )) - N i=1 x (2) 2 f i (w * (x 2 )) ≤ N i=1 (x (i) 1 -x (i) 2 )(f i (w * (x 2 )) -f i (w * (x 1 ))) - µ 2 ∥w * (x 2 ) -w * (x 1 )∥ 2 . Meanwhile, we have |f i (w * (x 1 ))-f i (w * (x 2 ))| ≤ max(∥∇ w f (w * (x 1 ))∥, ∥∇ w f (w * (x 2 ))∥)∥w * (x 1 )-w * (x 2 )∥+ L 1 2 ∥w * (x 2 )-w * (x 1 )∥ 2 With strongly convexity, it holds that ∥w * (x 2 ) -w * (x 1 )∥ ≤ 1 µ ∥∇f (w * (x 1 )) -∇f (w * (x 2 ))∥ ≤ 2D w µ . Thus, it holds that f i (w * (x 2 )) -f i (w * (x 1 )) ≤ D w (1 + L 1 /µ)∥w * (x 2 ) -w * (x 1 )∥. Combining the above inequalities, it holds that µ∥w * (x 1 )-w * (x 2 )∥ 2 ≤ ∥x 1 -x 2 ∥ 1 (max|f i (w * (x 1 ))-f i (w * (x 2 ))|) ≤ √ N D w (1+L 1 /µ)∥w * (x 1 )-w * (x 2 )∥ Lemma 13. Suppose (A1)-(A4) holds, then f 0 (w * (x)) has Lipschitz gradient with Lipschitz constant D 2 w µ 2 + 2DwL1 µ √ N Dwµ+ √ N DwL1 µ 2 . Proof. ∂f 0 (w * (x)) ∂x (i) = -∇ w f 0 (w * (x)) ⊤   N j=1 x (j) ∇ 2 w ∇f j (w * (x))   -1 ∇ w f i (w * (x)) By the smoothness and strongly convexity it holds that   N j=1 x (j) 1 ∇ 2 w ∇f j (w * (x 1 ))   -1 -   N j=1 x (j) 2 ∇ 2 w ∇f j (w * (x 2 ))   -1 ≤ 1 µ 2 N j=1 x (j) 1 ∇ 2 w ∇f j (w * (x 1 )) - N j=1 x (j) 2 ∇ 2 w ∇f j (w * (x 2 )) ≤ 1 µ 2 max j ∥∇ 2 w ∇f j (w * (x 1 )) -∇ 2 w ∇f j (w * (x 2 ))∥ ≤ L 2 µ 2 ∥w * (x 1 ) -w * (x 2 )∥, where the first inequality is due to ∥A(A -1 -B -1 )B∥ = ∥A -B∥ for all invertible matrices A, B. Then, it holds that   N j=1 x (j) 1 ∇ 2 w ∇f j (w * (x 1 ))   -1 ∇ w f i (w * (x 1 )) -   N j=1 x (j) 2 ∇ 2 w ∇f j (w * (x 2 ))   -1 ∇ w f i (w * (x 2 )) ≤      N j=1 x (j) 1 ∇ 2 w ∇f j (w * (x 1 ))   -1 -   N j=1 x (j) 2 ∇ 2 w ∇f j (w * (x 2 ))   -1    ∇ w f i (w * (x 1 )) +   N j=1 x (j) 2 ∇ 2 w ∇f j (w * (x 2 ))   -1 (∇f i (w * (x 1 )) -∇f i (w * (x 2 ))) ≤ D w µ 2 ∥w * (x 1 ) -w * (x 2 )∥ + L 1 µ ∥w * (x 1 ) -w * (x 2 )∥. Thus, combining with the definition of gradient, it holds that Therefore, combine with Lemma 12, we can obtain the result. Lemma 14. Suppose function f (x) is lower bounded by f with L-Lipshitz gradient. g(x) is an unbiased gradient estimator of ∇f (x) satisfying that expected norm of g(x) are bounded by G in domain X for function f . Then with update rule x t+1 = Π X (x t -η x (g(x t ) + ξ t )), where η x = Θ(1/ √ T ), X is a convex set and E∥ξ t ∥ 2 ≤ ϵ 2 . By defining x = arg min y∈X (f (y)+ ρ 2 ∥y-x∥ 2 ) and ∇ρ f (x) = ρ(x -x), where ρ = 2L, then it holds that 1 T E T t=1 ∥ ∇ρ f (x t )∥ 2 = O(1/ √ T + ϵ 2 ). Proof. It holds that E f (x t+1 + ρ 2 ∥x t+1 -xt+1 ∥ 2 ≤ E f (x t ) + ρ 2 ∥x t+1 -xt ∥ 2 = E f (x t ) + ρ 2 ∥Π X (x t -η x (g(x t ) + ξ t )) -xt ∥ 2 ≤ E f (x t ) + ρ 2 ∥x t -η x (g(x t ) + ξ t ) -xt ∥ 2 = E f (x t ) + ρ 2 ∥x t -xt ∥ 2 + ρη 2 x ∥g(x t ) + ξ t ∥ 2 -η x ρ⟨x t -xt , g(x t )⟩ -η x ρ⟨x t -xt , ξ t ⟩ ≤ E f (x t ) + ρ 2 ∥x t -xt ∥ 2 + ρη 2 x (G 2 + ϵ 2 ) -ρη x E f (x t ) -f (x t ) - L 2 ∥x t -xt ∥ 2 + ρη x E L 2 ∥x t -xt ∥ 2 + 1 2L ϵ 2 Then summing up the above inequality, it holds that ρη x T t=1 E f (x t ) -f (x t ) -L∥x t -xt ∥ 2 ≤ E f (x 1 ) + ρ 2 ∥x 1 -x1 ∥ 2 -(f (x T +1 + ρ 2 ∥x T +1 -xT +1 ∥ 2 ) + T ρη 2 x (G 2 + ϵ 2 ) + T ρη x 2 ϵ 2 ≤ f (x 1 ) + ρ 2 ∥x 1 -x1 ∥ 2 -f ) + T ρη 2 x (G 2 + ϵ 2 ) + T ρη x 2L ϵ 2 On the other hand, because xt = arg min y∈X (f (y) + ρ 2 ∥x t -y∥ 2 ) and function f (y) + ρ 2 ∥x t -y∥ 2 is strongly convex with modular ρ -L, it holds that f (x t ) -f (x t ) -L∥x t -xt ∥ 2 = f (x t ) + ρ 2 ∥x t -x t ∥ 2 -(f (x t ) + ρ 2 ∥x t -xt ∥ 2 ) + ρ -L -1 2 ∥x t -xt ∥ 2 ≥ 2ρ -3L 2 ∥x t -xt ∥ 2 = L 2 ∥x t -xt ∥ 2 Thus, combining the above inequalities, it holds that ρ 2 T T t=1 E∥x t -xt ∥ 2 ≤ 2ρ 2 T L T t=1 Ef (x t ) -f (x t ) -L∥x t -xt ∥ 2 ≤ ρf (x 1 ) + ρ 2 2 ∥x 1 -x1 ∥ 2 -ρf ) T η x + ρ 2 η x (G 2 + ϵ 2 ) + ρ 2 2L ϵ 2 . Therefore, when η x = O(1/ √ T ), 1 T E T t=1 ∥ ∇ρ f (x t )∥ 2 = ρ 2 T T t=1 E∥x t -xt ∥ 2 = O(1/ √ T + ϵ 2 ). Proof of Theorem 2. With Lemma 13, we can obtain f 0 (w * (x)) has Lipschitz gradient on domain X . Define ξ (i) t = λ T t ∇ w f i (w t,k ) -λ * (x) T ∇ w f i (w * (x)) When we use λ t,k and w t,k with random k, as it suggests in Theorem 1, the expected norm of ξ ( t i) can be bounded by O(1/ √ K). Let x (i) t+1/2 = x (i) t -η x (g(x t ) (i) + λ T t ∇ w f i (w t,k ) -λ * (x t ) T ∇ w f i (w * (x t ))), Then, it holds that Eg(x t ) (i) = E 1 η x (x t+1/2 -x t ) -λ T t ∇ w f i (w t,k ) + λ * (x t ) T ∇ w f i (w * (x t ))) = λ * (x t ) T ∇ w f i (w * (x t ))) and E g(x t ) (i) ) 2 ≤ (1 + p)∥λ * (x t ) T ∇ w f i (w * (x t )))∥ 2 + ϵ 2 + σ 2 ≤ (1 + p)D 2 λ D 2 w + ϵ 2 + σ 2 Thus, together with Lemma 14, we can directly get the result.

C ADDITIONAL EXPERIMENTAL RESULTS ON MNIST AND FASHION MNIST WITH LENET5

We use 10 clients in this experiment. The first 5 clients contains i.i.d. 9000 samples, the last 5 clients contains 9000 samples with random label. The rest setting is the same as it in the main text. The results are shown in Figure 5 , and Figure 6 . 



A welltrained local personalized personalized model is needed for each local device in personalized federated learning. Jiang et al. (2019); Deng et al. (2020) propose a method that they train a global model and then fine-tune the trained global model to get the local model. T Dinh et al. (2020); Fallah et al. (2020) change the local objective function to make each local has the ability to be different and handle individual local tasks. Li et al. (2021) introduces a two-level optimization problem for seeking the best local model from great global models. All of these works do not involve a validation set as a reference, but they use a few gradient steps or simple modifications and hope the local model can both fit the local training data and use information from the global model (other local devices).

Figure 1: The figure shows the result of the toy example where all clients participate in the optimization process in each iteration, and all gradient and hessian are estimated without noise. The above line shows the stationary of x in each iteration, and the second row shows the function value of x (f (w * (x))). The left column shows the results when the number of local steps is 1; the middle column shows the results of 5 local steps, and the right column gives the results of 10 local steps. The shadow part of the function value corresponds to the 0.1 standard error area, and the shadow part in stationary corresponds to the 0.5 standard error area.

Figure 2: The figure shows the result of the toy example where the active rate is 0.5 in each iteration, and all gradient and hessian are estimated with white-Gaussian noise with a noise level of 0.5. The above line shows the stationary of x in each iteration, and the second row shows the function value of x (f (w * (x))). The left column shows the results when the number of local steps is 1; the middle column shows the results of 5 local steps, and the right column gives the results of 10 local steps. The shadow part of the function value corresponds to the 0.1 standard error area, and the shadow part in stationary corresponds to the 0.5 standard error area.

Figure 3: Test accuracy of training LeNet 5 on MNIST dataset. The left curve shows the result when the active probability is 1; the middle curve shows the result when the active probability is 0.9, and the right curve shows the result with the active probability of 0.5.

Figure 4: Test accuracy of training LeNet 5 on the Fashion-MNIST dataset. The left curve shows the result when the active probability is 1; the middle curve shows the result when the active probability is 0.9, and the right curve shows the result with 0.5 active probability.

* (x1) -w * (x2)∥ + L1 µ ∥w * (x1) -w * (x2)∥) + DwL1 µ ∥w * (x1) -w * (x2)∥

Figure 5: Test accuracy of training LeNet5 on MNIST dataset in iid case. The left curve shows the result when active probability is 1, and the right figure shows the result when active probability is 0.5.

Figure 6: Test accuracy of training LeNet5 on Fashion-MNIST dataset in iid case. The left curve shows the result when active probability is 1, and the right figure shows the result when active probability is 0.5.

Test Accuracy and x output of Training LeNet 5 on MNIST. "AP" represents Active Probability, and Accuracy stands for Test Accuracy.

Chenggen Shi, Jie Lu, and Guangquan Zhang. An extended kuhn-tucker approach for linear bilevel programming. Applied Mathematics and Computation, 162(1):51-63, 2005. Canh T Dinh, Nguyen Tran, and Josh Nguyen. Personalized federated learning with moreau envelopes. Advances in Neural Information Processing Systems, 33:21394-21405, 2020.

