CONTINUOUS WASSERSTEIN-2 BARYCENTER ESTIMATION WITHOUT MINIMAX OPTIMIZATION

Abstract

Wasserstein barycenters provide a geometric notion of the weighted average of probability measures based on optimal transport. In this paper, we present a scalable algorithm to compute Wasserstein-2 barycenters given sample access to the input measures, which are not restricted to being discrete. While past approaches rely on entropic or quadratic regularization, we employ input convex neural networks and cycle-consistency regularization to avoid introducing bias. As a result, our approach does not resort to minimax optimization. We provide theoretical analysis on error bounds as well as empirical evidence of the effectiveness of the proposed approach in low-dimensional qualitative scenarios and high-dimensional quantitative experiments.

1. INTRODUCTION

Wasserstein barycenters have become popular due to their ability to represent the average of probability measures in a geometrically meaningful way. Techniques for computing Wasserstein barycenters have been successfully applied to many computational problems. In image processing, Wasserstein barycenters are used for color and style transfer (Rabin et al., 2014; Mroueh, 2019) , and texture synthesis (Rabin et al., 2011) . In geometry processing, shape interpolation can be done by computing barycenters (Solomon et al., 2015) . In online machine learning, barycenters are used for aggregating probabilistic predictions of experts (Korotin et al., 2019b) . Within the context of Bayesian inference, the barycenter of subset posteriors converges to the full data posterior, thus enabling efficient computational methods based on finding the barycenters (Srivastava et al., 2015; 2018) . Fast and accurate barycenter algorithms exist for discrete distributions (see Peyré et al. (2019) for a survey), while for continuous distributions the situation is more difficult and remains unexplored until recently (Li et al., 2020; Fan et al., 2020; Cohen et al., 2020) . The discrete methods scale poorly with the number of support points of the barycenter and thus cannot approximate continuous barycenters well, especially in high dimensions. In this paper, we present a method to compute Wasserstein-2 barycenters of continuous distributions based on a novel regularized dual formulation where the convex potentials are parameterized by input convex neural networks (Amos et al., 2017) . Our algorithm is straightforward without introducing bias (e.g. Li et al. (2020) ) or requiring minimax optimization (e.g. Fan et al. (2020)) . This is made possible by combining a new congruence regularizing term combined with cycle-consistency regularization (Korotin et al., 2019a) . As we will show in the analysis, thanks to the properties of Wasserstein-2 distances, the gradients of the resulting convex potentials "push" the input distributions close to the true barycenter, allowing good approximation of the barycenter.

2. PRELIMINARIES

