THE UNBALANCED GROMOV WASSERSTEIN DIS-TANCE: CONIC FORMULATION AND RELAXATION

Abstract

Comparing metric measure spaces (i.e. a metric space endowed with a probability distribution) is at the heart of many machine learning problems. This includes for instance predicting properties of molecules in quantum chemistry or generating graphs with varying connectivity. The most popular distance between such metric measure spaces is the Gromov-Wasserstein (GW) distance, which is the solution of a quadratic assignment problem. This distance has been successfully applied to supervised learning and generative modeling, for applications as diverse as quantum chemistry or natural language processing. The GW distance is however limited to the comparison of metric measure spaces endowed with a probability distribution. This strong limitation is problematic for many applications in ML where there is no a priori natural normalization on the total mass of the data. Furthermore, imposing an exact conservation of mass across spaces is not robust to outliers and often leads to irregular matching. To alleviate these issues, we introduce two Unbalanced Gromov-Wasserstein formulations: a distance and a more tractable upper-bounding relaxation. They both allow the comparison of metric spaces equipped with arbitrary positive measures up to isometries. The first formulation is a positive and definite divergence based on a relaxation of the mass conservation constraint using a novel type of quadratically-homogeneous divergence. This divergence works hand in hand with the entropic regularization approach which is popular to solve large scale optimal transport problems. We show that the underlying non-convex optimization problem can be efficiently tackled using a highly parallelizable and GPU-friendly iterative scheme. The second formulation is a distance between mm-spaces up to isometries based on a conic lifting. Lastly, we provide numerical simulations to highlight the salient features of the unbalanced divergence and its potential applications in ML.

1. INTRODUCTION

