DASHA: DISTRIBUTED NONCONVEX OPTIMIZATION WITH COMMUNICATION COMPRESSION AND OPTIMAL ORACLE COMPLEXITY

Abstract

We develop and analyze DASHA: a new family of methods for nonconvex distributed optimization problems. When the local functions at the nodes have a finite-sum or an expectation form, our new methods, DASHA-PAGE, DASHA-MVR and DASHA-SYNC-MVR, improve the theoretical oracle and communication complexity of the previous state-of-the-art method MARINA by Gorbunov et al. (2020). In particular, to achieve an ε-stationary point, and considering the random sparsifier RandK as an example, our methods compute the optimal number of gradients O ( finite-sum and expectation form cases, respectively, while maintaining the SOTA communication complexity O ( d /ε √ n). Furthermore, unlike MARINA, the new methods DASHA, DASHA-PAGE and DASHA-MVR send compressed vectors only, which makes them more practical for federated learning. We extend our results to the case when the functions satisfy the Polyak-Łojasiewicz condition. Finally, our theory is corroborated in practice: we see a significant improvement in experiments with nonconvex classification and training of deep learning models.

1. INTRODUCTION

Nonconvex optimization problems are widespread in modern machine learning tasks, especially with the rise of the popularity of deep neural networks (Goodfellow et al., 2016) . In the past years, the dimensionality of such problems has increased because this leads to better quality (Brown et al., 2020) and robustness (Bubeck & Sellke, 2021) of the deep neural networks trained this way. Such huge-dimensional nonconvex problems need special treatment and efficient optimization methods (Danilova et al., 2020) . Because of their high dimensionality, training such models is a computationally intensive undertaking that requires massive training datasets (Hestness et al., 2017) , and parallelization among several compute nodes 1 (Ramesh et al., 2021). Also, the distributed learning paradigm is a necessity in federated learning (Konečný et al., 2016) , where, among other things, there is an explicit desire to secure the private data of each client. Unlike in the case of classical optimization problems, where the performance of algorithms is defined by their computational complexity (Nesterov, 2018) , distributed optimization algorithms are typically measured in terms of the communication overhead between the nodes since such communication is often the bottleneck in practice (Konečný et al., 2016; Wang et al., 2021) . Many approaches tackle the problem, including managing communication delays (Vogels et al., 2021) , fighting with stragglers (Li et al., 2020a) , and optimization over time-varying directed graphs (Nedić & Olshevsky, 2014) . Another popular way to alleviate the communication bottleneck is to use lossy compression of communicated messages (Alistarh et al., 2017; Mishchenko et al., 2019; Gorbunov et al., 2021; Szlendak et al., 2021) . In this paper, we focus on this last approach.

1.1. PROBLEM FORMULATION

In this work, we consider the optimization problem min x∈R d f (x) := 1 n n i=1 f i (x) , where f i : R d → R is a smooth nonconvex function for all i ∈ [n] := {1, . . . , n}. Moreover, we assume that the problem is solved by n compute nodes, with the i th node having access to function f i only, via an oracle. Communication is facilitated by an orchestrating server able to communicate with all nodes. Our goal is to find an ε-solution (ε-stationary point) of (1): a (possibly random) point x ∈ R d , such that E ∇f ( x) 2 ≤ ε.

1.2. GRADIENT ORACLES

We consider all of the following structural assumptions about the functions {f i } n i=1 , each with its own natural gradient oracle: 1. Gradient Setting. The i th node has access to the gradient ∇f i : R d → R d of function f i . 2. Finite-Sum Setting. The functions {f i } n i=1 have the finite-sum form f i (x) = 1 m m j=1 f ij (x), ∀i ∈ [n], where 3. Stochastic Setting. The function f i is an expectation of a stochastic function, f ij : R d → R f i (x) = E ξ [f i (x; ξ)] , ∀i ∈ [n], where f i : R d × Ω ξ → R. For a fixed x ∈ R, f i (x; ξ) is a random variable over some distribution D i , and, for a fixed ξ ∈ Ω ξ , f i (x; ξ) is a smooth nonconvex function. The i th node has access to a mini-batch of B stochastic gradients 1 B B j=1 ∇f i (•; ξ ij ) of the function f i through the distribution D i , where {ξ ij } B j=1 is a collection of i.i.d. samples from D i .

1.3. ORACLE COMPLEXITY

In this paper, the oracle complexity of a method is the number of (stochastic) gradient calculations per node to achieve an ε-solution. Every considered method performs some number T of communications rounds to get an ε-solution; thus, if every node (on average) calculates B gradients in each communication round, then the oracle complexity equals O (B init + BT ) , where B init is the number of gradient calculations in the initialization phase of a method.

1.4. UNBIASED COMPRESSORS

The method proposed in this paper is based on unbiased compressors -a family of stochastic mappings with special properties that we define now. Definition 1.1. A stochastic mapping C : R d → R d is an unbiased compressor if there exists ω ∈ R such that E [C(x)] = x, E C(x) -x 2 ≤ ω x 2 , ∀x ∈ R d . We denote this class of unbiased compressors as U(ω). One can find more information about unbiased compressors in (Beznosikov et al., 2020; Horváth et al., 2019) . The purpose of such compressors is to quantize or sparsify the communicated vectors in order to increase the communication speed between the nodes and the server. Our methods will work collection of stochastic mappings {C i } n i=1 satisfying the following assumption. Assumption 1.2. C i ∈ U(ω) for all i ∈ [n], and the compressors are independent.



Alternatively, we sometimes use the terms: machines, workers and clients.



is a smooth nonconvex function for all j ∈ [m]. For all i ∈ [n], the i th node has access to a mini-batch of B gradients, 1 B j∈Ii ∇f ij (•), where I i is a multi-set of i.i.d. samples of the set [m], and |I i | = B.

