

Abstract

We study the use of amortized optimization to predict optimal transport (OT) maps from the input measures, which we call Meta OT. This helps repeatedly solve similar OT problems between different measures by leveraging the knowledge and information present from past problems to rapidly predict and solve new problems. Otherwise, standard methods ignore the knowledge of the past solutions and suboptimally re-solve each problem from scratch. We instantiate Meta OT models in discrete and continuous (Wasserstein-2) settings between images, spherical data, and color palettes and use them to improve the computational time of standard OT solvers by multiple orders of magnitude.

1. INTRODUCTION

Optimal transportation (Villani, 2009; Ambrosio, 2003; Santambrogio, 2015; Peyré et al., 2019; Merigot and Thibert, 2021) is thriving in domains including economics (Galichon, 2016) , reinforcement learning (Dadashi et al., 2021; Fickinger et al., 2021) , style transfer (Kolkin et al., 2019) , generative modeling (Arjovsky et al., 2017; Seguy et al., 2018; Huang et al., 2020; Rout et al., 2021) , geometry (Solomon et al., 2015; Cohen et al., 2021) , domain adaptation (Courty et al., 2017; Redko et al., 2019) , signal processing (Kolouri et al., 2017) , fairness (Jiang et al., 2020) , and cell reprogramming (Schiebinger et al., 2019) . A core component in these settings is to couple two measures (α, β) supported on domains (X , Y) by solving a transport optimization problem such as the primal Kantorovich problem, which is defined by: π (α, β, c) ∈ arg min π∈U (α,β) X ×Y c(x, y)dπ(x, y), (1) where the optimal coupling π is a joint distribution over the product space, U(α, β) is the set of admissible couplings between α and β, and c : X × Y → R is the ground cost, that represents a notion of distance between elements in X and elements in Y. Challenges. Unfortunately, solving eq. ( 1) once is computationally expensive between general measures and computationally cheaper alternatives are an active research topic: Entropic optimal transport (Cuturi, 2013) smooths the transport problem with an entropy penalty, and sliced distances (Kolouri et al., 2016; 2018; 2019; Deshpande et al., 2019) solve OT between 1-dimensional projections of the measures, where eq. ( 1) can be solved easily. Furthermore, when an optimal transport method is deployed in practice, eq. ( 1) is not just solved a single time, but is repeatedly solved for new scenarios between different input measures (α, β). For example, the measures could be representations of images we care about optimally transporting between and in deployment we would receive a stream of new images to couple. Repeatedly solving optimal transport problems also comes up in the context of comparing seismic signals (Engquist and Froese, 2013) and in single-cell perturbations (Bunne et al., 2021; 2022b; a) . Standard optimal transport solvers deployed in this setting would re-solve the optimization problems from scratch, but this ignores the shared structure and information present between different coupling problems. Overview and outline. We study the use of amortized optimization and machine learning methods to rapidly solve multiple optimal transport problems and predict the solution from the input measures (α, β). This setting involves learning a meta model to predict the solution to the optimal transport problem, which we will refer to as Meta Optimal Transport. We learn Meta OT models to predict the solutions to optimal transport problems and significantly improve the computational time and number of iterations needed to solve eq. ( 1) between discrete (sect. 3.1) and continuous (sect. 3.2) measures. The paper is organized as follows: sect. 2 recalls the main concepts needed for the rest of the paper, in particular the formulations of the entropy regularized and unregularized optimal transport problems and the basic notions of amortized optimization; sect. 3 presents the Meta OT models and algorithms; and sect. 4 empirically demonstrates the effectiveness of Meta OT. Settings that are not Meta OT. Meta OT is not useful in OT settings that do not involve repeatedly solving OT problems over a fixed distribution, including 1) standard generative modeling settings, such as Arjovsky et al. (2017) that estimate the OT distance between the data and model distributions, and 2) the out-of-sample setting of Seguy et al. (2018) ; Perrot et al. (2016) that couple measures and then extrapolate the map to larger measures containing the original measures.

2.1. DUAL OPTIMAL TRANSPORT SOLVERS

