ON AMORTIZING CONVEX CONJUGATES FOR OPTIMAL TRANSPORT

Abstract

This paper focuses on computing the convex conjugate operation that arises when solving Euclidean Wasserstein-2 optimal transport problems. This conjugation, which is also referred to as the Legendre-Fenchel conjugate or c-transform, is considered difficult to compute and in practice, Wasserstein-2 methods are limited by not being able to exactly conjugate the dual potentials in continuous space. To overcome this, the computation of the conjugate can be approximated with amortized optimization, which learns a model to predict the conjugate. I show that combining amortized approximations to the conjugate with a solver for fine-tuning significantly improves the quality of transport maps learned for the Wasserstein-2 benchmark by Korotin et al. (2021a) and is able to model many 2-dimensional couplings and flows considered in the literature. All of the baselines, methods, and solvers in this paper are available at http://github. com/facebookresearch/w2ot. 'log_alpha', nn.initializers.constant(0), []) 'log_alpha', nn.initializers.constant(0), [])

1. INTRODUCTION

Optimal transportation (Villani, 2009; Ambrosio, 2003; Santambrogio, 2015; Peyré et al., 2019) is a thriving area of research that provides a way of connecting and transporting between probability measures. While optimal transport between discrete measures is well-understood, e.g. with Sinkhorn distances (Cuturi, 2013) , optimal transport between continuous measures is an open research topic actively being investigated (Genevay et al., 2016; Seguy et al., 2017; Taghvaei and Jalali, 2019; Korotin et al., 2019; Makkuva et al., 2020; Fan et al., 2021; Asadulaev et al., 2022) . Continuous OT has applications in generative modeling (Arjovsky et al., 2017; Petzka et al., 2017; Wu et al., 2018; Liu et al., 2019; Cao et al., 2019; Leygonie et al., 2019) , domain adaptation (Luo et al., 2018; Shen et al., 2018; Xie et al., 2019) , barycenter computation (Li et al., 2020; Fan et al., 2020; Korotin et al., 2021b) , and biology (Bunne et al., 2021; 2022; Lübeck et al., 2022) . This paper focuses on estimating the Wasserstein-2 transport map between measures α and β in Euclidean space, i.e. supp(α) = supp(β) = R n with the Euclidean distance as the transport cost. The Wasserstein-2 transport map, T : R n → R n , is the solution to Monge's primal formulation: T ∈ arg inf T ∈T (α,β) E x∼α x -T (x) 2 2 , where T (α, β) := {T : T # α = β} is the set of admissible couplings and the push-forward operator # is defined by T # α(B) := α(T -1 (B)) for a measure α, measurable map T , and all measurable sets B. T exists and is unique under general settings, e.g. as in Santambrogio (2015, Theorem 1.17) , and is often difficult to solve because of the coupling constraints T . Almost every computational method instead solves the Kantorovich dual, e.g. as formulated in Villani (2009, §5) and Peyré et al. (2019, §2.5 ). This paper focuses on the dual associated with the negative inner product cost (Villani, 2009, eq. 5.12) , which introduces a dual potential function f : R n → R and solves: " f ∈ arg sup f ∈L 1 (α) -E x∼α [f (x)] -E y∼β [f (y)] where L 1 (α) is the space of measurable functions that are Lebesgue-integrable over α and f is the convex conjugate, or Legendre-Fenchel transform, of a function f defined by: f (y) := -inf x∈X J f (x; y) with objective J f (x; y) := f (x) -x, y . (3) x(y) denotes an optimal solution to eq. ( 3). Even though the eq. ( 2) searches over functions in L 1 (α), the optimal dual potential " f is convex (Villani, 2009, theorem 5.10) . When one of the measures has a density, Brenier (1991, theorem 3.1) and McCann (1995) relate " f to an optimal transport map T for the primal problem in eq. (1) with T (x) = ∇ x " f (x), and the inverse to the transport map is given by T -1 (y) = ∇ y " f (y). A stream of foundational papers have proposed methods to approximate the dual potential f with a neural network and learn it by optimizing eq. ( 2): Taghvaei and Jalali (2019); Korotin et al. (2019) ; Makkuva et al. (2020) parameterize f as an input-convex neural network (Amos et al., 2017) , which can universally represent any convex function with enough capacity (Huang et al., 2020) . Other works explore parameterizing f as a non-convex neural network (Nhan Dam et al., 2019; Korotin et al., 2021a; Rout et al., 2021) . Efficiently solving the conjugation operation in eq. ( 3) is the key computational challenge to solving the Kantorovich dual in eq. ( 2) and is an important design choice. Exactly computing the conjugate as done in Taghvaei and Jalali (2019) is considered computationally challenging and approximating it as in Korotin et al. (2019) ; Makkuva et al. (2020) ; Nhan Dam et al. (2019) ; Korotin et al. (2021a) ; Rout et al. (2021) may be instable. Korotin et al. (2021a) fortifies this observation: The [exact conjugate] solver is slow since each optimization step solves a hard subproblem for computing [the conjugate]. [Solvers that approximate the conjugate] are also hard to optimize: they either diverge from the start or diverge after converging to nearly-optimal saddle point. In contrast to these statements on the difficulty of exactly estimating the conjugate operation, I will show in this paper that computing the (near-)exact conjugate is easy. My key insight is that the approximate, i.e. amortized, conjugation methods can be combined with a fine-tuning procedure using the approximate solution as a starting point. Sect. 3 discusses the amortization design choices and sect. 3.2.2 presents a new amortization perspective on the cycle consistency term used in Wasserstein-2 generative networks (Korotin et al., 2019) , which was previously not seen in this way. Sect. 5 shows that amortizing and fine-tuning the conjugate results in state-of-the-art performance in all of the tasks proposed in the Wasserstein-2 benchmark by Korotin et al. (2021a) . Amortization with fine-tuning also nicely models synthetic settings (sect. 6), including for learning a single-block potential flow without using the likelihood.