We denote the set of all Borel probability measures on R D with finite second moment by P 2 (R D ). We use P 2,ac (R D ) ⊂ P 2 (R D ) to denote the subset of all absolutely continuous measures (w.r.t. the Lebesgue measure). Wasserstein-2 distance. For P, Q ∈ P 2 (R D ), the Wasserstein-2 distance is defined by W 2 2 (P, Q) def = min π∈Π(P,Q) R D ×R D x -y 2 2 dπ(x, y), where Π(P, Q) is the set of probability measures on R D × R D whose marginals are P, Q, respectively. This definition is known as Kantorovich's primal form of transport distance (Kantorovitch, 1958) . The Wasserstein-2 distance W 2 is well-studied in the theory of optimal transport (Brenier, 1991; McCann et al., 1995) . In particular, it has a dual formulation (Villani, 2003) : W 2 2 (P, Q) = R D x 2 2 dP(x) + R D y 2 2 dQ(y) -min ψ∈Conv R D ψ(x)dP(x) + R D ψ(y)dQ(y) , where the minimum is taken over all the convex functions (potentials) ψ : R D → R ∪ {∞}, and Fenchel, 1949) , which is also a convex function. The optimal potential ψ * is defined up to an additive constant. Brenier (1991) shows that if P does not give mass to sets of dimensions at most D -1, then the optimal plan π is uniquely determined by π * = [id R D , T * ] P, where T * : R D → R D is the unique solution to the Monge's problem ψ(y) = max x∈R D x, y -ψ(x) : R D → R ∪ {∞} is the convex conjugate of ψ ( T * = arg min T P=Q R D x -T (x) 2 2 dP(x). The connection between T * and the dual formulation (2) is that T * = ∇ψ * , where ψ * is the optimal solution of (2). Additionally, if Q does not give mass to sets of dimensions at most D -1, then T * is invertible and T * (x) = ∇ψ * (x) = (∇ψ * ) -1 (x), (T * ) -1 (y) = ∇ψ * (y) = (∇ψ * ) -1 (y). In particular, the above discussion applies to the case where P, Q ∈ P 2,ac (R D ). Wasserstein-2 barycenter. Let P 1 , . . . , P N ∈ P 2,ac (R D ). Then, their barycenter w.r.t. weights α 1 , . . . , α N (α n > 0 and N n=1 α n = 1) is P def = arg min P∈P2(R D ) N n=1 α n W 2 2 (P n , P). Throughout this paper, we assume that at least one of P 1 , . . . , P N ∈ P 2,ac (R D ) has bounded density. Under this assumption, P is unique and absolutely continuous, i.e., P ∈ P 2,ac (R D ), and it has bounded density (Agueh & Carlier, 2011, Definition 3.6 & Theorem 5.1) . For n ∈ {1, 2, . . . , N }, let (ψ * n , ψ * n ) be the optimal pair of (mutually) conjugate potentials that transport P n to P, i.e., ∇ψ * n P n = P and ∇ψ * n P = P n . Then {ψ * n } satisfy N n=1 α n ∇ψ * n (x) = x and N n=1 α n ψ * n (x) = x 2 2 + c. for all x ∈ R D (Agueh & Carlier, 2011; Álvarez-Esteban et al., 2016) . Since optimal potentials are defined up to a constant, for convenience, we set c = 0. The condition (5) serves as the basis for our algorithm for computing Wasserstein-2 barycenters. We say that potentials ψ 1 , . . . , ψ N are congruent w.r.t. weights α 1 , . . . , α n if their conjugate potentials satisfy (5), i.e., D n=1 α n ψ n (x) = x 2 2 for all x ∈ R D .

3. RELATED WORK

Most algorithms in the field of computational optimal transport are designed for the discrete setting where the input distributions have finite support; see the recent survey by Peyré et al. (2019) for discussion. A particular popular line of algorithms are based on entropic regularization that gives rise to the famous Sinkhorn iteration (Cuturi, 2013; Cuturi & Doucet, 2014) . These methods are typically limited to a support of 10 5 -10 6 points before the problem becomes computationally infeasible. Similarly, discrete barycenter methods (Cuturi & Doucet, 2014) , particularly the ones that rely on a fixed support for the barycenter (Dvurechenskii et al., 2018; Staib et al., 2017) , cannot provide precise approximation of continuous barycenters in high dimensions, since a large number of samples is needed; see experiments in Fan et al. (2020, §4.3 ) for an example. Thus we focus on the existing literature in the continuous setting. Computation of Wasserstein-2 distances and maps. Genevay et al. (2016) demonstrate the possibility of computing Wasserstein distances given only sample access to the distributions by parameterizing the dual potentials as functions in the reproducing kernel Hilbert spaces. Based on this realization, Seguy et al. (2017) propose a similar method but use neural networks to parameterize the potentials, using entropic or L 2 regularization w.r.t. P × Q to keep the potentials approximately conjugate. The transport map is recovered from optimized potentials via barycentric projection. As we note in §2, W 2 enjoys many useful theoretical properties. For example, the optimal potential ψ * is convex, and the corresponding optimal transport map is given by ∇ψ * . By exploiting these properties, Makkuva et al. (2019) propose a minimax optimization algorithm for recovering transport maps, using input convex neural networks (ICNNs) (Amos et al., 2017) to approximate the potentials. An alternative to entropic regularization is the cycle-consistency regularization proposed by Korotin et al. (2019a) . It uses the property that the gradients of optimal dual potentials are inverses of each other. The imposed regularizer requires integration only over the marginal measures P and Q, instead of over P × Q as required by entropy-based alternatives. Their method converges faster than the minimax method since it does not have an inner optimization cycle. Xie et al. (2019) propose using two generative models with a shared latent space to implicitly compute the optimal transport correspondence between P and Q. Based on the obtained correspondence, the authors are able to compute the optimal transport distance between the distributions. Computation of Wasserstein-2 barycenters. A few recent techniques tackle the barycenter problem (4) using continuous rather than discrete approximations of the barycenter: • MEASURE-BASED (GENERATIVE) OPTIMIZATION: Problem (4) optimizes over probability measures. This can be done using the generic algorithm by Cohen et al. (2020) who employ generative networks to compute barycenters w.r.t. arbitrary discrepancies. They test their method with the maximum mean discrepancy (MMD) and Sinkhorn divergence. This approach suffers from the usual limitations of generative models such as mode collapse. Applying it to W 2 barycenters requires estimation of W 2 2 (P n , P). Fan et al. (2020) test this approach using the minimax method by Makkuva et al. ( 2019), but they end up with a challenging min-max-min problem. • POTENTIAL-BASED OPTIMIZATION: Li et al. (2020) recover the optimal potentials {ψ * n } via a non-minimax regularized dual formulation. No generative model is needed: the barycenter is recovered by pushing forward measures using gradients of potentials or by barycentric projection.

4. METHODS

Inspired by Li et al. (2020) we use a potential-based approach and recover the barycenter by using gradients of the potentials as pushforward maps. The main differences are: (1) we restrict the potentials to be convex, (2) we enforce congruence via a regularizing term, and (3) our formulation does not introduce bias, meaning the optimal solution of our formulation gives the true barycenter.

4.1. DERIVING THE DUAL PROBLEM

Let P be the true barycenter. Our goal is to recover the optimal potentials {ψ * n , ψ * n } mapping the input measures P n into P. To start, we express the barycenter objective (4) after substituting the dual formulation (2): N n=1 α n W 2 2 (P n , P) = N n=1 α n R D x 2 2 dP n (x) + R D y 2 2 dP(y) - min {ψn}∈Conv N n=1 α n R D ψ n (x)dP n (x) + N n=1 α n R D ψ n (y)dP(y) The minimum is attained not just among convex potentials {ψ n }, but among congruent potentials (see discussion under ( 5)); thus, we can add the constraint that {ψ n } are congruent to (6). Hence, N n=1 α n W 2 2 (P n , P) = N n=1 α n R D x 2 2 dP n (x) -min {ψn} congruent N n=1 α n R D ψ n (y)dP n (y) MultiCorr({αn,Pn}|{ψn}) . To transition from ( 6) to ( 7), we used the fact that for congruent {ψ n } we have N n=1 α n ψ n (x) = x 2 2 , so N n=1 R D α n ψ n (y)dP(y) = R D y 2 2 dP(y). We call the value inside the minimum in ( 7) the multiple correlation of {P n } with weights {α n } w.r.t. potentials {ψ n }. Notice that the true barycenter P appears nowhere on the right side of ( 7). Thus the optimal potentials {ψ * n } can be recovered by solving the following min {ψn} congruent MultiCorr({α n , P n }|{ψ n }) = min {ψn} congruent N n=1 α n R D ψ n (y)dP n (y) .

4.2. IMPOSING THE CONGRUENCE CONDITION

It is challenging to impose the congruence condition on convex potentials. What if we relax the congruence condition? The following theorem bounds how close a set of convex potentials {ψ n } is to {ψ * n } in terms of the difference of multiple correlation. Theorem 4.1. Let P ∈ P 2,ac (R D ) be the barycenter of P 1 , . . . , P N ∈ P 2,ac (R D ) w.r.t. weights α 1 , . . . , α N . Let {ψ * n } be the optimal congruent potentials of the barycenter problem. Suppose we have B-smoothfoot_0 convex potentials {ψ n } for some B ∈ [0, +∞], and denote ∆ = MultiCorr({α n , P n } | {ψ n }) -MultiCorr({α n , P n } | {ψ * n }). Then, ∆ + R D N n=1 α n ψ n (y) - y 2 2 dP(y) Congruence mismatch ≥ 1 2B N n=1 α n ∇ψ * n (x) -∇ψ n (x) 2 Pn . ( ) Here • µ denotes the norm induced by inner product in Hilbert space L 2 (R D → R D , µ). We call the second term on the left of (9) the congruence mismatch. We prove this in Appendix B. Note that if the congruence mismatch is non-positive, then ∆ ≥ 1 2B N n=1 α n ∇ψ * n (x) -∇ψ n (x) 2 Pn ≥ 1 B N n=1 α n W 2 2 (∇ψ n P n , P), where the last inequality of (10) follows from (Korotin et al., 2019a, Lemma A.2) . From (10), we conclude that for all n ∈ {1, . . . , N }, we have W 2 2 (∇ψ n P n , P) ≤ B∆ αn . This shows that if the congruence mismatch is non-positive, then ∆, the difference in multiple correlation, provides an upper bound for the Wasserstein-2 distance between the true barycenter and each pushforward ∇ψ n P n . This justifies the use of ∇ψ n P n to recover the barycenter. Notice for optimal potentials, the congruence mismatch is zero. Thus to penalize positive congruence mismatch, we introduce a regularizing term R P 1 ({α n }, {ψ n }) def = R D N n=1 α n ψ n (y) - y 2 2 + dP(y). ( ) Because we take the positive part of the integrand of (9) to get (11) and that the right side of ( 9) is non-negative, we have MultiCorr({α n , P n } | {ψ n }) + 1 • R P 1 ({α n }, {ψ n }) -MultiCorr({α n , P n } | {ψ * n }) ≥ 0 for all convex potentials {ψ n }. On the other hand, for optimal potentials {ψ n } = {ψ * n }, the inequality turns into equality, implying that adding the regularizing term 1 • R P 1 ({α n }, {ψ n }) to (8) will not introduce bias -the optimal solution still yields {ψ * n }. However, evaluating (11) exactly requires knowing the true barycenter P a priori. To remedy this issue, one may replace P with another absolutely continuous measure τ • P (τ ≥ 1 and P is a probability measure) whose density bounds that of P from above almost everywhere. In this case, τ • R P 1 ({α n }, {ψ n }) = τ • R D N n=1 α n ψ n (y) - y 2 2 + d P ≥ R P 1 ({α n }, {ψ n }). Hence we obtain the following regularized version of ( 8) where {ψ * n } is the optimal solution: min {ψn}∈Conv MultiCorr({α n , P n } | {ψ n }) + τ • R P 1 ({α n }, {ψ n }) . ( ) Selecting a measure τ • P is not obvious. Consider the case when {P n } are supported on compact sets X 1 , . . . , X N ⊂ R D and P 1 has density upper bounded by h < ∞. In this scenario, the barycenter density is upper bounded by h • α -D 1 (Álvarez-Esteban et al., 2016, Remark 3.2). Thus, the measure τ • P supported on ConvexHull(X 1 , . . . , X N ) with this density is an upper bound for P. We will address the question of how to choose τ, P properly in practice in §4.4.

4.3. ENFORCING CONJUGACY OF POTENTIALS PAIRS

Throughout this subsection, we assume the upper bound finite measure τ • P of the P is known. The optimization problem (13) involves not only the potentials {ψ n }, but also their conjugates {ψ n }. This brings practical difficulty since evaluating conjugate potentials is hard (Korotin et al., 2019a) . Instead we parameterize potentials ψ n and ψ n separately using input convex neural networks (ICNN) as ψ † n and ψ ‡ n respectively. We add an additional cycle-consistency regularizer to enfore the conjugacy of ψ † n and ψ ‡ n as in Korotin et al. (2019a) . This regularizer is defined as R Pn 2 (ψ † n , ψ ‡ n ) def = R D ∇ψ ‡ n • ∇ψ † n (x) -x 2 2 dP n (x) = ∇ψ ‡ n • ∇ψ † n -id R D 2 Pn . Note that R Pn 2 (ψ † n , ψ ‡ n ) = 0 this condition is necessary for ψ † n and ψ ‡ n to be conjugate with each other. Also, it is a sufficient condition for convex functions to be conjugates up to an additive constant. We use one-sided regularization. In our case, computing the regularizer of the other direction Korotin et al. (2019a) demonstrates that such one-sided condition is sufficient. ∇ψ † n • ∇ψ ‡ n -id R D 2 P is infeasible, since P is unknown. If fact, In this way we use 2N input convex neural networks for {ψ † n , ψ ‡ n }. By adding the new cycle consistency regularizer into (13), we obtain our final objective: min {ψ † n ,ψ ‡ n } Approximate multiple correlation N n=1 α n R D [ x, ∇ψ † n (x) -ψ ‡ n (∇ψ † n (x)) ≈ψ ‡ n (x) ]dP n (x) +τ •R P 1 ({ψ ‡ n }) Congruence reg. +λ N n=1 α n R Pn 2 (ψ † n , ψ ‡ n ) Cycle regularizer . Note that we express the aproximate multiple correlation by using both potentials {ψ † n } and {ψ ‡ n }. This is done to eliminate the freedom of an additive constant on {ψ † n } that is not addressed by cycle regularization. We denote the entire objective as MultiCorr {P n } | {ψ † }, {ψ ‡ }; τ, P, λ . Analogous to Theorem 4.1, we have following result showing that this new objective enjoys the same properties as the unregularized version from (8). Theorem 4.2. Let P ∈ P 2,ac (R D ) be the barycenter of P 1 , . . . , P N ∈ P 2,ac (R D ) w.r.t. weights α 1 , . . . , α N . Let {ψ * n } be the optimal congruent potentials of the barycenter problem. Suppose we have τ, P such that τ ≥ 1 and τ • P ≥ P. Suppose we have convex potentials {ψ † n } and β ‡ -strongly convex and B ‡ -smooth convex potentials {ψ ‡ n } with 0 < β ‡ ≤ B ‡ < ∞ and λ > B † 2(β ‡ ) 2 . Then MultiCorr {α n , P n } | {ψ † n }, {ψ ‡ n }; τ, P, λ ≥ MultiCorr {α n , P n } | {ψ * n } . ( ) Denote ∆ = MultiCorr {α n , P n } | {ψ † n }, {ψ ‡ n }; τ, P, λ -MultiCorr {α n , P n } | {ψ * n } . Then for all n ∈ {1, . . . , N }, we have W 2 2 ∇ψ † n P n , P ≤ 2∆ α n • 1 β ‡ + 1 λ(β ‡ ) 2 -B † 2 2 = O(∆). Informally, Theorem 4.2 states that the better we solve the regularized dual problem, ( 14) the closer we expect each ∇ψ † n P n to be to the true barycenter P in W 2 . It follows from (15) that our final objective ( 14) is unbiased: the optimal solution is obtained by {ψ * n , ψ * n }.

4.4. PRACTICAL ASPECTS AND OPTIMIZATION PROCEDURE

In practice, even if the choice of τ, P does not satisfy τ • P ≥ P, we observe the pushforward measures ∇ψ † n P n often converge to P. To partially bridge the gap between theory and practice, we dynamically update the measure P so that after each optimization step we set (for γ ∈ [0, 1]) P := γ • P + (1 -γ) • N n=1 α n • ∇ψ † P n , i.e., the probability measure P is a mixture of the given initial measure P and the current barycenter estimates {∇ψ † P n }. For the initial P one may use the barycenter of {N (µ Pn , Σ Pn )}. It can be efficiently computed via an iterative fixed point algorithm (Álvarez-Esteban et al., 2016; Chewi et al., 2020) . During the optimization, these estimates become closer to the true barycenter and can thus improve the congruence regularizer (12). We use mini-batch stochastic gradient descent to solve (14) where the integration is done by Monte-Carlo sampling from input measures {P n } and regularization measure P, similar to Li et al. (2020) . We provide the detailed optimization procedure (Algorithm 1) and discuss its computational complexity in Appendix A. In Appendix C.3, we demonstrate that the impact of the considered regularization on our model: we show that cycle consistency and the congruence condition of the potentials are well satisfied.

5. EXPERIMENTS

The code is written on PyTorch framework and is publicly available at https://github.com/iamalexkorotin/Wasserstein2Barycenters. We compare our method [CW 2 B] with the potential-based method [CRWB] by Li et al. (2020) (with Wasserstein-2 distance and L 2 -regularization) and with the measure-based generative method [SCW 2 B] by Fan et al. (2020) . All considered methods recover 2N potentials {ψ † n , ψ ‡ n } ≈ {ψ * n , ψ * n } and approximate the barycenter as pushforward measures {∇ψ † n P n }. Regularization in [CRWB] allows access to the joint density of the transport plan, a feature of their method that we do not consider here. The method [SCW 2 B] additionally outputs a generated barycenter g S ≈ P where g is the generative network and S is the input noise distribution. To assess the quality of the computed barycenter, we consider the unexplained variance percentage defined as UVP( P) = 100 W 2 2 ( P,P) 1 /2Var(P) %. When UVP ≈ 0%, P is a good approximation of P. For values ≥ 100%, the distribution P is undesirable: a trivial baseline P 0 = δ E P [y] achieves UVP(P 0 ) = 100%. Evaluating UVP in high dimensions is infeasible: empirical estimates of W 2 2 are unreliable due to high sample complexity (Weed et al., 2019) . To overcome this issue, for barycenters given by ∇ψ † n P n we use L 2 -UVP defined by L 2 -UVP(∇ψ † n , P n ) def = 100 ∇ψ † n -∇ψ * n 2 Pn Var(P) % ≥ UVP(∇ψ † n P n ) , ( ) where the inequality in brackets follows from (Korotin et al., 2019a, Lemma A.2 ). We report the weighted average of L 2 -UVP of all pushforward measures w.r.t. the weights α n . For barycenters given in an implicit form g S, we compute the Bures-Wasserstein UVP defined by BW 2 2 -UVP(g S) def = 100 BW 2 2 (g S, P) 1 2 Var(P) % ≤ UVP(g S) , ( ) where BW 2 2 (P, Q) = W 2 2 N (µ P , Σ P ), N (µ Q , Σ Q ) is the Bures-Wasserstein metric and we use µ P , Σ P to denote the mean and the covariance of a distribution P (Chewi et al., 2020) . It is known that BW 2 2 lower-bounds W 2 2 (Dowson & Landau, 1982) , so the inequality in the brackets of (18) follows. A detailed discussion of the adopted metrics is given in Appendix C.2. In this section, we consider N = 4 with (α 1 , . . . , α 4 ) = (0.1, 0.2, 0.3, 0.4) as weights. We consider the location-scatter family of distributions (Álvarez-Esteban et al., 2016, §4) whose true barycenter can be computed. Let P 0 ∈ P 2,ac and define the following location-scatter family of distributions

5.1. HIGH-DIMENSIONAL LOCATION-SCATTER EXPERIMENTS

F(P 0 ) = {f S,u P 0 | S ∈ M + D×D , u ∈ R D } , where f S,u : R D → R D is a linear map f S,u (x) = Sx + u with positive definite matrix S ∈ M + D×D . When {P n } ⊂ F(P 0 ), their barycenter P is also an element of F(P 0 ) and can be computed via fixed-point iterations (Álvarez-Esteban et al., 2016) . Figure 1a shows a 2-dimensional location-scatter family generated by using the Swiss roll distribution as P 0 . The true barycenter is shown in Figure 1b . The generated barycenter g S of [SCW 2 B] is given in Figure 1c . The pushforward measures ∇ψ † n P n of each method are provided in Figures 1d, 1e , 1f, respectively. In this example, the pushforward measures ∇ψ n P n all reasonably approximate P, whereas the generated barycenter g S of [SCW 2 B] (Figure 1c ) visibly underfits. For quantitative comparison, we consider two choices for P 0 : the D-dimensional standard Gaussian distribution and the uniform distribution on [- √ 3, + √ 3] D . Each P n is constructed as f S T n ΛSn,0 P 0 ∈ F(P 0 ), where S n is a random rotation matrix and Λ is diagonal with entries [ 1 2 b 0 , 1 2 b 1 , . . . , 2] where b = D-1 √ 4. We consider only centered distributions (i.e. zero mean) because the barycenter of non-centered {P n } ∈ P 2,ac (R D ) is the barycenter of {P n } shifted by N n=1 α n µ Pn , where {P n } are centered copies of {P n } (Álvarez- Esteban et al., 2016) . Results are shown in Table 1 and 2. In these experiments, our method outperforms [CRWB] and [SCW 2 B]. For [CRWB], dimension ∼ 16 is the breakpoint: the method does not scale well to higher dimensions. [SCW 2 B] scales with the increasing dimension better, but its errors L 2 -UVP and BW 2 2 -UVP are twice as high as ours. This is likely due to the generative approximation and the difficult min-max-min optimization in [SCW 2 B]. For completeness, we also compare our algorithm to the proposed in Cuturi & Doucet (2014) which approximates the barycenter by a discrete distribution on a fixed number of free-support points. In our experiment, similar to Li et al. (2020) , we set 5000 as the support size. As expected, the BW 2 2 -UVP error of the method increases drastically as the dimension grows and the method is outperformed by our approach. To show the scalability of our method with the number of input distributions N , we conduct an analogous experiment with a high-dimensional location-scatter family for N = 20. We set α n = 2n N (N +1) for n = 1, 2, ..., 20 and choose the uniform distribution on [-√ 3, + √ 3] D as P 0 and construct distributions P n ∈ F(P 0 ) as before. The results for dimensions 32, 64 and 128 are provided in Table 3 . Similar to the results from Tables 1 and 2 , we see that our method outperforms the alternatives.  = Uniform [- √ 3, + √ 3] D , N = 20.

5.2. SUBSET POSTERIOR AGGREGATION

We apply our method to aggregate subset posterior distributions. The barycenter of subset posteriors converges to the true posterior (Srivastava et al., 2018) . Thus, computing the barycenter of subset posteriors is an efficient alternative to obtaining a full posterior in the big data setting (Srivastava et al., 2015; Staib et al., 2017; Li et al., 2020) . Analogous to (Li et al., 2020) , we consider Poisson and negative binomial regressions for predicting the hourly number of bike rentals using features such as the day of the week and weather conditions. 2We consider the posterior on the 8-dimensional regression coefficients for both Poisson and negative binomial regressions. We randomly split the data into N = 5 equally-sized subsets and obtain 10 5 samples from each subset posterior using the Stan library (Carpenter et al., 2017) . This gives the discrete uniform distributions {P n } supported on the samples. As the ground truth barycenter P, we consider the full dataset posterior also consisting of 10 5 points. We use BW 2 2 -UVP( P, P) to compare the estimated barycenter P (pushforward measure ∇ψ † n P n or generated measure g S) with the true barycenter. The results are in Table 4 . All considered methods perform well (UVP< 2%), but our method outperforms the alternatives. Regression SCW2B, (Fan et al., 2020) [CRWB], (Li et al., 

5.3. COLOR PALETTE AVERAGING

For qualitative study, we apply our method to aggregating color palettes of images. For an RGB image I, its color palette is defined by the discrete uniform distribution P(I) of all its pixels ∈ [0, 1] 3 . For 3 images {I n } we compute the barycenter P of each color palette P n = P(I n ) w.r.t. uniform weights α n = 1 3 . We apply each computed potential ∇ψ † n pixel-wise to I n to obtain the "pushforward" image ∇ψ † n I n . These "pushforward" images should be close to the barycenter P of {P n }. The results are provided in Figure 2 . Note that the image ∇ψ † 1 I 1 inherits certain attributes of images I 2 and I 3 : the sky becomes bluer and the trees becomes greener. On the other hand, the sunlight in images ∇ψ † 2 I 2 , ∇ψ † 3 I 3 has acquired an orange tint, thanks to the dominance of orange in I 1 . to the size of the computational graph. Hence, gradient computation requires computational time proportional to the time for evaluating the function h θ (x) itself. The same holds when computing the derivative with respect to x. Then, for instance, computing the term ∇ψ ‡ n • ∇ψ † n (x) in ( 14) takes O(T ) time. The gradient of this term with respect to θ also takes O(T ) time: Hessian-vector products that appear can be calculated in O(T ) time using the famous Hessian trick, see Pearlmutter (1994) . In practice, we compute all the gradients using automatic differentiation. We empirically measured that for our DenseICNN potentials, the computation of their gradient w.r.t. input x, i.e., ∇ψ † (x), requires roughly 3-4x more time than the computation of ψ † (x).

B PROOFS

In this section, we prove our main Theorems 4.1 and 4.2. We use L 2 (R D → R D , µ) to denote the Hilbert space of functions f : R D → R D with integrable square w.r.t. a probability measure µ. The corresponding inner product for f 1 , f 2 ∈ L 2 (R D → R D , µ) is denoted by f 1 , f 2 µ def = R D f 1 (x), f 2 (x) dµ(x), where f 1 (x), f 2 (x) is the Euclidean dot product. We use • µ = •, • µ to denote the norm induced by the inner product in L 2 (R D → R D , µ). We also recall a useful property of lower semi-continuous convex function ψ : R D → R: ∇ψ(x) = arg max y∈R D y, x -ψ(y) , which follows from the fact that ŷ = arg max y∈R D y, x -ψ(y) ⇐⇒ x -∇ψ(ŷ) = 0. We begin with the proof of Theorem 4.1. Proof. We consider the difference between the estimated correlations and true ones: ∆ = N n=1 α n R D ψ n (x)dP n (x) - N n=1 α n R D ψ * n (x)dP n (x) = N n=1 α n R D ∇ψ n (x), x -ψ n ∇ψ n (x)) dP n (x) - N n=1 α n R D ∇ψ * n (x), x -ψ * n ∇ψ * n (x)) dP n (x), where we twice use ( 19) for f = ψ n and f = ψ * n . We note that N n=1 α n R D ∇ψ * n (x), x dP n (x) = N n=1 α n R D y, ∇ψ * n (y) dP(y) = R D y, N n=1 α n ∇ψ * n (y) dP(y) = R D y, y dP(y) = id R D 2 P , where we use of change-of-variable formula for ∇ψ * n P n = P and (5). Analogously,  N n=1 α n R D ψ * n ∇ψ * n (x))dP n (x) = N n=1 α n R D ψ * n y)dP(y) = R D N n=1 α n ψ * n y)dP(y) = R D y 2 2 dP(y) = 1 2 id R D 2 P .