We review foundations of optimal transportation, following the notation of Peyré et al. (2019) in most places. The discrete setting often favors the entropic regularized version since it can be computed efficiently and in a parallelized way using the Sinkhorn algorithm. On the other hand, the continuous setting is often solved from samples using convex potentials. While the primal Kantorovich formulation in eq. ( 1) provides an intuitive problem description, optimal transport problems are rarely solved directly in this form due to the high-dimensionality of the couplings π and the difficulty of satisfying the coupling constraints U(α, β). Instead, most computational OT solvers use the dual of eq. ( 1), which we build our Meta OT solvers on top of in discrete and continuous settings.

2.1.1. ENTROPIC OT BETWEEN DISCRETE MEASURES WITH THE SINKHORN ALGORITHM

Algorithm 1 Sinkhorn(α, β, c, , f0 = 0) for iteration i = 1 to N do gi ← log b -log K exp{fi-1/ } fi ← log a -log (K exp{gi/ }) end for Compute PN from fN , gN using eq. ( 6) return PN ≈ P Let α := m i=1 a i δ xi and β := n i=1 b i δ yi be discrete measures, where δ z is a Dirac at point z and a ∈ ∆ m-1 and b ∈ ∆ n-1 are in the probability simplex defined by ∆ k-1 := {x ∈ R k : x ≥ 0 and i xi = 1}. (2) Discrete OT. In the discrete setting, eq. ( 1) simplifies to the linear program P (α, β, c) ∈ arg min P ∈U (a,b) C, P U (a, b) := {P ∈ R n×m + : P 1 m = a, P 1 n = b} (3) where P is a coupling matrix, P (α, β) is the optimal coupling, and the cost can be discretized as a matrix C ∈ R m×n with entries C i,j := c(x i , y j ), and C, P := i,j C i,j P i,j , Entropic OT. The linear program above can be regularized adding the entropy of the coupling to smooth the objective as in Cominetti and Martín (1994) ; Cuturi (2013) , resulting in: P (α, β, c, ) ∈ arg min P ∈U (a,b) C, P -H(P ) where H(P ) :=i,j P i,j (log(P i,j ) -1) is the discrete entropy of a coupling matrix P . Entropic OT dual. As presented in Peyré et al. (2019, Prop. 4.4) , the dual of eq. ( 4) is f , g ∈ arg max f ∈R n ,g∈R m f, a + g, b -exp{f / }, K exp{g/ } , K i,j := exp{-C i,j / }, where K ∈ R m×n is the Gibbs kernel and the dual variables or potentials f ∈ R n and g ∈ R m are associated, respectively, with the marginal constraints P 1 m = a and P 1 n = b. The optimal duals depend on the problem, e.g. f (α, β, c, ), but we omit this dependence for notational simplicity. Recovering the primal solution from the duals. Given optimal duals f , g that solve eq. ( 5) the optimal coupling P to the primal problem in eq. ( 4) can be obtained by P i,j (α, β, c, ) := exp{f i / }K i,j exp{g j / } (K is defined in eq. ( 5)) (6) The Sinkhorn algorithm. Algorithm 1 summarizes the log-space version, which takes closed-form block coordinate ascent updates on eq. ( 5) obtained from the first-order optimality conditions (Peyré et al., 2019, Remark 4.21) . We will use it to fine-tune predictions made by our Meta OT models. Computing the error. Standard implementations of the Sinkhorn algorithm, such as Flamary et al. (2021) ; Cuturi et al. (2022) , measure the error of a candidate dual solution (f, g) by computing the deviation from the marginal constraints, which we will also use in comparing our solution quality: err(f, g; α, β, c) := P 1 m -a 1 + P 1 n -b 1 (compute P from eq. ( 6)) (7) Mapping between the duals. The first-order optimality conditions of eq. ( 5) also provide an equivalence between the optimal dual potentials that we will make use of:  g(f ; b, c) := log b -log K exp{f / } . ) := ∇xψϕ N (•) ≈ T (•) Let α and β be continuous measures in Euclidean space X = Y = R d (with α absolutely continuous with respect to the Lebesgue measure) and the ground cost be the squared Euclidean distance c(x, y) := x-y 2 2 . Then the minimum of eq. ( 1) defines the square of the Wasserstein-2 distance: W 2 2 (α, β) := min π∈U (α,β) X ×Y x -y 2 2 dπ(x, y) = min T X x -T (x) 2 2 dα(x), where T is a transport map pushing α to β, i.e. T # α = β with the pushforward operator defined by T # α(B) := α(T -1 (B)) for any measurable set B. Convex dual potentials. The primal form in eq. ( 9) is difficult to solve, as in the discrete setting, due to the difficulty of representing the coupling and satisfying the constraints. Makkuva et al. (2020) ; Taghvaei and Jalali (2019) ; Korotin et al. (2019; 2021b; 2022) propose to instead solve the dual: ψ ( • ; α, β) ∈ arg min ψ∈convex X ψ(x)dα(x) + Y ψ(y)dβ(y), ( ) where ψ is a convex function referred to as a convex potential, and ψ(y) := max x∈X x, y -ψ(x) is the Legendre-Fenchel transform or convex conjugate of ψ (Fenchel, 1949; Rockafellar, 2015) . The potential ψ is often approximated with an input-convex neural network (ICNN) (Amos et al., 2017) . Recovering the primal solution from the dual. Given an optimal dual ψ for eq. ( 10), Brenier (1991) remarkably shows that an optimal map T for eq. ( 9) can be obtained with differentiation: T (x) = ∇ x ψ (x). Wasserstein-2 Generative Networks (W2GNs). Korotin et al. (2019) model ψ ϕ and ψ ϕ in eq. ( 10) with two separate ICNNs parameterized by ϕ. The separate model for ψ ϕ is useful because the conjugate operation in eq. ( 10) becomes computationally expensive. They optimize the loss: L(ϕ) := E x∼α [ψ ϕ (x)] + E y∼β ∇ψ ϕ (y), y -ψ ϕ (∇ψ ϕ (y)) Cyclic monotone correlations (dual objective) +γ E y∼β ∇ψ ϕ • ∇ψ ϕ (y) -y 2 2 , Cycle-consistency regularizer ( ) where ϕ is a detached copy of the parameters and γ is a hyper-parameter. The first term are the cyclic monotone correlations (Chartrand et al., 2009; Taghvaei and Jalali, 2019) , that optimize the dual objective in eq. ( 10), and the second term provides cycle consistency (Zhu et al., 2017) to estimate the conjugate ψ. Algorithm 2 shows how L is typically optimized using samples from the measures, which we use to fine-tune Meta OT predictions. 