2. LEARNING DUAL POTENTIALS: A CONJUGATION PERSPECTIVE

This section reviews the standard methods of learning parameterized dual potentials to solve eq. ( 2). The first step is to re-cast the Kantorovich dual problem eq. ( 2) as being over a parametric family of potentials f θ with parameter θ as an input-convex neural network (Amos et al., 2017) or a more general non-convex neural network. Taghvaei and Jalali (2019) ; Makkuva et al. (2020) have laid the foundations for optimizing the parametric potentials for the dual objective with max θ V(θ) where V(θ) := -E x∼α [f θ (x)] -E y∼β [f θ (y)] = -E x∼α [f θ (x)] + E y∼β [J f θ (x(y))] , ( ) where J is the objective to the conjugate optimization problem in eq. ( 3), x(y) is the solution to the convex conjugate, and eq. ( 4) assumes a finite solution to eq. ( 2) exists and replaces the sup with a max. Taghvaei and Jalali (2019) show that the model can be learned, i.e. the optimal parameters can be found, by taking gradient steps of the dual with respect to the parameters of the potential, i.e. using ∇ θ V. This derivative going through the loss and conjugation operation can be obtained by applying Danskin's envelope theorem (Danskin, 1966; Bertsekas, 1971) and results in only needing derivatives of the potential: ∇ θ V(θ) = ∇ θ -E x∼α [f θ (x)] + E y∼β [J f θ (x(y))] = -E x∼α [∇ θ f θ (x)] + E y∼β [∇ θ f θ (x(y))] where x(y) is not differentiated through. Assumption 1 A standard assumption is that the conjugate is smooth with a well-defined arg min. This has been shown to hold when f is strongly convex, e.g. in Kakade et al. (2009) , or when f is essentially strictly convex (Rockafellar, 2015, theorem 26.3) . In practice, assumption 1 in not guaranteed, e.g. non-convex potentials may have a parameterization that results in the conjugate taking infinite values in regions. The dual objective in eq. ( 2) and eq. ( 4) discourage the conjugate from diverging as the supremum involves the negation of the conjugate. Remark 1 In eq. (4), the dual potential f associated with the α measure's constraints is the central object that is parameterized and learned, and the dual potential associated with the β measure is given by the conjugate f and does not require separately learning. Because of the symmetry of eq. (1), the order can also be reversed as in Korotin et al. (2021b) so that the duals associated with the β measure are the ones directly parameterized, but we will not consider doing this. Potentials associated with both measures can also be parameterized and we will next see that it is the most natural to think about the model associated with the conjugate as an amortization model.

Remark 2

The dual objective V can be upper-bounded by replacing x with any approximation because any sub-optimal solution to the conjugation objective provides an upper-bound to the true objective, i.e. J(x(y); y) ≤ J(x; y) for all x. In practice, maximizing a loose upper-bound can cause significant divergence issues as the potential can start over-optimizing the objective. Computing the updates to the dual potential's parameters in eq. ( 5) is a well-defined machine learning setup given a parameterization of the potential f θ , but is often computationally bottlenecked by the conjugate operation. Because of this bottleneck, many existing work resorts to amortizing the conjugate by predicting the solution with a model xφ (y). I overview the design choices behind amortizing the conjugate in sect. 3, and then go on in sect. 4 to show that it is reasonable to fine-tune the amortized predictions with an explicit solver CONJUGATE(f, y, x init = xφ (y)). Algorithm 1 summarizes how to learn a dual potential with an amortized and fine-tuned conjugate.

3. AMORTIZING CONVEX CONJUGATES: MODELING AND LOSSES

This section scopes to predicting an approximate solution to the conjugate optimization problem in eq. ( 3). This is an instance of amortized optimization methods which predict the solution to a family of optimization problems that are repeatedly solved (Shu, 2017; Chen et al., 2021; Amos, 2022) . Amortization is sensible here because the conjugate is repeatedly solved for y ∼ β every time the dual V from eq. ( 4) is evaluated across a batch. Using the basic setup from Amos (2022), I call a prediction to the solution of eq. ( 3) the amortization model xϕ (y), which is parameterized by some ϕ. The goal is to make the amortization model's prediction match the true conjugate solution, i.e. xφ (y) ≈ x(y), for samples y ∼ β. In other words, amortization uses a model to simultaneously solve all of the conjugate optimization problems. There are two main design choices: sect. 3.1 discusses parameterizing the amortization model and sect. 3.2 overviews amortization losses.

3.1. PARAMETERIZING A CONJUGATE AMORTIZATION MODEL

