LEARNING TO GENERATE WASSERSTEIN BARYCEN-TERS

Abstract

Optimal transport is a notoriously difficult problem to solve numerically, with current approaches often remaining intractable for very large scale applications such as those encountered in machine learning. Wasserstein barycenters -the problem of finding measures in-between given input measures in the optimal transport sense -is even more computationally demanding. By training a deep convolutional neural network, we improve by a factor of 60 the computational speed of Wasserstein barycenters over the fastest state-of-the-art approach on the GPU, resulting in milliseconds computational times on 512 × 512 regular grids. We show that our network, trained on Wasserstein barycenters of pairs of measures, generalizes well to the problem of finding Wasserstein barycenters of more than two measures. We validate our approach on synthetic shapes generated via Constructive Solid Geometry as well as on the "Quick, Draw" sketches dataset.

1. INTRODUCTION

Optimal transport is becoming widespread in machine learning, but also in computer graphics, vision and many other disciplines. Its framework allows for comparing probability distributions, shapes or images, as well as producing interpolations of these data. As a result, it has been used in the context of machine learning as a loss for training neural networks (Arjovsky et al., 2017) , as a manifold for dictionary learning (Schmitz et al., 2018 ), clustering (Mi et al., 2018) and metric learning applications (Heitz et al., 2019) , as a way to sample an embedding (Liutkus et al., 2019) and transfer learning (Courty et al., 2014) , and many other applications (see Sec. 2.3). However, despite recent progress in computational optimal transport, in many cases these applications have remained limited to small datasets due to the substantial computational cost of optimal transport, in terms of speed, but also memory. We tackle the problem of efficiently computing Wasserstein barycenters of measures discretized on regular grids, a setting common to several of these machine learning applications. Wasserstein barycenters are interpolations of two or more probability distributions under optimal transport distances. As such, a common way to obtain them is to perform a minimization of a functional involving optimal transport distances or transport plans, which is thus a very costly process. Instead, we directly predict Wasserstein barycenters by training a Deep Convolutional Neural Network (DCNN) specific to this task. An important challenge behind our work is to build an architecture that can handle a variable number of input measures with associated weights without needing to retrain a specific network. To achieve that, we specify and adapt an architecture designed for and trained with two input measures, and show that we can use this modified network with no retraining to compute barycenters of more than two measures. Directly predicting Wasserstein barycenters avoids the need to compute a Wasserstein embedding (Courty et al., 2017) , and our experiments suggest that this results in better Wasserstein barycenters approximations. Our implementation is publicly availablefoot_0 . Contributions This paper introduces a method to compute Wasserstein barycenters in milliseconds. It shows that this can be done by learning Wasserstein barycenters of only two measures on a dataset of random shapes using a DCNN, and by adapting this DCNN to handle multiple input



https://github.com/iclr2021-anonymous-author/learning-to-generate-wasserstein-barycenters 1

