New Bounds For Distributed Mean Estimation and Variance Reduction

Abstract

We consider the problem of distributed mean estimation (DME), in which n machines are each given a local d-dimensional vector x v ∈ R d , and must cooperate to estimate the mean of their inputs µ = 1 n n v=1 x v , while minimizing total communication cost. DME is a fundamental construct in distributed machine learning, and there has been considerable work on variants of this problem, especially in the context of distributed variance reduction for stochastic gradients in parallel SGD. Previous work typically assumes an upper bound on the norm of the input vectors, and achieves an error bound in terms of this norm. However, in many real applications, the input vectors are concentrated around the correct output µ, but µ itself has large norm. In such cases, previous output error bounds perform poorly. In this paper, we show that output error bounds need not depend on input norm. We provide a method of quantization which allows distributed mean estimation to be performed with solution quality dependent only on the distance between inputs, not on input norm, and show an analogous result for distributed variance reduction. The technique is based on a new connection with lattice theory. We also provide lower bounds showing that the communication to error trade-off of our algorithms is asymptotically optimal. As the lattices achieving optimal bounds under 2 -norm can be computationally impractical, we also present an extension which leverages easy-to-use cubic lattices, and is loose only up to a logarithmic factor in d. We show experimentally that our method yields practical improvements for common applications, relative to prior approaches.

1. Introduction

Several problems in distributed machine learning and optimization can be reduced to variants distributed mean estimation problem, in which n machines must cooperate to jointly estimate the mean of their d-dimensional inputs µ = 1 n n v=1 x v as closely as possible, while minimizing communication. In particular, this construct is often used for distributed variance reduction: here, each machine receives as input an independent probabilistic estimate of a d-dimensional vector ∇, and the aim is for all machines to output a common estimate of ∇ with lower variance than the individual inputs, minimizing communication. Without any communication restrictions, the ideal output would be the mean of all machines' inputs. While variants of these fundamental problems have been considered since seminal work by Tsitsiklis & Luo (1987) , the task has seen renewed attention recently in the context of distributed machine learning. In particular, variance reduction is a key component in data-parallel distributed stochastic gradient descent (SGD), the standard way to parallelize the training of deep neural networks, e.g. Bottou (2010); Abadi et al. (2016) , where it is used to estimate the average of gradient updates obtained in parallel at the nodes. Thus, several prior works proposed efficient compression schemes to solve variance reduction or mean estimation, see e.g. Suresh et al. ( 2017 2019) for a general survey of practical distribution schemes. These schemes seek to quantize nodes' inputs coordinatewise to one of a limited collection of values, in order to then efficiently encode and transmit these quantized values. A trade-off then arises between the number of bits sent, and the added variance due of quantization. Since the measure of output quality is variance, it appears most natural to evaluate this with respect to input variance, in order to show that variance reduction is indeed achieved. Surprisingly, however, we are aware of no previous works which do so; all existing methods give bounds on output variance in terms of the squared input norm. This is clearly suboptimal when the squared norm is higher than the variance, i.e., when inputs are not centered around the origin. In some practical scenarios this causes output variance to be higher than input variance, as we demonstrate in Section 4. Contributions. In this paper, we provide the first bounds for distributed mean estimation and variance reduction which are still tight when inputs are not centered around the origin. Our results are based on new lattice-based quantization techniques, which may be of independent interest, and come with matching lower bounds, and practical extensions. More precisely, our contributions are as follows: • For distributed mean estimation, we show that, to achieve a reduction of a factor q in the input 'variance' (which we define to be the maximum squared distance between inputs), it is necessary and sufficient for machines to communicate Θ(d log q) bits. • For variance reduction, we show tight Θ(d log n) bounds on the worst-case communication bits required to achieve optimal Θ(n)-factor variance reduction by n nodes over d-dimensional input, and indeed to achieve any variance reduction at all. We then show how incorporating error detection into our quantization scheme, we can also obtain tight bounds on the bits required in expectation. • We show how to efficiently instantiate our lattice-based quantization framework in practice, with guarantees. In particular, we devise a variant of the scheme which ensures close-to-optimal communication-variance bounds even for the standard cubic lattice, and use it to obtain improvements relative to the best known previous methods for distributed mean estimation, both on synthetic and real-world tasks.

1.1. Problem Definitions and Discussion

MeanEstimation is defined as follows: we have n machines v, and each receives as input a vector x v ∈ R d . We also assume that all machines receive a common value y, with the guarantee that for any machines u, v, x u -x v ≤ y. Our goal is for all machines to output the same value EST ∈ R d , which is an unbiased estimator of the mean µ = 1 n v∈M x v , i.e. E [EST ] = µ, with variance as low as possible. Notice that the input specification is entirely deterministic; any randomness in the output arises only from the algorithm used. In the variant of VarianceReduction, we again have a set of n machines, and now an unknown true vector ∇. Each machine v receives as input an independent unbiased estimator x v of ∇ (i.e., E [x v ] = ∇) with variance E x v -∇ 2 ≤ σ 2 . Machines are assumed to have knowledge of σ. Our goal is for all machines to output the same value EST ∈ R d , which is an unbiased estimator of ∇, i.e., E [EST ] = ∇, with low variance. Since the input is random, output randomness now stems from this input randomness as well as any randomness in the algorithm. VarianceReduction is common for instance in the context of gradient-based optimization of machine learning models, where we assume that each machine v processes local samples in order to obtain a stochastic gradient gv , which is an unbiased estimator of the true gradient ∇, with variance bound σ 2 . If we directly averaged the local stochastic gradients gv , we



); Alistarh et al. (2017); Ramezani-Kebrya et al. (2019); Gandikota et al. (2019), and Ben-Nun & Hoefler (