The amortization model xϕ (y) maps a point y ∼ β to a solution to the conjugate in eq. (3), i.e. xϕ : R n → R n and the goal is for xϕ (y) ≈ x(y). In this paper, I take standard potential models further described in app. B and keep them fixed to ablate across the amortization loss and fine-tuning choices. The main categories are: 1. xϕ : R n → R n directly maps to the solution of eq. ( 3) with a multilayer perceptron (MLP) as in Nhan Dam et al. (2019) , or a U-Net (Ronneberger et al., 2015) for image-based transport. These are also used in parts of Korotin et al. (2021a) . 2. xϕ = ∇ y g ϕ is the gradient of a function g ϕ : R n → R. Korotin et al. (2019) ; Makkuva et al. (2020) parameterize g ϕ as an input-convex neural network, and some methods of Korotin et al. (2021a) parametrize g ϕ as a ResNet (He et al., 2016) . This is well-motivated because the arg min of a convex conjugate is the derivative, i.e. x(y) = ∇ y f (y). Algorithm 1 Learning Wasserstein-2 dual potentials with amortized and fine-tuned conjugation Inputs: Measures α and β to couple, initial dual potential f θ , and initial amortization model xϕ while unconverged do Sample batches {x j } ∼ α and {y j } ∼ β indexed by j ∈ [N ] Obtain the amortized prediction of the conjugate xϕ (y j ) Fine-tune the prediction by numerically solving x(y j ) = CONJUGATE(f, y j , x init = xϕ (y j )) Update the potential with a gradient estimate of the dual in eq. ( 5), i.e. ∇ θ V Update the amortization model with a gradient estimate of a loss from sect. 3, i.e. ∇ ϕ L end while return optimal dual potentials f θ and conjugate amortization model xϕ We now turn to the design choice of what loss to optimize so that the conjugate amortization model xϕ best-predicts the solution to the conjugate. In all cases, the loss is differentiable and ϕ is optimized with a gradient-based optimizer. I present an amortization perspective of methods not previously presented as amortization methods, which is useful to help think about improving the amortized predictions with the fine-tuning and exact solvers in sect. 4. Figure 1 illustrates the main loss choices. 3.2 CONJUGATE AMORTIZATION LOSS CHOICES -2 0 2 x 0 4 ∝ ∇J f (x) 2 2 Cycle J f (x; y) Objective x -x (y) 2 2 Regression x (y)

3.2.1. OBJECTIVE-BASED AMORTIZATION

Nhan Dam et al. (2019) propose to make the amortized prediction optimal on the conjugation objective J f from eq. ( 3) across samples from β, i.e.:  We refer to L obj as objective-based amortization and solve eq. ( 6) by taking gradient steps ∇ ϕ L obj using a Monte-Carlo estimate of the expectation. Remark 3 The maximin method proposed in Makkuva et al. (2020, theorem 3.3 ) is equivalent to maximizing an upper-bound to the dual loss V with respect to θ of a potential f θ and minimizing the objective-based amortization loss L obj with respect to ϕ of an amortization model xϕ := ∇g ϕ . Their formulation replaces the exact conjugate x in eq. (4) with an approximation xϕ , i.e.:  max θ min ϕ V MM (θ, ϕ) where V MM (θ, ϕ) := -E x∼α [f θ (x)] + E y∼β [J f θ (x ϕ (y); y)]. V MM (θ, ϕ) = ∇ ϕ L obj (ϕ) = ∇ ϕ J f θ (x ϕ (y); y). Remark 4 Suboptimal predictions of the conjugate often leads to a divergent upper bound on V(θ). Makkuva et al. (2020, algorithm 1) propose to fix this by running more updates on the amortization model. In sect. 4, I propose fine-tuning as an alternative to obtain a near-exact conjugates.

3.2.2. FIRST-ORDER OPTIMALITY AMORTIZATION: CYCLE CONSISTENCY AND W2GN

An alternative to optimizing the dual objective directly as in eq. ( 6) is to optimize for the first-order optimality condition. Eq. ( 3) is an unconstrained minimization problem, so the first-order optimality condition is that the derivative of the objective is zero, i.e. ∇ x J f (x; y) = ∇ x f (x) -y = 0. The conjugate amortization model can be optimized for the residual norm of this condition with min ϕ L cycle (ϕ) where L cycle (ϕ) := E y∼β ∇ x J f (x ϕ (y); y) 2 2 = E y∼β ∇ x f (x ϕ (y)) -y 2 2 . ( ) Remark 5 W2GN (Korotin et al., 2019) is equivalent to maximizing an upper-bound to the dual loss V with respect to θ of a potential f θ and minimizing the first-order amortization loss L cycle with respect to ϕ of an conjugate amortization model xϕ := ∇g ϕ . Korotin et al. (2019) originally motivated the cycle consistency term from the use in cross-domain generative modeling Zhu et al. (2017) and eq. ( 8) shows an alternative way of deriving the cycle consistency term by amortizing the first-order optimality conditions of the conjugate. Remark 6 The formulation in Korotin et al. (2019) does not disconnect f θ when optimizing the cycle loss in eq. ( 8). From an amortization perspective, this performs amortization by updating f θ to have a solution closer to xϕ rather than the usual amortization setting of updating xϕ to make a prediction closer to the solution of f θ . In my experiments, updating f θ with the amortization term seems to help when not fine-tuning the conjugate to be exact, but not when using the exact conjugates. Remark 7 Korotin et al. (2019) and followup papers such as Korotin et al. (2021b) state that they do not perform maximin optimization as in eq. ( 7) from Makkuva et al. (2020) because they replace the inner optimization of the conjugate with an approximation. I disagree that the main distinction between these methods should be based on their formulation as a maximin optimization problem. I instead propose that the main difference between their losses is how they amortize the convex conjugate: Makkuva et al. ( 2020) use the objective-based loss in eq. ( 6) while Korotin et al. (2019) use the first-order optimality condition (eq. ( 8)). Sect. 5 shows that adding fine-tuning and exact conjugates to both of these methods makes their performance match in most cases. Remark 8 Optimizing for the first-order optimality conditions may not be ideal for non-convex conjugate objectives as inflection points with a near-zero derivative may not be a global minimum of eq. (3). The left and right regions of fig. 1 illustrate this.

3.2.3. REGRESSION-BASED AMORTIZATION