Input measures and cost Dual potentials Couplings

Figure 1 : Meta OT uses objective-based amortization for optimal transport. In the general formulation, the parameters θ capture shared structure in the optimal couplings π between multiple input measures and costs over some distribution D. In practice, we learn this shared structure over the dual potentials which map back to the coupling: f in discrete settings and ψ in continuous ones.

2.2. AMORTIZED OPTIMIZATION AND LEARNING TO OPTIMIZE

Our paper is an application of amortized optimization methods that predict the solutions of optimization problems, as surveyed in, e.g., Chen et al. (2021) ; Amos (2022) . We use the basic setup from Amos (2022), which considers unconstrained continuous optimization problems of the form z (φ) ∈ arg min z J(z; φ), where J is the objective, z ∈ Z is the domain, and φ ∈ Φ is some context or parameterization. In other words, the context conditions the objective but is not optimized over. Given a distribution over contexts P(φ), we learn a model ẑθ parameterized by θ to approximate eq. ( 13), i.e. ẑθ (φ) ≈ z (φ). J will be differentiable for us, so we optimize the parameters using objective-based learning with min θ E φ∼P(φ) J(ẑ θ (φ); φ), which does not require ground-truth solutions z and can be optimized with a gradient-based solver. While we focus on optimizing eq. ( 14) because we do not assume easy access to ground-truth solutions z (φ), one alternative is regression-based learning if the solutions are easily available: min θ E φ∼P(φ) z (φ) -ẑθ (φ) 2 2 . ( ) 3 META OPTIMAL TRANSPORT Figure 1 illustrates our key contribution of connecting objective-based amortization in eq. ( 14) to optimal transport. We consider solving multiple OT problems and learning shared structure and correlations between them. We denote a joint meta-distribution over the input measures and costs with D(α, β, c), which we call meta to distinguish it from the measures α, β. In general, we could introduce a model that directly predicts the primal solution to eq. ( 1), i.e. π θ (α, β, c) ≈ π (α, β, c) for (α, β, c) ∼ D. This is difficult for the same reason why most computational methods do not operate directly in the primal space: the optimal coupling is often a highdimensional joint distribution with non-trivial marginal constraints. We instead turn to predicting the dual variables used by today's solvers.