C.2 METRICS

The unexplained variance percentage (UVP) (introduced in Section 5) is a natural and straightforward metric to assess the quality of the computed barycenter. However, it is difficult to compute in high dimensions: it requires computation of the Wasserstein-2 distance. Thus, we use different but highly related metrics L 2 -UVP and BW 2 2 -UVP. To access the quality of the recovered potentials {ψ † n } we use L 2 -UVP defined in (17). L 2 -UVP compares not just pushforward distribution ∇ψ † n P n with the barycenter P, but also the resulting transport map with the optimal transport map ∇ψ * n . It bounds UVP(∇ψ † n P n ) from above, thanks to (Korotin et al., 2019a, Lemma A.2) . Besides, L 2 -UVP naturally admits unbiased Monte Carlo estimates using random samples from P n . For measure-based optimization method, we also evaluate the quality of the generated measure g S using Bures-Wasserstein UVP defined in (18). For measures P, Q whose covariance matrices are not degenerate, BW 2 2 is given by BW 2 2 (P, Q) = 1 2 µ P -µ Q 2 + 1 2 Tr Σ P + 1 2 Tr Σ Q -Tr(Σ 1 2 P Σ Q Σ 1 2 P ) 1 2 . Bures-Wasserstein metric compares P, Q by considering only their first and second moments. It is known that BW 2 2 (P, Q) is a lower bound for W 2 2 (P, Q), see (Dowson & Landau, 1982) . Thus, we have BW 2 2 -UVP(g S) ≤ UVP(g S). In practice, to compute BW 2 2 -UVP(g S), we estimate means and covariance matrices of distributions by using 10 5 random samples.