The previous objective and first-order amortization methods locally refine the model's prediction using local derivative information. The conjugate amortization model can also be trained by regressing onto ground-truth solutions when they are available, i.e. min ϕ L reg (ϕ) where L reg (ϕ) := E y∼β xϕ (y) -x(y) 2 2 . ( ) This regression loss is the most useful when approximations to the conjugate are computationally easy to obtain, e.g. with a method described in sect. 4. L reg gives the amortization model information about where the globally optimal solution is rather than requiring it to only locally search over the conjugate's objective J.

4. NUMERICAL SOLVERS FOR EXACT CONJUGATES AND FINE TUNING

Algorithm 2 CONJUGATE(f, y, xinit) x ← xinit while unconverged do Update x with ∇xJ f (x; y) end while return optimal x(y) = x In the Euclidean Wasserstein-2 setting, the conjugation operation in eq. ( 3) is a continuous and unconstrained optimization problem over a possibly non-convex potential f . It is usually implemented with a method using first-order information for the update in algorithm 2, such as: 1. Adam (Kingma and Ba, 2014) is an adaptive first-order optimizer for high-dimensional optimization problems and is used for the exact conjugations in Korotin et al. (2021a) . Note: Adam here is for algorithm 2 and is not performing parameter optimization. 2. L-BFGS (Liu and Nocedal, 1989 ) is a quasi-Newton method for optimizing unconstrained convex functions. App. A discusses more implementation details behind setting up L-BFGS efficiently to run on the batches of optimization problems considered here. Choosing the line search method is the most crucial part as the conditional nature of some line searches may be prohibitive over batches. Table 3 shows that an Armijo search often works well to obtain approximate solutions.

5. EXPERIMENTAL RESULTS ON THE WASSERSTEIN-2 BENCHMARK

I have focused most of the experimental investigations on the Wasserstein-2 benchmark (Korotin et al., 2021a) because it provides a concrete evaluation setting with established baselines for learning potentials for Euclidean Wasserstein-2 optimal transport. The tasks in the benchmark have known (ground-truth) optimal transport maps and include transporting between: 1) high-dimensional (HD) mixtures of Gaussians, and 2) samples from generative models trained on CelebA (Liu et al., 2015) . The main evaluation metric is the unexplained variance percentage (L 2 -UVP) metric from (Korotin et al., 2019) , which compares a candidate map T to the ground truth map T with: L 2 -UVP(T ; α, β) := 100 • T -T 2 L 2 (α) /Var(β)%. ( ) In all of the experimental results, I report the final L 2 -UVP evaluated with 16384 samples at the end of training, and average the results over 10 trials. App. C further details the experimental setup. My original motivation for running these experiments was to understand how ablating the amortization losses and fine-tuning options impacts the final L 2 -UVP performance of the learned potential. The main experimental takeaway of this paper is that fine-tuning the amortized conjugate with a solver significantly improves the learned transport maps. Tables 1 and 2 Remark 9 With fine-tuning, the choice of regression or objective-based amortization doesn't significantly impact the L 2 -UVP of the final potential. This is because fine-tuning is usually able to find the optimal conjugates from the predicted starting points. Remark 10 My re-implementation of W2GN (Korotin et al., 2019) , which uses cycle consistency amortization with no fine-tuning, often outperforms the results reported in Korotin et al. (2021a) . This is likely due to differences in the base potential and conjugate amortization models. Remark 11 Cycle consistency sometimes provides difficult starting points for the fine-tuning methods, especially for L-BFGS. When learning non-convex potentials, this poor performance is likely related to the fact that Newton methods are known to be difficult for saddle points (Dauphin et al., 2014) . Combining cycle consistency, which tries to find a point where the derivative is zero, with L-BFGS, which also tries to find a point where the derivative is zero, results in finding suboptimal inflection points of the potential rather than the true minimizer. Remark 12 The performance of the methods using objective-based amortization without finetuning, as done in Taghvaei and Jalali (2019) , are lower than the performance reported in Korotin et al. (2021a) . This is because I do not run multiple inner updates to update the conjugate amortization model. I instead advocate for fine-tuning the conjugate predictions with a known solver, eliminating the need for a hyper-parameter of the number of inner iterations that needs to be delicately tuned to make sure the amortized prediction alone does not diverge too much from the true conjugate. amortize and learn the solutions to OT and matching problems by predicting the optimal duals given the input measures. These approaches are complimentary to this paper as they amortize the solution to the dual in eq. ( 2) while this paper amortizes the conjugate subproblem in eq. ( 3) that is repeatedly computed when solving a single OT problem.

8. CONCLUSIONS, FUTURE DIRECTIONS, AND LIMITATIONS

This paper explores the use of amortization and fine-tuning for computing convex conjugates. The methodological insights and amortization perspective may directly transfer to many other applications and extensions of Euclidean Wasserstein-2 optimal transport, including for computing barycenters (Korotin et al., 2021b) , Wasserstein gradient flows (Alvarez-Melis et al., 2021; Mokrov et al., 2021) , or cellular trajectories (Bunne et al., 2021) . Many of the key amortization and fine-tuning concepts from here will transfer beyond the Euclidean Wasserstein-2 setting, e.g. the more general c-transform arising in non-Euclidean optimal transport (Sei, 2013; Cohen et al., 2021; Rezende and Racanière, 2021) or for the Moreau envelope computation, which can be decomposed into a term that involve the convex conjugate as described in Rockafellar and Wets (2009, ex. 11.26) and Lucet (2006, sect. 2). Limitations. The most significant limitation in the field of estimating Euclidean Wasserstein-2 optimal transport maps is the lack of convergence guarantees. The parameter optimization problem in eq. ( 4) is always non-convex, even when using input-convex neural networks. I have shown that improved conjugate estimations significantly improve the stability when the base potential model is properly set up, but all methods are sensitive to the potential model's hyper-parameters. I found that small changes to the activation type or initial learning rate can cause no method to converge. Algorithm 3 The Broyden-Fletcher-Goldfarb-Shanno (BFGS) method to solve eq. ( 11) as presented in Nocedal and Wright (1999, alg. 6.1) . Inputs: Function J to optimize, initial iterate x 0 and Hessian approximation B 0 k ← 0 while unconverged do Compute the search direction p k = -B -1 k ∇ x J k (x k ) Set x k+1 = x k + α k p k where α k is computed from a line search from app. A.2 Compute B k with the update in eq. ( 12) k ← k + 1 end while return optimal solution x k ≈ x 