3.1. META OT BETWEEN DISCRETE MEASURES

We build on standard methods for entropic OT reviewed in sect. Amortization objective. We will seek to predict the optimal potential. At optimality, the pair of potentials are related to each other via eq. ( 8), i.e. g(f ; α, β, c) := log b -log K exp{f / } where K ∈ R m×n is the Gibbs kernel from eq. ( 5). Hence, it is sufficient to predict one of the potentials, e.g. f , and recover the other. We thus re-formulate eq. ( 5) to just optimize over f with f (α, β, c, ) ∈ arg min f ∈R n J(f ; α, β, c), where -J(f ; α, β, c) := f, a + g, b -exp{f / }, K exp{g/ } is the (negated) dual objective. Even though most solvers optimize over f and g jointly as in eq. ( 16), amortizing over these would likely need: 1) to have a higher capacity than a model just predicting f , and 2) to learn how f and g are connected through eq. ( 8) while in eq. ( 16) we explicitly provide this knowledge. Amortization model. We predict the solution to eq. ( 16) with fθ (α, β, c) parameterized by θ, resulting in a computationally efficient approximation fθ ≈ f . Here we use the notation fθ (α, β, c) to mean that the model fθ depends on representations of the input measures and cost. In our settings, we define fθ as a fully-connected MLP mapping from the atoms of the measures to the duals. Amortization loss. Applying objective-based amortization from eq. ( 14) to the dual in eq. ( 16) completes our learning setup. Our model should best-optimize the expectation of the dual objective min θ E (α,β,c)∼D J( fθ (α, β, c); α, β, c), which is appealing as it does not require ground-truth solutions f . Algorithm 3 shows a basic training loop for eq. ( 17) using a gradient-based optimizer such as Adam (Kingma and Ba, 2014). Sinkhorn fine-tuning. The dual prediction made by fθ with an associated ĝ can easily be input as the initialization to a standard Sinkhorn solver as shown in algorithm 4. This allows us to deploy the predicted potential with Sinkhorn to obtain the optimal potentials with only a few extra iterations. On accelerated solvers. Here we have only considered fine-tuning the Meta OT prediction with a log-Sinkhorn solver. Meta OT can also be combined with accelerated variants of entropic OT solvers such as We take an analogous approach to predicting the Wasserstein-2 map between continuous measures for Wasserstein-2 as reviewed in sect. 2.1.2. Here the measures α, β are supported in continuous space X = Y = R d and we focus on computing Wasserstein-2 couplings from instances sampled from a meta-distribution (α, β) ∼ D(α, β). The cost c is not included in D as it remains fixed to the squared Euclidean cost everywhere here. One challenge here is that the optimal dual potential ψ ( • ; α, β) in eq. ( 10) is a convex function and not simply a finite-dimensional real vector. The dual potentials in this setting are approximated by, e.g., an ICNN. We thus propose a Meta ICNN that predicts the parameters ϕ of an ICNN ψ ϕ that approximates the optimal dual potentials, which can be seen as a hypernetwork (Stanley et al., 2009; Ha et al., 2016) . The dual prediction made by φθ can easily be input as the initial value to a standard W2GN solver as shown in algorithm 5. App. B discusses other modeling choices we considered: we tried models based on MAML (Finn et al., 2017) and neural processes (Garnelo et al., 2018b; a) . Sinkhorn (converged, ground-truth) α 0 α 1 α 2 Meta OT (initial prediction) We report the mean and standard deviation across 10 test instances. α 0 α 1 α 2 Amortization objective. We build on the W2GN formulation (Korotin et al., 2019) and seek parameters ϕ optimizing the dual ICNN potentials ψ ϕ and ψ ϕ with L(ϕ; α, β) from eq. ( 12). We chose W2GN due to the stability, but could also easily use other losses optimizing ICNN potentials. Amortization model: the Meta ICNN. We predict the solution to eq. ( 12) with φθ (α, β) parameterized by θ, resulting in a computationally efficient approximation to the optimum φθ ≈ ϕ . Figure 3 instantiates a convolutional Meta ICNN model using a ResNet-18 (He et al., 2016) architecture for coupling image-based measures. We again emphasize that α, β used with the model here are representations of measures, which in our cases are simply images. Amortization loss. Applying objective-based amortization from eq. ( 14) to the W2GN loss in eq. ( 12) completes our learning setup. Our model should best-optimize the expectation of the loss: min θ E (α,β)∼D L( φθ (α, β); α, β). ( ) As in the discrete setting, it does not require ground-truth solutions ϕ and we learn it with Adam.