C.3 CYCLE CONSISTENCY AND CONGRUENCE IN PRACTICE

To assess the effect of the regularization of cycle consistency and the congruence condition in practice, we run the following sanity checks. For cycle consistency, for each input distribution P n we estimate (by drawing samples from P n ) the value ∇ψ ‡ n • ∇ψ † n (x) -x 2 Pn /Var(P n ). This metric can be viewed as an analog of the L 2 -UVP that we used for assessing the resulting transport maps. In all the experiments, this value does not exceed 2%, which means that cycle consistency and hence conjugacy are satisfied well. For the congruence condition, we need to check that N n=1 α n ψ † n (x) = x 2 /2. However, we do not know any straightforward metric to check this exact condition that is scaled properly by the variance of the distributions. Thus, we propose to use an alternative metric to check a slightly weaker condition on gradients, e.g., that N n=1 α n ∇ψ † n (x) = x. This is weaker due to the ambiguity of the additive constants. For this we can compute N n=1 α n ∇ψ † n (x) -x 2 P /Var(P), where the denominator is



We say that a diffirentiable function f : R D → R is B-smooth if its gradient ∇f is B-Lipschitz. http://archive.ics.uci.edu/ml/datasets/Bike+Sharing+Dataset



Figure 1: Barycenter of location-scatter Swiss roll population computed by three methods.