A L-BFGS OVERVIEW AND LINE SEARCH DETAILS

The conjugate optimization problem in eq. ( 3) is an unconstrained convex optimization problem for convex potentials, which is a setting BFGS (Broyden, 1970; Fletcher, 1970; Goldfarb, 1970; Shanno, 1970) and L-BFGS (Liu and Nocedal, 1989 ) methods thrive in. The default strong Wolfe line search methods in the Jax and JaxOpt L-BFGS implementations may take a long time to solve a batch of optimization problems. Without efficiently setting the line search method, some of the Wasserstein-2 benchmark experiments in app. C that ran in a few hours would have otherwise taken a month to run. This section provides a brief overview of BFGS methods and shows that an Armijo line search can be the most efficient at computing the conjugate. A.1 BACKGROUND ON BFGS METHODS Nocedal and Wright (1999, alg. 6 .1) is a standard reference for BFGS methods and extensions and the key steps are summarized in algorithm 3 for solving an optimization problem of the form x ∈ arg min x J(x) where J : R n → R is possibly non-convex and twice continuously differentiable. The method iteratively finds a solution x by 1) maintaining an approximate Hessian around the current iterate, i.e. B k ≈ ∇ 2 J(x k ), 2) computing an approximate Newton step p k = -B -1 k ∇ x J k (x k ) using the approximate Hessian, 3) updating the iterate with x k+1 = x k + α k p k , where α k is found with a line search, and 4) updating the Hessian approximation with the Sherman-Morrison-Woodbury formula (Woodbury, 1950; Sherman and Morrison, 1950 ) B k+1 = B k - B k s s s k B k s k B k s k + y k y k y k s k , ( ) where y k = ∇ x J(x k+1 ) -∇ x J(x k ) and s k = x k+1 -x k . Instead of estimating B k and inverting it in every iteration, most implementations maintain a direct approximation to the inverse Hessian H k := B -1 k . The limited-memory version of BFGS (L-BFGS) in Liu and Nocedal (1989) propose to replace the inverse Hessian approximation as a matrix H k with the sequence of vectors [y k , s k ] defining the updates to H k and never requires instantiating the full n × n approximation. Algorithm 4 Backtracking Armijo line search to solve eq. ( 17) Inputs: Iterate x k , search direction p k , decay τ , control parameter c 1 , initial α 0 α ← α init while J(x k + α j p k ) > J(x k ) + c 1 α j p k ∇ x J(x k ) do α ← τ α end while return α satisfying the Armijo condition in eq. ( 17). Algorithm 5 Parallel Armijo line search to solve eq. ( 17) Inputs: Iterate x k , search direction p k , decay τ , control parameter c 1 , initial α 0 , #evaluations M Compute candidate step lengths α m = τ -m for m ∈ [M ] Evaluate the line search condition g(α m ) from eq. ( 16) in parallel if all g(α m ) < 0 then Error: No acceptable step found else return largest α satisfying eq. ( 17), i.e. max α m subject to g(α m ) > 0 end if

A.2 LINE SEARCHES