4. EXPERIMENTS

We demonstrate how Meta OT models improve the convergence of the state-of-the-art solvers in settings where solving multiple OT problems naturally arises. We implemented our code in JAX (Bradbury et al., 2018) as an extension to the the Optimal Transport Tools (OTT) package (Cuturi et al., 2022) . App. C covers further experimental and implementation details, and shows that all of our experiments take a few hours to run on our single Quadro GP100 GPU. We will open source the code to reproduce all of our experiments. Figure 4 : Meta OT successfully predicts warm-start initializations that significantly improve the convergence of Sinkhorn iterations on test data. The error is the marginal error defined in eq. ( 7). Sinkhorn (converged, ground-truth) Meta OT (initial prediction) 

4.1. DISCRETE OT BETWEEN MNIST DIGITS

Images provide a natural setting for Meta OT where the distribution over images provide the metadistribution D over OT problems. Given a pair of images α 0 and α 1 , each grayscale image is cast as a discrete measure in 2-dimensional space where the intensities define the probabilities of the atoms. The goal is to compute the optimal transport interpolation between the two measures as in, e.g., Peyré et al. (2019, §7) . Formally, this means computing the optimal coupling P by solving the entropic optimal transport problem between α 0 and α 1 and computing the interpolates as α t = (t proj y +(1 -t) proj x ) # P , for t ∈ [0, 1], where proj x (x, y) := x and proj y (x, y) = y. We selected = 10 -2 as app. A shows that it gives interpolations that are not too blurry or sharp. Our Meta OT model fθ (sect. 3.1) is an MLP that predicts the transport map between pairs of MNIST digits. We train on every pair from the standard training dataset. Figure 2 shows that even without fine-tuning, Meta OT's predicted Wasserstein interpolations between the measures are close to the ground-truth interpolations obtained from running the Sinkhorn algorithm to convergence. We then fine-tune Meta OT's prediction with Sinkhorn as in algorithm 4. Figure 4 shows that the near-optimal predictions can be quickly refined in fewer iterations than running Sinkhorn with the default initialization, and table 1 shows the runtime required to reach an error threshold of 10 -2 , showing that the Meta OT initialization help solve the problems faster by an order of magnitude. We compare our learned initialization to the standard zero initialization, as well as the Gaussian initialization proposed in Thornton and Cuturi (2022) , which takes a continuous Gaussian approximation of the measures and initializes the potentials to be the known coupling between the Gaussians. This Gaussian initialization assumes the squared Euclidean cost, which is not the case in our spherical transport problem, but we find it is still helpful over the zero initialization. We next set up a synthetic transport problem between supply and demand locations where the supply and demands may change locations or quantities frequently, creating another Meta OT setting to be able to rapidly solve the new instances. We specifically consider measures living on the 2-sphere defined by S 2 := {x ∈ R 3 : x = 1}, i.e. X = Y = S 2 , with the transport cost given by the spherical distance c(x, y) = arccos( x, y ). We then randomly sample supply locations uniformly from Earth's landmass and demand locations from Earth's population density to induce a class of transport problems on the sphere obtained from the CC-licensed dataset from Doxsey-Whitfield et al. (2015) . Figure 5 shows that the predicted transport maps on test instances are close to the optimal maps obtained from Sinkhorn to convergence. Similar to the MNIST setting, fig. 4 and table 1 show improved convergence and runtime. The problem of color transfer between two images consists in mapping the color palette of one image into the other one. The images are required to have the same number of channels, for example RGB images. The continuous formulation that we use from Korotin et al. (2019) , takes i.e. X = Y = [0, 1] 3 with c being the squared Euclidean distance. We collected ≈200 public domain images from WikiArt and trained a Meta ICNN model from sect. 3.2 to predict the color transfer maps between every pair of them. Figure 6 shows the predictions on test pairs and fig. 7 shows the convergence in comparison to the standard W2GN learning. Table 2 reports runtimes and app. E shows additional results.