(a) Original images {In}. (b) Color palettes {Pn} of original images. (c) Images with averaged color palette {∇ψ † n In}. (d) Barycenter palettes {∇ψ † n Pn}.

Figure 2: Results of our method applied to averaging color palettes of images.

Figure 4: Barycenter of a two 2D Gaussian mixtures.

Comparison of UVP for the case {P n } ⊂ F(P 0 ), P 0 = N (0, I D ), N = 4.

Comparison of UVP for the case {P n } ⊂ F(P 0 ), P 0 = Uniform [-

Comparison of UVP for the case {P n } ⊂ F(P 0 ), P 0

Comparison of UVP for recovered barycenters in our subset posterior aggregation task.

ACKNOWLEDGMENTS

The Skoltech Advanced Data Analytics in Science and Engineering Group acknowledges the support of Russian Foundation for Basic Research grant 20-01-00203, Skoltech-MIT NGP initiative and thanks the Skoltech CDISE HPC Zhores cluster staff for computing cluster provision. The MIT Geometric Data Processing group acknowledges the generous support of Army Research Office grant W911NF2010168, of Air Force Office of Scientific Research award FA9550-19-1-031, of National Science Foundation grant IIS-1838071, from the CSAIL Systems that Learn program, from the MIT-IBM Watson AI Laboratory, from the Toyota-CSAIL Joint Research Center, from a gift from Adobe Systems, from an MIT.nano Immersion Lab/NCSOFT Gaming Program seed grant, and from the Skoltech-MIT Next Generation Program.

A THE ALGORITHM

The numerical procedure for solving our final objective ( 14) is given below.Algorithm 1: Numerical Procedure for Optimizing Multiple Correlations (14) Input :Distributions P 1 , . . . , P N with sample access;Weights α 1 , . . . , α N ≥ 0 with N n=1 α n = 1; Regularization distribution P given by a sampler; Congruence regularizer coefficient τ ≥ 1; Balancing coefficient γ ∈ [0, 1]; Cycle-consistency regularizer coefficient λ > 0; 2N ICNNs {ψ θn , ψ ωn }; Batch size K > 0; for t = 1, 2, . . . do 1. Sample batches X n ∼ P n for all n = 1, . . . , N ; 2. Compute the pushforwards Y n = ∇ψ θn X n for all n = 1, . . . , N ; 3. Sample batch Y 0 ∼ P; 4. Compute the Monte-Carlo estimate of the congruence regularizer:where γ 0 = γ and γ n = α n • (1 -γ) for n = 1, 2, . . . , N ; 5. Compute the Monte-Carlo estimate of the cycle-consistency regularizer:6. Compute the Monte-Carlo estimate of multiple correlations:7. Compute the total loss:8. Perform a gradient step over {θ n , ω n } by using ∂LTotal ∂{θn,ωn} ; end Parametrization of the potentials. To parametrize potentials {ψ θn , ψ ωn }, we use DenseICNN (dense input convex neural network) with quadratic skip connections; see (Korotin et al., 2019a, Appendix B.2) . As an initialization step, we pre-train the potentials to satisfySuch pre-training provides a good start for the networks: each ψ θn is approximately conjugate to the corresponding ψ ωn . On the other hand, the initial networks {ψ θn } are approximate congruent according to (5).Computational Complexity. For a single training iteration, the time complexity of both forward (evaluation) and backward (computing the gradient with respect to the parameters) passes through the objective function ( 14) is O(N T ). Here N is the number of input distributions and T is the time taken by evaluating each individual potential (parameterized as a neural network) on a batch of points sampled from either P n or P. This claim follows from the well-known fact that gradient evaluation ∇ θ h θ (x) of h θ : R D → R, when parameterized as a neural network, requires time proportionalSince each ψ n is B-smooth, we conclude that ψ n is 1 B -strongly convex, see (Kakade et al., 2009) . Thus, we haveor equivalentlyWe integrate (24) w.r.t. P n and sum over n = 1, 2, . . . , N with weights α n :-We note thatNow we substitute ( 25), ( 26), ( 21) and ( 22) into (20) to obtain (9).Next, we prove Theorem 4.2. Kakade et al., 2009) . Thus, for all x, x ∈ R D :and obtain:Since the function ψ ‡ n is B ‡ -smooth, we have for all x ∈ R D :that is equivalent to:We combine ( 28) with ( 27) to obtainFor every n = 1, 2, . . . , N we integrate (29) w.r.t. P n and sum up the corresponding cycle-consistency regularization term:We sum (30) for n = 1, 2, . . . , N w.r.t. weights α n to obtain:We add τ • R P 1 ({ψ ‡ n }) to both sides of (31) to getWe substract MultiCorr({α n , P n } | {ψ ‡ n }) from both sides and use Theorem 4.1 to obtainIn transition from (33) to (34), we explot the fact that the sum of the first term of (32) with the regularizer τ34) we immediately conclude ∆ ≥ 0; i.e., the multiple correlations upper bound (15) holds true. On the other hand, for every n = 1, 2, . . . , N we haveWe combine the second part of ( 35) with ( 27) integrated w.r.t. P n :Published as a conference paper at ICLR 2021 Finally, we use the triangle inequality for • Pn and concludei.e.,where the first inequality follows from (Korotin et al., 2019a, Lemma A.2 ).

