ON AMORTIZING CONVEX CONJUGATES FOR OPTIMAL TRANSPORT

Abstract

This paper focuses on computing the convex conjugate operation that arises when solving Euclidean Wasserstein-2 optimal transport problems. This conjugation, which is also referred to as the Legendre-Fenchel conjugate or c-transform, is considered difficult to compute and in practice, Wasserstein-2 methods are limited by not being able to exactly conjugate the dual potentials in continuous space. To overcome this, the computation of the conjugate can be approximated with amortized optimization, which learns a model to predict the conjugate. I show that combining amortized approximations to the conjugate with a solver for fine-tuning significantly improves the quality of transport maps learned for the Wasserstein-2 benchmark by Korotin et al. (2021a) and is able to model many 2-dimensional couplings and flows considered in the literature. All of the baselines, methods, and solvers in this paper are available at http://github. com/facebookresearch/w2ot.

1. INTRODUCTION

Optimal transportation (Villani, 2009; Ambrosio, 2003; Santambrogio, 2015; Peyré et al., 2019) is a thriving area of research that provides a way of connecting and transporting between probability measures. While optimal transport between discrete measures is well-understood, e.g. with Sinkhorn distances (Cuturi, 2013) , optimal transport between continuous measures is an open research topic actively being investigated (Genevay et al., 2016; Seguy et al., 2017; Taghvaei and Jalali, 2019; Korotin et al., 2019; Makkuva et al., 2020; Fan et al., 2021; Asadulaev et al., 2022) . Continuous OT has applications in generative modeling (Arjovsky et al., 2017; Petzka et al., 2017; Wu et al., 2018; Liu et al., 2019; Cao et al., 2019; Leygonie et al., 2019 ), domain adaptation (Luo et al., 2018; Shen et al., 2018; Xie et al., 2019 ), barycenter computation (Li et al., 2020; Fan et al., 2020; Korotin et al., 2021b), and biology (Bunne et al., 2021; 2022; Lübeck et al., 2022) . This paper focuses on estimating the Wasserstein-2 transport map between measures α and β in Euclidean space, i.e. supp(α) = supp(β) = R n with the Euclidean distance as the transport cost. The Wasserstein-2 transport map, T : R n → R n , is the solution to Monge's primal formulation: T ∈ arg inf T ∈T (α,β) E x∼α x -T (x) 2 2 , where T (α, β) := {T : T # α = β} is the set of admissible couplings and the push-forward operator # is defined by T # α(B) := α(T -1 (B)) for a measure α, measurable map T , and all measurable sets B. T exists and is unique under general settings, e.g. as in Santambrogio (2015, Theorem 1.17), and is often difficult to solve because of the coupling constraints T . Almost every computational method instead solves the Kantorovich dual, e.g. as formulated in Villani (2009, §5) and Peyré et al. (2019, §2.5 ). This paper focuses on the dual associated with the negative inner product cost (Villani, 2009, eq. 5.12) , which introduces a dual potential function f : R n → R and solves: " f ∈ arg sup f ∈L 1 (α) -E x∼α [f (x)] -E y∼β [f (y)] where L 1 (α) is the space of measurable functions that are Lebesgue-integrable over α and f is the convex conjugate, or Legendre-Fenchel transform, of a function f defined by: f (y) := -inf x∈X J f (x; y) with objective J f (x; y) := f (x) -x, y . x(y) denotes an optimal solution to eq. ( 3). Even though the eq. ( 2) searches over functions in L 1 (α), the optimal dual potential " f is convex (Villani, 2009, theorem 5.10) . When one of the measures has a density, Brenier (1991, theorem 3.1) and McCann (1995) relate " f to an optimal transport map T for the primal problem in eq. (1) with T (x) = ∇ x " f (x), and the inverse to the transport map is given by T -1 (y) = ∇ y " f (y). A stream of foundational papers have proposed methods to approximate the dual potential f with a neural network and learn it by optimizing eq. ( 2): Taghvaei and Jalali ( 2019 Efficiently solving the conjugation operation in eq. ( 3) is the key computational challenge to solving the Kantorovich dual in eq. ( 2) and is an important design choice. Exactly computing the conjugate as done in Taghvaei and Jalali ( 2019 [Solvers that approximate the conjugate] are also hard to optimize: they either diverge from the start or diverge after converging to nearly-optimal saddle point. In contrast to these statements on the difficulty of exactly estimating the conjugate operation, I will show in this paper that computing the (near-)exact conjugate is easy. My key insight is that the approximate, i.e. amortized, conjugation methods can be combined with a fine-tuning procedure using the approximate solution as a starting point. Sect. 3 discusses the amortization design choices and sect. 3.2.2 presents a new amortization perspective on the cycle consistency term used in Wasserstein-2 generative networks (Korotin et al., 2019), which was previously not seen in this way. Sect. 5 shows that amortizing and fine-tuning the conjugate results in state-of-the-art performance in all of the tasks proposed in the Wasserstein-2 benchmark by Korotin et al. (2021a) . Amortization with fine-tuning also nicely models synthetic settings (sect. 6), including for learning a single-block potential flow without using the likelihood.

2. LEARNING DUAL POTENTIALS: A CONJUGATION PERSPECTIVE

This section reviews the standard methods of learning parameterized dual potentials to solve eq. ( 2). The first step is to re-cast the Kantorovich dual problem eq. ( 2) as being over a parametric family of potentials f θ with parameter θ as an input-convex neural network (Amos et al., 2017) or a more general non-convex neural network. Taghvaei and Jalali (2019); Makkuva et al. ( 2020) have laid the foundations for optimizing the parametric potentials for the dual objective with max θ V(θ) where V(θ) := -E x∼α [f θ (x)] -E y∼β [f θ (y)] = -E x∼α [f θ (x)] + E y∼β [J f θ (x(y))] , ( ) where J is the objective to the conjugate optimization problem in eq. ( 3), x(y) is the solution to the convex conjugate, and eq. ( 4) assumes a finite solution to eq. ( 2) exists and replaces the sup with a max. Taghvaei and Jalali ( 2019) show that the model can be learned, i.e. the optimal parameters can be found, by taking gradient steps of the dual with respect to the parameters of the potential, i.e. using ∇ θ V. This derivative going through the loss and conjugation operation can be obtained by applying Danskin's envelope theorem (Danskin, 1966; Bertsekas, 1971) and results in only needing derivatives of the potential: ∇ θ V(θ) = ∇ θ -E x∼α [f θ (x)] + E y∼β [J f θ (x(y))] = -E x∼α [∇ θ f θ (x)] + E y∼β [∇ θ f θ (x(y))] where x(y) is not differentiated through. Assumption 1 A standard assumption is that the conjugate is smooth with a well-defined arg min. This has been shown to hold when f is strongly convex, e.g. in Kakade et al. (2009) , or when f is essentially strictly convex (Rockafellar, 2015, theorem 26.3 ).



);Korotin  et al. (2019);Makkuva et al. (2020)  parameterize f as an input-convex neural network(Amos et al.,  2017), which can universally represent any convex function with enough capacity(Huang et al.,  2020). Other works explore parameterizing f as a non-convex neural network(Nhan Dam et al.,  2019; Korotin et al., 2021a; Rout et al., 2021).