5. RELATED WORK

Efficiently estimating OT maps. To compute OT maps with fixed cost between pairs of measures efficiently, neural OT models (Korotin et al., 2019; 2021a; Mokrov et al., 2021; Korotin et al., 2021b) leverage ICNNs to estimate maps between continuous high-dimensional measures given samples from these, and Litvinenko et al. ( 2021 2021) leverage structural assumptions on coupling and cost matrices to reduce the computational and memory complexity. In the meta-OT setting, we consider learning to rapidly compute OT mappings between new pairs measures. All these works can hence potentially benefit from an acceleration effect by leveraging amortization similarly. Embedding measures where OT distances are discriminative. Effort has been invested in learning encodings/projections of measures through a nested optimization problem, which aims to find discriminative embeddings of the measures to be compared (Genevay et al., 2018; Deshpande et al., 2019; Nguyen and Ho, 2022) . While these works share an encoder and/or a projection across task with the aim of leveraging more discriminative alignments (and hence an OT distance with a metric different from the Euclidean metric), our work differs in the sense that we find good initializations to solve the OT problem itself with fixed cost more efficiently across tasks. Optimal transport and amortization. Few previous works in the OT literature leverage amortization. Courty et al. ( 2018) learn a latent space in which the Wasserstein distance between the measure's embeddings is equivalent to the Euclidean distance. Concurrent work (Nguyen and Ho, 2022) amortizes the estimation of the optimal projection in the max-sliced objective, which differs from our work where we instead amortize the estimation of the optimal coupling directly. Also, Lacombe et al. (2021) learns to predict Wasserstein barycenters of pixel images by training a convolutional networks that, given images as input, outputs their barycenters. Our work is hence a generalization of this pixel-based work to general measures -both discrete and continuous. A limitation of Lacombe et al. ( 2021) is that it does not provide alignments, as the amortization networks predicts the barycenter directly rather than individual couplings.

6. CONCLUSIONS, FUTURE DIRECTIONS, AND LIMITATIONS

We have presented foundations for modeling and learning to solve OT problems with Meta OT by using amortized optimization to predict optimal transport plans. This works best in applications that require solving multiple OT problems with shared structure. We instantiated it to speed up entropic regularized optimal transport and unregularized optimal transport with squared cost by multiple orders of magnitude. We envision extensions of the work in: 1. Meta OT models. While we mostly consider models based on hypernetworks, other metalearning paradigms can be connected in. In the discrete setting, we only considered settings where the cost remains fixed, but the Meta OT model can also be conditioned on the cost by considering the entire cost matrix as an input (which may be too large for most models to handle), or considering a lower-dimensional parameterization of the cost that changes between the Meta OT problem instances. 2. OT algorithms. While we instantiated models on top of log-Sinkhorn and W2GN, Meta OT could be built on top of other methods. 3. OT applications that are computationally expensive and repeatedly solved, e.g. in multimarginal and barycentric settings, or for Gromov-Wasserstein distances between metricmeasure spaces. Limitations. While we have illustrated successful applications of Meta OT, it is also important to understand the limitations: 1) Meta OT does not make previously intractable problems tractable. All of the baseline OT solvers we consider solve our problems within milliseconds or seconds. 2) Out-of-distribution generalization. Meta OT may not generate good predictions on instances that are not close to the training OT problems from the meta-distribution D over the measures and cost. If the model makes a bad prediction, one fallback option is to re-solve the instance from scratch.

REPRODUCIBILITY STATEMENT

We have tried to clearly articulate all of the relevant details in the text so that this paper can be completely reproduced from the contents of this paper alone. We will also open-source the full source code behind every experimental result, table, and figure in this paper, and will anonymously send it with this submission. 

B OTHER MODELS FOR CONTINUOUS OT

While developing the hyper-network or Meta ICNN in sect. 3.2 for predicting couplings between continuous measures, we considered alternative modeling formulations briefly documented in this section. We finalized only the hyper-network model because it is conceptually the most similar to predicting the optimal dual variables in the continuous setting and results in rapid predictions.

B.1 OPTIMIZATION-BASED META-LEARNING (MAML-INSPIRED)

