LEARNING TO GENERATE WASSERSTEIN BARYCEN-TERS

Abstract

Optimal transport is a notoriously difficult problem to solve numerically, with current approaches often remaining intractable for very large scale applications such as those encountered in machine learning. Wasserstein barycenters -the problem of finding measures in-between given input measures in the optimal transport sense -is even more computationally demanding. By training a deep convolutional neural network, we improve by a factor of 60 the computational speed of Wasserstein barycenters over the fastest state-of-the-art approach on the GPU, resulting in milliseconds computational times on 512 × 512 regular grids. We show that our network, trained on Wasserstein barycenters of pairs of measures, generalizes well to the problem of finding Wasserstein barycenters of more than two measures. We validate our approach on synthetic shapes generated via Constructive Solid Geometry as well as on the "Quick, Draw" sketches dataset.

1. INTRODUCTION

Optimal transport is becoming widespread in machine learning, but also in computer graphics, vision and many other disciplines. Its framework allows for comparing probability distributions, shapes or images, as well as producing interpolations of these data. As a result, it has been used in the context of machine learning as a loss for training neural networks (Arjovsky et al., 2017) , as a manifold for dictionary learning (Schmitz et al., 2018) , clustering (Mi et al., 2018) and metric learning applications (Heitz et al., 2019) , as a way to sample an embedding (Liutkus et al., 2019) and transfer learning (Courty et al., 2014) , and many other applications (see Sec. 2.3). However, despite recent progress in computational optimal transport, in many cases these applications have remained limited to small datasets due to the substantial computational cost of optimal transport, in terms of speed, but also memory. We tackle the problem of efficiently computing Wasserstein barycenters of measures discretized on regular grids, a setting common to several of these machine learning applications. Wasserstein barycenters are interpolations of two or more probability distributions under optimal transport distances. As such, a common way to obtain them is to perform a minimization of a functional involving optimal transport distances or transport plans, which is thus a very costly process. Instead, we directly predict Wasserstein barycenters by training a Deep Convolutional Neural Network (DCNN) specific to this task. An important challenge behind our work is to build an architecture that can handle a variable number of input measures with associated weights without needing to retrain a specific network. To achieve that, we specify and adapt an architecture designed for and trained with two input measures, and show that we can use this modified network with no retraining to compute barycenters of more than two measures. Directly predicting Wasserstein barycenters avoids the need to compute a Wasserstein embedding (Courty et al., 2017) , and our experiments suggest that this results in better Wasserstein barycenters approximations. Our implementation is publicly availablefoot_0 . Contributions This paper introduces a method to compute Wasserstein barycenters in milliseconds. It shows that this can be done by learning Wasserstein barycenters of only two measures on a dataset of random shapes using a DCNN, and by adapting this DCNN to handle multiple input measures without retraining. This proposed approach is 60x faster than the fastest state-of-the-art GPU library, and performs better than Wasserstein embeddings.

2.1. WASSERSTEIN DISTANCES AND APPROXIMATIONS

Optimal transport seeks the best way to warp a given probability measure µ 0 to form another given probability measure µ 1 by minimizing the total cost of moving individual "particles of earth". We restrict our description to discrete distributions. In this setting, finding the optimal transport between two probability measures is often achieved by solving a large linear program (Kantorovich, 1942)more details on this theory and numerical tools can be found in the book of Peyré et al. (2019) . This minimization results in the so-called Wasserstein distance, the mathematical distance defined by the total cost of reshaping µ 0 to µ 1 . This distance can be used to compare probability distributions, in particular in a machine learning context. It also results in a transport plan, a matrix P (x, y) representing the amount of mass of µ 0 traveling from location x in µ 0 towards location y in µ 1 . However, the Wasserstein distance is notoriously difficult to compute -the corresponding linear program is huge, and dedicated solvers typically solve this problem in O(N 3 log N ), with N the size of the input measures discretization. Recently, numerous approaches have attempted to approximate Wasserstein distances. One of the most efficient methods, the so-called Sinkhorn algorithm introduces an entropic regularization, allowing to compute such distances by iteratively performing fast matrix-vector multiplications (Cuturi, 2013) or convolutions in the case of regular grids (Solomon et al., 2015) . However, this comes at the expense of smoothing the transport plan and removing guarantees regarding this mathematical distance (in particular, the regularized cost W (µ 0 , µ 0 ) = 0). These issues are addressed by Sinkhorn divergences (Feydy et al., 2018; Genevay et al., 2017) . This approach symmetrizes the entropy-regularized optimal transport distance, adding guarantees on this divergence (now, the cost S (µ 0 , µ 0 ) = 0 by construction, though triangular inequality still does not hold) but also effectively reducing blur, while maintaining a relatively fast numerical algorithm. They show that this divergence interpolates between optimal transport distances and Maximum Mean Discrepancies. Sinkhorn divergences are implemented in the GeomLoss library (Feydy, 2019), relying on a specific computational scheme on the GPU (Feydy et al., 2019; 2018; Schmitzer, 2019) and constitutes the state-of-the-art in term of speed and approximation of optimal transport-like distances.

2.2. WASSERSTEIN BARYCENTERS

The Wasserstein barycenter of a set of probability measures corresponds to the Fréchet mean of these measures under the Wasserstein distance (i.e., a weighted mean under the Wasserstein metric). Wasserstein barycenters allow to interpolate between two or more probability measures by warping these measures (contrarily to Euclidean barycenters that blends them). Similarly to Wasserstein distances, Wasserstein barycenters are very expensive to compute. An entropy-regularized approach based on Sinkhorn-like iterations also allows to efficiently compute blurred Wasserstein barycenters. Reducing blur via Sinkhorn divergences is also doable, but does not benefit from a very fast Sinkhorn-like algorithm: a weighted sum of Sinkhorn divergences needs to be iteratively minimized, which adds significant computational cost. In our approach, we rely on Sinkhorn divergence-based barycenters to feed training data to a Deep Convolutional Neural Network, and aim at speeding up this approach. Other fast transport-based barycenters include that of sliced and Radon Wasserstein barycenters, obtained via Wasserstein barycenters on 1-d projections (Bonneel et al., 2015) , which we compare to. A recent trend seeks linearizations or Euclidean embeddings of optimal transport problems. Notably, Nader & Guennebaud (2018) approximate Wasserstein barycenters by first solving an optimal transport map between a uniform measure towards n input measures, and then linearly combining Monge maps. This allows for efficient computations -typically of the order of half a second for 512x512 images. A similar approach is taken within the documentation of the GeomLoss library (Feydy, 2019)foot_1 , where a single step of a gradient descent initialized with a uniform distribution is used,



https://github.com/iclr2021-anonymous-author/learning-to-generate-wasserstein-barycenters See https://www.kernel-operations.io/geomloss/_auto_examples/optimal_ transport/plot_wasserstein_barycenters_2D.html