C EXPERIMENTAL DETAILS AND EXTRA RESULTS

In this section, we provide experimental details and additional results. In Subsection C.1, we demonstrate qualitative results of computed barycenters in the 2-dimensional space. In Subsection C.2, we discuss used metrics in more detail. In Subsection C.4, we list the used hyperparameters of our method (CW 2 B) and methods [SCW 2 B], [CRWB] .

C.1 ADDITIONAL TOY EXPERIMENTS IN 2D

We provide additional qualitative examples of computed barycenters of probability measures on R 2 .In Figure 3 , we consider the location-scatter family F(P 0 ) withIn principle, all the methods capture the true barycenter. However, the generated distribution g S of [SCW 2 B] (Figure 3c ) provides samples that lies outside of the actual barycenter's support (Figure 3b ). Also, in [CRWB] method, one of the potentials' pushforward measure (top-right in Figure 3e ) has visual artifacts. In Figure 4 , we consider the Gaussian Mixture example by (Fan et al., 2020) . The barycenter computed by [SCW 2 B] method (Figure 4b ) suffers from the behavior similar to mode collapse.the variance of the true barycenter. We computed this metric and found that it is also less than 2% in all the cases, which means that congruence condition is mostly satisfied.

C.4 TRAINING HYPERPARAMETERS