The model-agnostic meta-learning setup proposed in MAML (Finn et al., 2017) could also be applied in the Meta OT setting to learn an adaptable initial parameterization. In the continuous setting, one initial version would take a parameterized dual potential model ψ ϕ (x) and seek to learn an initial parameterization ϕ 0 so that optimizing a loss such as the W2GN loss L from eq. ( 12) results in a minimal L(ϕ K ) after adapting the model for K steps. Formally, this would optimize: Challenges for Meta OT. The transport maps given by T = ∇ψ can significantly vary depending on the input measures α, β. We found it difficult to learn an initialization that can be rapidly adapted, and optimizing eq. ( 19) is more computationally expensive than eq. ( 18) as it requires unrolling through many evaluations of the transport loss L. And, we found that only learning to predict the optimal parameters with eq. ( 18), conditional on the input measures, and then fine-tuning with W2GN to be stable. arg min ϕ0 L(ϕ K ) where ϕ t+1 = ϕ t -∇ ϕ L(ϕ t ) Advantages for Meta OT. Exploring MAML-inspired methods could further incorporate the knowledge that the model's prediction is going to be fine-tuned into the learning process. One promising direction we did not try could be to integrate some of the ideas from LEO (Rusu et al., 2018) and CAVIA (Zintgraf et al., 2019) , which propose to learn a latent space for the parameters where the initialization is also conditional on the input.

B.2 NEURAL PROCESS AND CONDITIONAL MONGE MAPS

The (conditional) neural process models considered in Garnelo et al. (2018b; a) can also be adapted for the Meta OT setting, and is similar to the model proposed in Bunne et al. (2022a) . In the continuous setting, this would result in a dual potential that is also conditioned on a representation of the input measures, e.g. ψ ϕ (x; z) where z := f emb ϕ (α, β) is a learned embedding of the input measures that is learned with the parameters of ψ. This could be formulated as arg min ϕ E (α,β)∼D L(ϕ, f emb ϕ (α, β)), where L modifies the model used in the loss eq. ( 12) to also be conditioned on the context extracted from the measures. Challenges for Meta OT. This raises the issue on best-formulating the model to be conditional on the context. One way could be to append z to the input point x in the domain. Bunne et al. (2022a) proposes to use the Partially Input-Convex Neural Network (PICNN) from (Amos et al., 2017) to make the model convex with respect to x and not z. Advantages for Meta OT. A large advantage is that the representation z of the measures α, β would be significantly lower-dimensional than the parameters ϕ that our Meta OT models are predicting.

C ADDITIONAL EXPERIMENTAL AND IMPLEMENTATION DETAILS

We have attached the Jax source code necessary to run and reproduce all of the experiments in our paper and will open-source all of it. Here is a basic overview of the files: App. D tests the ability of Meta OT to predict potentials for out-of-distribution input data. We consider the pairwise training and evaluation on the following datasets: 1) MNIST; 2) USPS (Hull, 1994) (upscaled to have the same size as the MNIST); 3) Google Doodles dataset * with classes Crab, Cat and Faces; 4) sparsified random uniform data in [0,1] where sparsity (zeroing values below 0.95) is used to mimic the sparse signal in black-and-white images. For each pair, eg, MNIST-USPS, we train on one dataset and use the other to predict the potentials. The comparison is done using the same metric as before, i.e., the deviation from the marginal constraints defined in eq. ( 7). 



2.1.1 between discrete measures α := m i=1 a i δ xi and β := n i=1 b i δ xi with a ∈ ∆ m-1 and b ∈ ∆ n-1 coupled using a cost c. In the Meta OT setting, the measures and cost are the contexts for amortization and sampled from a metadistribution, i.e. (α, β, c) ∼ D(α, β, c). For example, sects. 4.1 and 4.2 considers meta-distributions over the weights of the atoms, i.e. (a, b) ∼ D, where D is a distribution over ∆ m-1 × ∆ n-1 . Algorithm 3 Training Meta OT Initialize amortization model with θ0 for iteration i do Sample (α, β, c) ∼ D Predict duals fθ or φθ on the sample Estimate the loss in eq. (17) or eq. (18) Update θi+1 with a gradient step end for Algorithm 4 Fine-tuning with Sinkhorn Predict duals fθ (α, β, c) return Sinkhorn(α, β, c, , fθ ) Algorithm 5 Fine-tuning with W2GN Predict dual ICNN parameters φθ (α, β, c) return W2GN(α, β, c, T, φθ )