Comparing data distributions on different metric spaces is a basic problem in machine learning. This class of problems is for instance at the heart of surfaces (Bronstein et al., 2006) or graph matching (Xu et al., 2019) (equipping the surface or graph with its associated geodesic distance), regression problems in quantum chemistry (Gilmer et al., 2017) (viewing the molecules as distributions of points in R 3 ) and natural language processing (Grave et al., 2019; Alvarez-Melis & Jaakkola, 2018) (where texts in different languages are embedded as points distributions in different vector spaces). Metric measure spaces. The mathematical way to formalize these problems is to model the data as metric measure spaces (mm-spaces). A mm-space is denoted as X = (X, d, µ) where X is a complete separable set endowed with a distance d and a positive Borel measure µ ∈ M + (X). For instance, if X = (x i ) i is a finite set of points, then µ = i m i δ xi (here δ xi is the Dirac mass at x i ) is simply a set of positive weights m i = µ({x i }) ≥ 0 associated to each point x i , which accounts for its mass or importance. For instance, setting some m i to 0 is equivalent to removing the point x i . We refer to Sturm (2012) for a mathematical account on the theory of mm-spaces. In all the applications highlighted above, it makes sense to perform the comparisons up to isometric transformations of the data. Two mm-spaces X = (X, d X , µ) and Y = (Y, d Y , ν) are considered to be equal (denoted X ∼ Y) if they are isometric, meaning that there is a bijection ψ : spt(µ) → spt(ν) (where spt(µ) is the support of µ) such that d X (x, y) = d Y (ψ(x), ψ(y)) and ψ µ = ν. Here ψ is the push-forward operator, so that ψ µ = ν is equivalent to imposing ν(A) = µ(ψ -1 (A)) for any set A ⊂ Y . For discrete spaces where µ = i m i δ xi , then one should have ν = ψ µ = i m i δ ψ(xi) . As highlighted by Mémoli ( 2011), considering mm-spaces up to isometry is a powerful way to formalize and analyze a wide variety of problems such as matching, regression and classification of distributions of points belonging to different spaces. The key to unlock all these problems is the computation of a distance between mm-spaces up to isometry. So far, existing distances (reviewed below) assume that µ is a probability distribution, i.e. µ(X) = 1. This constraint is not natural and sometimes problematic for most of the practical applications to machine learning. The goal of this paper is to alleviate this restriction. We define for the first time a class of distances between unbalanced metric measure spaces, these distances being upper-bounded by divergences which can be approximated by an efficient numerical scheme. Csiszár divergences The simplest case is when X = Y and one simply ignores the underlying metric. One can then use Csiszár divergences (or ϕ-divergences), which perform a pointwise comparison (in contrast with optimal transport distances, which perform a displacement comparison). It is defined using an entropy function ϕ : R + → [0, +∞], which is a convex, lower semi-continuous, positive function with ϕ(1) = 0. The Csiszár ϕ-divergence reads D ϕ (µ|ν) def. = X ϕ dµ dν dν + ϕ ∞ X dµ ⊥ , where µ = dµ dν ν + µ ⊥ is the Lebesgue decomposition of µ with respect to ν and ϕ ∞ = lim r→∞ ϕ(r)/r ∈ R ∪ {+∞} is called the recession constant. This divergence D ϕ is convex, positive, 1-homogeneous and weak* lower-semicontinuous, see Liero et al. (2015) for details. Particular instances of ϕ-divergences are Kullback-Leibler (KL) for ϕ(r) = r log(r) -r + 1 (note that ϕ ∞ = ∞) and Total Variation (TV) for ϕ(r) = |r -1|. Balanced and unbalanced optimal transport. If the common embedding space X is equipped with a distance d(x, y), one can use more elaborated methods such as optimal transport (OT) distances, which are computed by solving convex optimization problems. This type of methods has proven useful for ML problems as diverse as domain adaptation (Courty et al., 2014) , supervised learning over histograms (Frogner et al., 2015) and unsupervised learning of generative models (Arjovsky et al., 2017) . In this case, the extension from probability distributions to arbitrary positive measures (µ, ν) ∈ M + (X) 2 is now well understood and corresponds to the theory of unbalanced OT. Following Liero et al. (2015) ; Chizat et al. (2018c) , a family of unbalanced Wasserstein distances is defined by solving UW(µ, ν) q def. = inf π∈M(X×X) λ(d(x, y))dπ(x, y) + D ϕ (π 1 |µ) + D ϕ (π 2 |µ). Here (π 1 , π 2 ) are the two marginals of the joint distribution π, defined by π 1 (A) = π(A × Y ) for A ⊂ X. The mapping λ : R + → R and exponent q ≥ 1 should be chosen wisely to ensure for instance that UW defines a distance (see Section 2.2.1). It is frequent to take ρD ϕ instead of D ϕ (i.e. take ψ = ρϕ) to adjust the strength of the marginals' penalization. Balanced OT is retrieved with the convex indicator ϕ = ι {1} or by taking the limit ρ → +∞, which enforces π 1 = µ and π 2 = ν. When 0 < ρ < +∞, unbalanced OT operates a trade-off between transportation and creation of mass, which is crucial to be robust to outliers in the data and to cope with mass variations in the modes of the distributions. For supervised tasks, the value of ρ should be cross-validated to obtain the best performances. Its use is gaining popularity in applications, such as medical imaging registration (Feydy et al., 2019a ), videos (Lee et al., 2019 ), generative learning (Balaji et al., 2020) and gradient flow to train neural networks (Chizat & Bach, 2018; Rotskoff et al., 2019) . Furthermore, existing efficient algorithms for balanced OT extend to this unbalanced problem. In particular Sinkhorn's iterations, introduced in ML for balanced OT by Cuturi (2013), extend to unbalanced OT (Chizat et al., 2018a; Séjourné et al., 2019) , as detailed in Section 3. The Gromov-Wasserstein distance and its applications. The Gromov-Wasserstein (GW) distance (Mémoli, 2011; Sturm, 2012) generalizes the notion of OT to the setting of mm-spaces up to isometries. It corresponds to replacing the linear cost λ(d)dπ of OT by a quadratic function GW(X , Y) q def. = min π∈M+(X×Y ) λ(|d X (x, x ) -d Y (y, y )|)dπ(x, y)dπ(x , y ) : π 1 = µ π 2 = ν ,