The line search to find the step size α k for the iterate update x k+1 = x k + α k p k is often done with: 1. a Wolfe line search (Wolfe, 1969) to satisfy the conditions: J(x k + α k p k ) ≤ J(x k ) + c 1 α k p k ∇ x J(x k ) -p x ∇ x J(x k + α x p x ) ≤ -c 2 p x ∇ x J(x k ) where 0 < c 1 < c 2 < 1, 2. a strong Wolfe line search to satisfy the conditions: J(x k + α k p k ) ≤ J(x k ) + c 1 α k p k ∇ x J(x k ) |p x ∇ x J(x k + α x p x )| ≤ c 2 |p x ∇ x J(x k )| (14) This is often found via the zoom procedure from Nocedal and Wright (1999, algorithm 3.5) . 3. an Armijo line search (Armijo, 1966) to satisfy the first condition: J(x k + α k p k ) ≤ J(x k ) + c 1 α k p k ∇ x J(x k ). ( ) For notational simplicity, we can also write the Armijo condition as: g(α) := f (x k ) + c 1 αp k ∇f (x k ) -f (x k + αp k ) ≥ 0 (16) Remark 14 The strong Wolfe line search is the most commonly used line search for L-BFGS as it guarantees that the resulting update to the Hessian in eq. ( 12) stays positive definite, but the Armijo line search may be more efficient as it does not involve re-evaluating the derivative of the objective. Unfortunately, an iterate obtained by an Armijo line search may not satisfy the curvature condition y k s k > 0 that ensure the Hessian update stays positive definite while a step satisfying the strong Wolfe conditions provably does (Nocedal and Wright, 1999, page 143) . Nonetheless, Armijo searches are still combined with BFGS and can be guarded by only updating the Hessian approximation if y k s k > 0, or a modification thereof, as described in Li and Fukushima (2001) ; Wan et al. (2012) ; Fridovich-Keil and Recht (2020) and Berahas et al. (2016, sect. 3.2) . Table 3 : Runtime and number of L-BFGS iterations for line search methods to converge to a solution x of eq. ( 11) for conjugating the trained ICNN potential on the 256-dimensional HD benchmark from sect. 5 with a batch of 1024 samples and a tolerance of ∇J(x) ∞ ≤ 0.1, starting from the amortized prediction. The Wolfe and Armijo line search methods use standard values of c 1 = 10 -4 and c 2 = 0.9, all backtracking options use a decay factor of τ = 2/3 with M = 15 evaluations, and the runtimes are averaged over 10 trials on an NVIDIA Tesla V100 GPU. In this setting, Armijo line searches without many conditionals or gradient evaluations consistently take the shortest time. Base The Armijo line search can be written as the optimization problem α k (x k , p k ) = max α subject to J(x k + αp k ) ≤ J(x k ) + c 1 αp k ∇ x J(x k ). ( ) Eq. ( 17) is typically solved as shown in algorithm 4 and fig. 8 by setting a decay factor τ and iteratively decreasing a candidate step length α until the condition is satisfied. When the objective J can be efficiently evaluated in parallel on a GPU, and when solving many batches of optimization problems concurrently, e.g. with vmap, the backtracking Armijo line search described in algorithm 4, and the Wolfe line search described in Nocedal and Wright (1999, alg. 7.5) , are computationally slowed down by serial and conditional operations. These issues arise from: 1) the sequential nature of the line search, and 2) the fact that the line search may run for a different number of iterations for every optimization problem in the batch. Wolfe line searches such as Nocedal and Wright (1999, alg. 7.5) have other conditionals scoping the search interval that cause the line search to perform potentially different operations for every optimization problem in the batch. I propose a parallel Armijo line search in algorithm 5, which is also visualized in fig. 8 , to remove serial and conditional operations to improve the computation of the line search on the GPU for solving batches of optimization problems. The key idea is to instantiate many possible step sizes, evaluate them all at once, and then select the largest α m satisfying the Armijo condition g(α m ) ≥ 0. Remark 15 The parallel line search may unnecessarily evaluate more candidate steps sizes than the sequential line search, but on GPU architectures this may not be very detrimental to the performance because additional parallel function evaluations are computationally cheap. Furthermore, when solving a batch of N optimization problems with M line search evaluations, i.e. when using vmap on the line search or optimizer, the parallel line search in algorithm 5 can efficiently evaluate N M candidate step lengths in tandem on a GPU and then select the best for each element in the batch. Remark 16 A potential concern with this parallel line search is that it may not find a step size satisfying the Armijo condition if M is not set to be low enough. While this may be a significant issue for when high-precision solves are needed, I have found in practice for the Euclidean Wasserstein-2 conjugates that taking M = 10 line search evaluations frequently finds a solution.

A.2.2 COMPARING LINE SEARCH METHODS

Table 3 takes a trained ICNN potential and isolates the comparison between L-BFGS runtimes to only changing the linesearch methods. This is the same optimization procedure and batch size used for all of the training runs on the Wasserstein-2 benchmark. Despite the concerns in remark 14 about the Armijo line search resulting in slower convergence and an indefinite Hessian approximations, the Armijo line searches are consistently able to solve the batch of optimization problems the fastest.

B MODEL DEFINITIONS AND PRETRAINING

All of the potential and conjugate amortization models in this paper can be implemented in ≈ 30-50 lines of readable Jax code with Flax (Heek et al., 2020) . They are included in the w2ot/models directory of the code here, and reproduced here to precisely define them.

B.1 PRETRAINING AND INITIALIZATION

Following Korotin et al. (2021a) , every experimental setting has a pre-training phase so that the potentials and amortization maps approximate the identity mapping, i.e. ∇ x f θ (x) ≈ x and xϕ (y) ≈ y.

B.2 InitNN: NON-CONVEX NEURAL NETWORK AMORTIZATION MODEL xϕ

Remark 17 The passthrough on line 18 is helpful for learning an identity initialization. 1 class InitNN(nn.Module): Remark 18 Applying an activation to the output on line 41 is helpful to lower-bound the otherwise unconstrained potential and adds stability to the training. 2 dim_hidden: Sequence[int]

Remark 19