Figure 2: Interpolations between MNIST test digits using couplings obtained from (left) solving the problem with Sinkhorn, and (right) Meta OT model's initial prediction, which is ≈100 times computationally cheaper and produces a nearly identical coupling.

Figure 5: Test set coupling predictions of the spherical transport problem. Meta OT's initial prediction is ≈37500 times faster than solving Sinkhorn to optimality. Supply locations are shown as black dots and the blue lines show the spherical transport maps T going to demand locations at the end. The sphere is visualized with the Mercator projection.

Figure 6: Color transfers with a Meta ICNN on test pairs of images. The objective is to optimally transport the continuous RGB measure of the first image α to the second β, producing an invertible transport map T . Meta OT's prediction is ≈1000 times faster than training W2GN from scratch. The image generating α is Market in Algiers by August Macke (1914) and β is Argenteuil, The Seine by Claude Monet (1872), obtained from WikiArt.

Figure 7: Convergence on color transfer test instances using W2GN. Meta ICNNs predicts warm-start initializations that significantly improve the (normalized) dual objective values.

); Scetbon et al. (2021); Forrow et al. (2019); Sommerfeld et al. (2019); Scetbon et al. (2022); Muzellec and Cuturi (2019); Bonet et al. (

Jun-Yan Zhu, Taesung Park, Phillip Isola, and Alexei A Efros. Unpaired image-to-image translation using cycle-consistent adversarial networks. In Proceedings of the IEEE international conference on computer vision, pages 2223-2232, 2017.Luisa Zintgraf, Kyriacos Shiarli, Vitaly Kurin, Katja Hofmann, and Shimon Whiteson. Fast context adaptation via meta-learning. In International Conference on Machine Learning, pages 7693-7702. PMLR, 2019.

Figure 8: We selected = 10 -2 for our MNIST coupling experiments as it results in transport maps that are not too blurry or sharp.

Meta OT Python library code conjugate.py Exact conjugate solver for the continuous setting Hydra configuration for the experiments (containing hyper-parameters) train_discrete.py Train Meta OT models for discrete OT train_color_single.py Train a single ICNN with W2GN between 2 images (for debugging) train_color_meta.py Train a Meta ICNN with W2GN plot_mnist.py Visualize the MNIST couplings plot_world_pair.py Visualize the spherical couplings eval_color.py Evaluate the Meta ICNN in the continuous setting eval_discrete.py Evaluate the Meta ICNN for the discrete tasks D OUT-OF-DISTRIBUTION GENERALIZATION

Figure 10: Cross-domain experiments.

Figure 12: Meta ICNN + W2GN fine-tuning. The sources are given in the beginning of app. E.

Figure 13: W2GN (final). The sources are given in the beginning of app. E.

Thibault et al. (2017); Altschuler et al. (2017); Alaya et al. (2019); Lin et al. (2019) that would otherwise solve every problem from scratch.

Color transfer runtimes and values.

annex

Connecting to the data is one difficulty in running the experiments. The easiest experiment to re-run is the MNIST one, which will automatically download the dataset:1 ./train_discrete.py # Train the model, outputting to <exp_dir> 2 ./eval_discrete.py <exp_dir> # Evaluate the learned models 3 ./plot_mnist.py <exp_dir> # Produce further visualizations

C.1 HYPER-PARAMETERS

We briefly summarize the hyper-parameters we used for training, which we did not extensively tune. In the discrete setting, we use the same hyper-parameters for the MNIST and spherical settings. In the main paper, table 1 reports the runtime of Sinkhorn to reach a convergence threshold of the marginal error being below a tolerance of 10 23 . Tables 5 and 6 report the results from sweeping over other thresholds and show that Meta OT's initialization is consistently able to help.Table 5 : Sinkhorn runtime to reach a thresholded marginal error on MNIST.Table 6: Sinkhorn runtime to reach a thresholded marginal error on the spherical transport problem. Initialization Threshold=10 -2 Threshold=10 -3 Threshold=10 -4 Threshold=10 -5Zeros 8.8 

