CONTINUOUS WASSERSTEIN-2 BARYCENTER ESTIMATION WITHOUT MINIMAX OPTIMIZATION

Abstract

Wasserstein barycenters provide a geometric notion of the weighted average of probability measures based on optimal transport. In this paper, we present a scalable algorithm to compute Wasserstein-2 barycenters given sample access to the input measures, which are not restricted to being discrete. While past approaches rely on entropic or quadratic regularization, we employ input convex neural networks and cycle-consistency regularization to avoid introducing bias. As a result, our approach does not resort to minimax optimization. We provide theoretical analysis on error bounds as well as empirical evidence of the effectiveness of the proposed approach in low-dimensional qualitative scenarios and high-dimensional quantitative experiments.

1. INTRODUCTION

Wasserstein barycenters have become popular due to their ability to represent the average of probability measures in a geometrically meaningful way. Techniques for computing Wasserstein barycenters have been successfully applied to many computational problems. In image processing, Wasserstein barycenters are used for color and style transfer (Rabin et al., 2014; Mroueh, 2019) , and texture synthesis (Rabin et al., 2011) . In geometry processing, shape interpolation can be done by computing barycenters (Solomon et al., 2015) . In online machine learning, barycenters are used for aggregating probabilistic predictions of experts (Korotin et al., 2019b) . Within the context of Bayesian inference, the barycenter of subset posteriors converges to the full data posterior, thus enabling efficient computational methods based on finding the barycenters (Srivastava et al., 2015; 2018) . Fast and accurate barycenter algorithms exist for discrete distributions (see Peyré et al. (2019) for a survey), while for continuous distributions the situation is more difficult and remains unexplored until recently (Li et al., 2020; Fan et al., 2020; Cohen et al., 2020) . The discrete methods scale poorly with the number of support points of the barycenter and thus cannot approximate continuous barycenters well, especially in high dimensions. In this paper, we present a method to compute Wasserstein-2 barycenters of continuous distributions based on a novel regularized dual formulation where the convex potentials are parameterized by input convex neural networks (Amos et al., 2017) . Our algorithm is straightforward without introducing bias (e.g. Li et al. ( 2020)) or requiring minimax optimization (e.g. Fan et al. ( 2020)). This is made possible by combining a new congruence regularizing term combined with cycle-consistency regularization (Korotin et al., 2019a) . As we will show in the analysis, thanks to the properties of Wasserstein-2 distances, the gradients of the resulting convex potentials "push" the input distributions close to the true barycenter, allowing good approximation of the barycenter.

2. PRELIMINARIES

We denote the set of all Borel probability measures on R D with finite second moment by P 2 (R D ). We use P 2,ac (R D ) ⊂ P 2 (R D ) to denote the subset of all absolutely continuous measures (w.r.t. the Lebesgue measure). Wasserstein-2 distance. For P, Q ∈ P 2 (R D ), the Wasserstein-2 distance is defined by W 2 2 (P, Q) def = min π∈Π(P,Q) R D ×R D x -y 2 2 dπ(x, y), where Π(P, Q) is the set of probability measures on R D × R D whose marginals are P, Q, respectively. This definition is known as Kantorovich's primal form of transport distance (Kantorovitch, 1958) . The Wasserstein-2 distance W 2 is well-studied in the theory of optimal transport (Brenier, 1991; McCann et al., 1995) . In particular, it has a dual formulation (Villani, 2003): W 2 2 (P, Q) = R D x 2 2 dP(x) + R D y 2 2 dQ(y) -min ψ∈Conv R D ψ(x)dP(x) + R D ψ(y)dQ(y) , where the minimum is taken over all the convex functions (potentials) ψ : R D → R ∪ {∞}, and Fenchel, 1949) , which is also a convex function. The optimal potential ψ * is defined up to an additive constant. ψ(y) = max x∈R D x, y -ψ(x) : R D → R ∪ {∞} is the convex conjugate of ψ ( Brenier (1991) shows that if P does not give mass to sets of dimensions at most D -1, then the optimal plan π is uniquely determined by π * = [id R D , T * ] P, where T * : R D → R D is the unique solution to the Monge's problem T * = arg min T P=Q R D x -T (x) 2 2 dP(x). The connection between T * and the dual formulation (2) is that T * = ∇ψ * , where ψ * is the optimal solution of (2). Additionally, if Q does not give mass to sets of dimensions at most D -1, then T * is invertible and T * (x) = ∇ψ * (x) = (∇ψ * ) -1 (x), (T * ) -1 (y) = ∇ψ * (y) = (∇ψ * ) -1 (y). In particular, the above discussion applies to the case where P, Q ∈ P 2,ac (R D ). Wasserstein-2 barycenter. Let P 1 , . . . , P N ∈ P 2,ac (R D ). Then, their barycenter w.r.t. weights α 1 , . . . , α N (α n > 0 and N n=1 α n = 1) is P def = arg min P∈P2(R D ) N n=1 α n W 2 2 (P n , P). Throughout this paper, we assume that at least one of P 1 , . . . , P N ∈ P 2,ac (R D ) has bounded density. Under this assumption, P is unique and absolutely continuous, i.e., P ∈ P 2,ac (R D ), and it has bounded density (Agueh & Carlier, 2011, Definition 3.6 & Theorem 5.1). For n ∈ {1, 2, . . . , N }, let (ψ * n , ψ * n ) be the optimal pair of (mutually) conjugate potentials that transport P n to P, i.e., ∇ψ * n P n = P and ∇ψ * n P = P n . Then {ψ * n } satisfy N n=1 α n ∇ψ * n (x) = x and N n=1 α n ψ * n (x) = x 2 2 + c. for all x ∈ R D (Agueh & Carlier, 2011; Álvarez-Esteban et al., 2016) . Since optimal potentials are defined up to a constant, for convenience, we set c = 0. The condition (5) serves as the basis for our algorithm for computing Wasserstein-2 barycenters. We say that potentials ψ 1 , . . . , ψ N are congruent w.r.t. weights α 1 , . . . , α n if their conjugate potentials satisfy (5), i.e., D n=1 α n ψ n (x) = x 2 2 for all x ∈ R D .

