

Abstract

We study the use of amortized optimization to predict optimal transport (OT) maps from the input measures, which we call Meta OT. This helps repeatedly solve similar OT problems between different measures by leveraging the knowledge and information present from past problems to rapidly predict and solve new problems. Otherwise, standard methods ignore the knowledge of the past solutions and suboptimally re-solve each problem from scratch. We instantiate Meta OT models in discrete and continuous (Wasserstein-2) settings between images, spherical data, and color palettes and use them to improve the computational time of standard OT solvers by multiple orders of magnitude.

1. INTRODUCTION

Optimal transportation (Villani, 2009; Ambrosio, 2003; Santambrogio, 2015; Peyré et al., 2019; Merigot and Thibert, 2021) is thriving in domains including economics (Galichon, 2016) , reinforcement learning (Dadashi et al., 2021; Fickinger et al., 2021) , style transfer (Kolkin et al., 2019) , generative modeling (Arjovsky et al., 2017; Seguy et al., 2018; Huang et al., 2020; Rout et al., 2021) , geometry (Solomon et al., 2015; Cohen et al., 2021) , domain adaptation (Courty et al., 2017; Redko et al., 2019) , signal processing (Kolouri et al., 2017) , fairness (Jiang et al., 2020) , and cell reprogramming (Schiebinger et al., 2019) . A core component in these settings is to couple two measures (α, β) supported on domains (X , Y) by solving a transport optimization problem such as the primal Kantorovich problem, which is defined by: π (α, β, c) ∈ arg min π∈U (α,β) X ×Y c(x, y)dπ(x, y), where the optimal coupling π is a joint distribution over the product space, U(α, β) is the set of admissible couplings between α and β, and c : X × Y → R is the ground cost, that represents a notion of distance between elements in X and elements in Y. Challenges. Unfortunately, solving eq. ( 1) once is computationally expensive between general measures and computationally cheaper alternatives are an active research topic: Entropic optimal transport (Cuturi, 2013) smooths the transport problem with an entropy penalty, and sliced distances (Kolouri et al., 2016; 2018; 2019; Deshpande et al., 2019) solve OT between 1-dimensional projections of the measures, where eq. ( 1) can be solved easily. Furthermore, when an optimal transport method is deployed in practice, eq. ( 1) is not just solved a single time, but is repeatedly solved for new scenarios between different input measures (α, β). For example, the measures could be representations of images we care about optimally transporting between and in deployment we would receive a stream of new images to couple. Repeatedly solving optimal transport problems also comes up in the context of comparing seismic signals (Engquist and Froese, 2013) and in single-cell perturbations (Bunne et al., 2021; 2022b; a) . Standard optimal transport solvers deployed in this setting would re-solve the optimization problems from scratch, but this ignores the shared structure and information present between different coupling problems. Overview and outline. We study the use of amortized optimization and machine learning methods to rapidly solve multiple optimal transport problems and predict the solution from the input measures (α, β). This setting involves learning a meta model to predict the solution to the optimal transport problem, which we will refer to as Meta Optimal Transport. We learn Meta OT models to predict the solutions to optimal transport problems and significantly improve the computational time and number of iterations needed to solve eq. ( 1) between discrete (sect. 3.1) and continuous (sect. 3.2) measures. The paper is organized as follows: sect. 2 recalls the main concepts needed for the rest of the paper, in particular the formulations of the entropy regularized and unregularized optimal transport problems and the basic notions of amortized optimization; sect. 3 presents the Meta OT models and algorithms; and sect. 4 empirically demonstrates the effectiveness of Meta OT. Settings that are not Meta OT. Meta OT is not useful in OT settings that do not involve repeatedly solving OT problems over a fixed distribution, including 1) standard generative modeling settings, such as Arjovsky et al. ( 2017) that estimate the OT distance between the data and model distributions, and 2) the out-of-sample setting of Seguy et al. ( 2018); Perrot et al. ( 2016) that couple measures and then extrapolate the map to larger measures containing the original measures.

2.1. DUAL OPTIMAL TRANSPORT SOLVERS

We review foundations of optimal transportation, following the notation of Peyré et al. ( 2019) in most places. The discrete setting often favors the entropic regularized version since it can be computed efficiently and in a parallelized way using the Sinkhorn algorithm. On the other hand, the continuous setting is often solved from samples using convex potentials. While the primal Kantorovich formulation in eq. ( 1) provides an intuitive problem description, optimal transport problems are rarely solved directly in this form due to the high-dimensionality of the couplings π and the difficulty of satisfying the coupling constraints U(α, β). Instead, most computational OT solvers use the dual of eq. ( 1), which we build our Meta OT solvers on top of in discrete and continuous settings.

2.1.1. ENTROPIC OT BETWEEN DISCRETE MEASURES WITH THE SINKHORN ALGORITHM

Algorithm 1 Sinkhorn(α, β, c, , f0 = 0) for iteration i = 1 to N do gi ← log b -log K exp{fi-1/ } fi ← log a -log (K exp{gi/ }) end for Compute PN from fN , gN using eq. ( 6 where P is a coupling matrix, P (α, β) is the optimal coupling, and the cost can be discretized as a matrix C ∈ R m×n with entries C i,j := c(x i , y j ), and C, P := i,j C i,j P i,j , Entropic OT. The linear program above can be regularized adding the entropy of the coupling to smooth the objective as in Cominetti and Martín (1994) ; Cuturi (2013), resulting in: P (α, β, c, ) ∈ arg min P ∈U (a,b) C, P -H(P ) where H(P ) :=i,j P i,j (log(P i,j ) -1) is the discrete entropy of a coupling matrix P . Entropic OT dual. As presented in Peyré et al. (2019, Prop. 4.4) , the dual of eq. ( 4) is f , g ∈ arg max f ∈R n ,g∈R m f, a + g, b -exp{f / }, K exp{g/ } , K i,j := exp{-C i,j / }, where K ∈ R m×n is the Gibbs kernel and the dual variables or potentials f ∈ R n and g ∈ R m are associated, respectively, with the marginal constraints P 1 m = a and P 1 n = b. The optimal duals depend on the problem, e.g. f (α, β, c, ), but we omit this dependence for notational simplicity. Recovering the primal solution from the duals. Given optimal duals f , g that solve eq. ( 5) the optimal coupling P to the primal problem in eq. ( 4) can be obtained by P i,j (α, β, c, ) := exp{f i / }K i,j exp{g j / } (K is defined in eq. ( 5)) (6)



) return PN ≈ P Let α := m i=1 a i δ xi and β := n i=1 b i δ yi be discrete measures, where δ z is a Dirac at point z and a ∈ ∆ m-1 and b ∈ ∆ n-1 are in the probability simplex defined by ∆ k-1 := {x ∈ R k : x ≥ 0 and i xi = 1}. (2) Discrete OT. In the discrete setting, eq. (1) simplifies to the linear program P (α, β, c) ∈ arg min P ∈U (a,b) C, P U (a, b) := {P ∈ R n×m + : P 1 m = a, P 1 n = b} (3)