The code is written using the PyTorch framework. The networks are trained on a single GTX 1080Ti.C.4.1 WASSERSTEIN-2 CONTINUOUS BARYCENTERS (CW 2 B, OUR METHOD)Regularization. We use τ = 5 and P = N (0, I D ) in our congruence regularizer τ • R P 1 . We use λ = 10 for the cycle regularization λ • R Pn 2 for all n = 1, 2, . . . , N .Neural Networks (Potentials). To approximate potentials {ψ † n , ψ ‡ n } in dimension D, we use DenseICNN[2; max(64, 2D), max(64, 2D), max(32, D)] with CELU activation function. DenseICNN is an input-convex dense architecture with additional convex quadratic skip connections. Here 2 is the rank of each input-quadratic skip-connection's Hessian matrix. Each following number max(•, •) represents the size of a hidden dense layer in the sequantial part of the network. For detailed discussion of the architecture see (Korotin et al., 2019a, Section B.2) .Training process. We perform training according to Algorithm 1 of Appendix A. We set batch size K = 1024 and balancing coefficient γ = 0.2. We use Adam optimizer by (Kingma & Ba, 2014) with a fixed learning rate 10 -3 . The total number of iterations is set to 50000.

C.4.2 SCALABLE COMPUTATION OF WASSERSTEIN BARYCENTERS (SCW 2 B)