The final quadratic on line 46 makes it easy to initialize the potential to the identity. Remark 20 This ICNN does not use the quadratic activations proposed in Korotin et al. (2019, Appendix B.1) . While I did not heavily experiment with them, table 1 shows that this ICNN architecture without the quadratic activations performs better than the results originally reported in Korotin et al. (2021a) which use an ICNN architecture with the quadratic activations. 1 class ICNN(nn.Module): 9 shows that with a non-convex potential, many of the initial amortized predictions are suboptimal and difficult for the L-BFGS to improve upon. This indicates that the amortized predictions may be in parts of the space that are difficult to recover from and suggests a future avenue of work better characterizing and recovering from this behavior. L-BFGS converges fast to an optimal solution in fig. 10 while Adam often gets stuck at suboptimal solutions. These results give an idea of how much additional time is spent fine-tuning. On the HD benchmark, fine-tuning takes between ≈ 10-50ms per batch. The overall wall clock time may take ≈ 2-3 times longer than the training runs without fine-tuning, but are able to find significantly better solutions. On the CelebA64 benchmarks, the conjugation time impacts the overall runtime even less because, especially in the "Mid" and "Late" settings as the transport maps here are close to being the identity mapping and are easy to conjugate. 2 dim_hidden: Sequence[int] Remark 23 Some settings immediately diverged to an irrecoverable state providing a L 2 -UVP of 10 9 , including runs using the objective-based and cycle amortization losses without fine-tuning. I early-stopped those experiments and do not report the runtimes or conjugation times here, as the few minutes that the objective-based amortization experiments took to diverge is not very interesting or comparable to the times of the experiments that converged. Remark 24 I found leaky ReLU activations on the potential model to work better in these lowdimensional settings than ELU activations, which work better in the HD benchmark settings. I do not have a strong explanation for this but found the LReLU capable of performing sharper transitions in the space, e.g. the sharp boundaries shown in fig. 4 . One reason that the ELU potentials could perform better on the benchmark settings is that the ground-truth transport maps in the benchmark, described in Korotin et al. (2021a, Appendix B.1) , use an ICNN with CELU activations (Barron, 2017) which may be easier to recover with potential models that use ELU activations. I trained convex and non-convex potentials on every synthetic setting and show the results from the best-performing potential model, which are: • Makkuva et al. ( 2020): an ICNN. This setting originally considered convex potentials, and the non-convex potentials I tried training on these settings diverged, • Rout et al. (2021) : a non-convex potential (an MLP). This setting also originally considered an MLP and I couldn't find an ICNN that accurately transports between the highly curved and concentrated parts of the measures. • Huang et al. (2020) : a non-convex potential (an MLP). In contrast to the ICNNs originally used, I found that an MLP works better when learned with the OT dual. Almost every setting in Huang et al. (2020) requires composing multiple blocks of ICNNs, which means the flow will not necessary be the optimal transport flow, while the non-convex MLP potential I am using here estimates the optimal transport map between the measures. All of the synthetic settings use the L-BFGS conjugate solver set to obtain slightly higher precision solves than in the Wasserstein-2 benchmark. The conjugate solver stops early if all dimensions of the iterates change by less than 0.001, and otherwise run for a maximum of 100 iterations. The line search parameters for the parallel Armijo search in algorithm 5 for L-BFGS are to decay the steps with a base of τ = 1.5 and to search M = 30 step sizes. Brenier's theorem (Brenier, 1991) shows that the known Wasserstein-2 optimal transport map associated with the negative inner product cost is the gradient of a convex function, i.e. T (x) = ∇ x " f (x). Because of this, optimizing over convex potentials is theoretically nice and also results in a convex and easy conjugate optimization problem in eq. ( 3) to compute f . The input-convex property is usually enforced by constraining all of the weights of the network to be positive in every layer except the first. Unfortunately, in practice, the positivity constraints of a convex potential may be prohibitive and not easy to optimize over and result in sub-optimal transport maps. In other words, the parameter optimization problem over the input-convex model is still non-convex and may be exasperated by the input-convex constraints. Due to these limitations, non-convex potentials are appealing as their parameter space is less constrained and may therefore be easier to search over. And in practice, this has been shown to be true, e.g. the main results in table 1 show that a non-convex potential significantly outperforms the convex potential. However, non-convex potentials can result in non-convex conjugate optimization problems in eq. ( 3) that can cause significant numerical instabilities and an exploding upper-bound on the dual objective. Figure 12 illustrates a small non-convex region arising in a learned non-convex potential. While the non-convex region mostly does not impact the transport map in this case, they can easily blow up and make the dual optimization problem challenging. In contrast, the ICNN-based convex potential provably retains convexity and keeps this region nicely flat, but the constraints on the parameter space may hinder the performance.



z = x + Wx(z) # Encourage identity initialization. return z



Figure 1: Conjugate amortization losses.

obj (ϕ) where L obj (ϕ) := E y∼β J f (x ϕ (y); y).

Figure 2: Conjugate solver convergence on the HD benchmarks with an ICNN potential.

Figure 3: Learned transport maps on synthetic settings from Rout et al. (2021).

Learned potentials on settings considered inMakkuva et al. (2020).

Figure 5: Mesh grid G warped by the conjugate potential flow ∇f from the top setting of fig. 4.

Figure 8: Visualization of backtracking and parallel line searches to solve eq. (17).

.Dense(n_input, use_bias=True) B.3 ICNN: INPUT-CONVEX NEURAL NETWORK POTENTIAL f θ actnorm is the activation normalization layer from Kingma and Dhariwal (2018), which was also used in the ICNN potentials in Huang et al. (2020) and normalizes the activations at initialization to follow a normal distribution.

jnp.exp(log_alpha) * utils.batch_dot(x, x) 25 B.5 ConvPotential: NON-CONVEX CONVOLUTIONAL POTENTIAL f θ Remark 22 I was not able to easily add batch normalization(Ioffe and Szegedy, 2015) to this potential. In contrast to standard use cases of batch normalization that only call into a batchnormalized model once over samples from a single distribution, the dual objective in eq. (4) calls into the potential multiple times to estimate Ex∼α f θ (x) and Ey∼β f θ (x(y)), which also involve internally solving the conjugate optimization problem in eq. (3) to obtain x. This makes it not clear what training and evaluation statistics batch normalization should use when computing the dual objective. One choice could be to only use the statistics induced from the samples x ∼ α. ndim == 2 # Images should be flattened 10 num_batch = x.shape[0] 11 12 x_flat = x # Save for taking the quadratic at the end.

17