Generator Neural Network. For the input noise distribution of the generative model we use S = N (0, I D ). For the generative network g : R D → R D we use a fully-connected sequential ReLU network with hidden layer sizes [max(100, 2D), max(100, 2D), max(100, 2D)]. Before the main optimization, we pre-train the network to satisfy g(z) ≈ z for all z ∈ R D . This has been empirically verified as a better option than random initialization of network's weights.Neural Networks (Potentials). We used exactly the same networks as in Subsection C.4.1.Training process. We perform training according to the min-max-min procedure described by (Fan et al., 2020, Algorithm 1) . The batch size is set to 1024. We use Adam optimizer by (Kingma & Ba, 2014) with fixed learning rate 10 -3 for potentials and 10 -4 for generative network g. The number of iterations of the outer cycle (min-max-min) number of iterations is set to 15000. Following (Fan et al., 2020) , we use 10 iterations per the middle cycle (min-max-min) and 6 iterations per the inner cycle (min-max-min).

C.4.3 CONTINUOUS REGULARIZED WASSERSTEIN BARYCENTERS (CRWB)

Regularization.[CRWB] method uses regularization to keep the potentials conjugate. The authors impose entropy or L 2 regularization w.r.t. some proposal measure P; see (Li et al., 2020, Section 3) for more details. Following the source code provided by the authors, we use L 2 regularization (empirically shown as a more stable option than entropic regularization). The regularization measure P is set to be the uniform measure on a box containing the support of all the source distributions, estimated by sampling. The regularization parameter is set to 10 -4 .Neural Networks (Potentials). To approximate potentials {ψ † n , ψ ‡ n } in dimension D, we use fullyconnected sequential ReLU neural networks with layer sizes given by [max(128, 4D), max(128, 4D), max(128, 4D)]. We have also tried using DenseICNN architecture, but did not experience any performance gain.Training process. We perform training according to (Li et al., 2020, Algorithm 1) . We set batch size to 1024. We use Adam optimizer by (Kingma & Ba, 2014) with fixed learning rate 10 -3 . The total number of iterations is set to 50000.