5 detail the main hyper-parameters for the Wasserstein-2 benchmark experiments. I tried to keep these consistent with the choices fromKorotin et al. (2021a), e.g. using the same batch sizes, number of training iterations, and hidden layer sizes for the potential.All experiments use the same settings for the conjugate solvers: The conjugate solvers stop early if all dimensions of the iterates change by less than 0.1, and otherwise run for a maximum of 100 iterations. The line search parameters for the parallel Armijo search in algorithm 5 for L-BFGS are to decay the steps with a base of τ = 1.5 and to search M = 10 step sizes. With the Adam conjugate solver, I use the default β = [0.9, 0.999] with an initial learning rate of 0.1 with a cosine annealing schedule to decrease it to 10 -5 .

Figure 9: Conjugate solver convergence on the HD benchmarks with a NN potential.

Figure 11: Sample conjugation landscapes J(x; y) of the bottom setting of fig. 4. The inverse transport map ∇ y f (y) = x(y) is obtained by minimizing J, which is a convex optimization problem. The contour shows J(x; y) filtered to not display a color for values above J(y; y).Table 8: Hyper-parameters for the synthetic experiments

report that amortizing and fine-tuning the conjugate improves the L 2 -UVP performance by a factor of 1.8 to 4.4 over the previously best-known results on the benchmark. App. C.3 shows that the conjugate can often be fine-tuned within 100ms per batch of 1024 examples on an NVIDIA Tesla V100 GPU, fig.2and app. C.2 compare Adam and L-BFGS for solving the conjugation. The following remarks further summarize the results from these experiments:

Comparison of L 2 -UVP on the high-dimensional tasks from the Wasserstein-2 benchmark byKorotin et al. (2021a), where *[the gray tags] denote their results. I report the mean and standard deviation across 10 trials. Fine-tuning the amortized prediction with L-BFGS or Adam consistently improves the quality of the learned potential. ±0.09 0.77 ±0.11 1.63 ±0.28 1.15 ±0.14 2.02 ±0.10 4.48 ±0.89 1.65 ±0.10 5.93 ±9.43 Objective L-BFGS 0.26 ±0.09 0.79 ±0.12 1.63 ±0.30 1.12 ±0.11 1.92 ±0.19 4.40 ±0.79 1.64 ±0.11 2.24 ±0.13 Regression L-BFGS 0.26 ±0.09 0.78 ±0.12 1.64 ±0.29 1.14 ±0.12 1.93 ±0.20 4.41 ±0.74 1.69 ±0.11 2.21 ±0.15 ±0.09 0.79 ±0.14 1.62 ±0.31 1.08 ±0.14 1.89 ±0.19 4.23 ±0.76 1.59 ±0.12 1.99 ±0.15 Regression Adam 0.35 ±0.07 0.81 ±0.12 1.61 ±0.32 1.09 ±0.11 1.85 ±0.20 4.42 ±0.68 1.63 ±0.08 1.99 ±0.16 Potential model: the non-convex neural network (MLP) described in app. B.4 Amortization model: the MLP described in app. B.2 ±0.00 0.22 ±0.01 0.60 ±0.03 0.80 ±0.11 2.09 ±0.31 2.08 ±0.40 0.67 ±0.05 0.59 ±0.04 Regression L-BFGS 0.03 ±0.00 0.22 ±0.01 0.61 ±0.04 0.77 ±0.10 1.97 ±0.38 2.08 ±0.39 0.67 ±0.05 0.65 ±0.07 ±0.01 0.26 ±0.02 0.63 ±0.07 0.81 ±0.10 1.99 ±0.32 2.21 ±0.32 0.77 ±0.05 0.66 ±0.07 Regression Adam 0.22 ±0.01 0.28 ±0.02 0.61 ±0.07 0.80 ±0.10 2.07 ±0.38 2.37 ±0.46 0.77 ±0.06 0.75 ±0.09

Comparison of L 2 -UVP on the CelebA64 tasks from the Wasserstein-2 benchmark byKorotin et al. (2021a), where *[the gray tags] denote their results. I report the mean and standard deviation across 10 trials. Fine-tuning the amortized prediction with L-BFGS or Adam consistently improves the quality of the learned potential. The ConvICNN64 and ResNet potential models are fromKorotin et al. (2021a), and app. B.5 describes the (non-convex) ConvNet model. † the reversed direction fromKorotin et al. (2021a), i.e. the potential model is associated with the β measure

Hyper-parameters for the D-dimensional Wasserstein-2 benchmark experiments

Hyper-parameters for the CelebA64 Wasserstein-2 benchmark experiments

Additional runtime and conjugation information for the HD benchmark. These report the median time from the converged runs. Amortization loss Conjugate solver n = 64 n = 128 n = 256 n = 64 n = 128 n = 256 n = 64 n = 128 n = 256

Additional runtime and conjugation information for the CelebA64 benchmark

details the main hyper-parameters for the synthetic benchmark experiments, and fig.11shows additional conjugation landscapes.

ACKNOWLEDGMENTS

, Matplotlib I would like to thank Max Balandat, Ricky Chen, Samuel Cohen, Marco Cuturi, Carles Domingo-Enrich, Yaron Lipman, Max Nickel, Misha Khodak, Aram-Alexandre Pooladian, Mike Rabbat, Adriana Romero Soriano, Mark Tygert, and Lin Xiao, for insightful comments and discussions. The core set of tools in Python (Van Rossum and Drake Jr, 1995; Oliphant, 2007) enabled this work, including Hydra (Yadan, 2019) , JAX (Bradbury et al., 2018) , Flax (Heek et al., 2020) , Matplotlib (Hunter, 2007) , numpy (Oliphant, 2006; Van Der Walt et al., 2011), and pandas (McKinney, 2012).

