NEURAL OPTIMAL TRANSPORT WITH GENERAL COST FUNCTIONALS

Abstract

We present a novel neural-networks-based algorithm to compute optimal transport (OT) plans and maps for general cost functionals. The algorithm is based on a saddle point reformulation of the OT problem and generalizes prior OT methods for weak and strong cost functionals. As an application, we construct a functional to map data distributions with preserving the class-wise structure of data.

1. INTRODUCTION

Optimal transport (OT) is a powerful framework to solve mass-moving and generative modeling problems for data distributions. Recent works (Korotin et al., 2022c; Rout et al., 2022; Korotin et al., 2021b; Fan et al., 2021a; Daniels et al., 2021) propose scalable neural methods to compute OT plans (or maps). They show that the learned transport plan (or map) can be used directly as the generative model in data synthesis (Rout et al., 2022) and unpaired learning (Korotin et al., 2022c; Rout et al., 2022; Daniels et al., 2021; Gazdieva et al., 2022) . Compared to WGANs (Arjovsky et al., 2017) which employ OT cost as the loss for generator (Rout et al., 2022, 3) , these methods provide better flexibility: the properties of the learned model can be controlled by the transport cost function. Existing neural OT plan (or map) methods consider distance-based cost functions, e.g., weak or strong quadratic costs (Korotin et al., 2022c; Fan et al., 2021a; Gazdieva et al., 2022) . Such costs are suitable for the tasks of unpaired image-to-image style translation (Zhu et al., 2017, Figures 1, 2 ) and image restoration (Lugmayr et al., 2020) . However, they do not take into account the class-wise structure of data or available side information, e.g., the class labels. As a result, such costs are hardly applicable to certain tasks such as the dataset transfer where the preservation the class-wise structure is needed (Figure 1 ). We tackle this issue. Contributions. We propose the extension of neural OT which allows to apply it to previously unreleased problems. For this, we develop a novel neural-networks-based algorithm to compute optimal transport plans for general cost functionals ( 3). As an example, we construct ( 4) and test ( 6) the functional for mapping data distributions with the preservation the class-wise structure. Let X , Y be compact Hausdorf spaces and P ∈ P(X ), Q ∈ P(Y). We use Π(P) ⊂ P(X × Y) to denote the subset of probability measures on X × Y which projection onto the first marginal is P. We use Π(P, Q) ⊂ Π(P) to denote the subset of probability measures (transport plans) on X × Y with marginals P, Q. For u, v ∈ C(X ), C(Y) we write u ⊕ v ∈ C(X × Y) to denote the function u ⊕ v : (x, y) → u(x) + v(y). For a functional F : M(X × Y) → R we say that it is separably *-increasing if for all functions u, v ∈ C(X ), C(Y) and any function c ∈ C(X × Y) from u ⊕ v ≤ c (point-wise) it follows F * (u ⊕ v) ≤ F * (c). For a measurable map T : X × Z → Y, we denote the associated push-forward operator by T # . For Q 1 , Q 2 ∈ P(Y) with Y ⊂ R D , the (square of) energy distance E (Rizzo & Székely, 2016) between them is: E 2 (Q 1 , Q 2 ) = E∥Y 1 -Y 2 ∥ - 1 2 E∥Y 1 -Y ′ 1 ∥ - 1 2 E∥Y 2 -Y ′ 2 ∥, where 1) is a particular case of the Maximum Mean Discrepancy (Sejdinovic et al., 2013) . It equals zero only when Y 1 ∼ Q 1 , Y ′ 1 ∼ Q 1 , Y 2 ∼ Q 2 , Y ′ 2 ∼ Q 2 are independent random vectors. Energy distance ( Q 1 = Q 2 .

2. PRELIMINARIES

In this section, we provide key concepts of the optimal transport theory. Thought the paper, we consider compact X = Y ⊂ R D and P, Q ∈ P(X ), P(Y). Strong OT. For a cost function c ∈ C(X × Y), the optimal transport cost between P, Q is Cost(P, Q) def = inf π∈Π(P,Q) X ×Y c(x, y)dπ(x, y), see (Villani, 2008, 1) Problem (2) admits a minimizer π * ∈ Π(P, Q) which is called an OT plan (Santambrogio, 2015, Theorem 1.4). It may be not unique (Peyré et al., 2019, Remark 2.3 ). Intuitively, the cost function c(x, y) measures how hard it is to move a mass piece between points x ∈ X and y ∈ Y. That is, π * shows how to optimally distribute the mass of P to Q, i.e., with the minimal effort. For cost functions c(x, y) = ∥x -y∥ and c(x, y) = 1 2 ∥x -y∥ 2 , the OT cost (2) is called the Wasserstein-1 (W 1 ) and the (square of) Wasserstein-2 (W 2 ) distance, respectively, see (Villani, 2008, 1) or (Santambrogio, 2015, 1, 2). Weak OT. Consider a weak cost function C : X × P(Y) → R. Its inputs are a point x ∈ X and a distribution of y ∈ Y. The weak OT cost is (Gozlan et al., 2017; Backhoff-Veraguas et al., 2019) Cost(P, Q) def = inf π∈Π(P,Q) X C x, π(•|x) dπ(x), where π(•|x) denotes the conditional distribution. Weak formulation (3) reduces to strong formulation (2) when C(x, µ) = Y c(x, y)dµ(y). An other example of a weak cost function is the γ-weak quadratic cost C x, µ = Y 1 2 ∥x -y∥ 2 dµ(y) -γ 2 Var(µ) , where γ ≥ 0 and Var(µ) is the variance of µ, see (Korotin et al., 2022c, Eq. 5) , (Alibert et al., 2019, 5. 2), (Gozlan & Juillet, 2020, 5.2) for details. For this cost, we denote the optimal value of (3) by W 2 2,γ and call it γ-weak Wasserstein-2. Regularized OT. The expression inside (2) is a linear functional. It is common to add a lower-semicontinuous convex regularizer R : M(X × Y) → R ∪ {∞} with weight γ > 0: Cost(P, Q) def = inf π∈Π(P,Q) X ×Y c(x, y)dπ(x, y) + γR(π) . (4) Regularized formulation (4) typically provides several advantages over original formulation (2). For example, if R(π) is strictly convex, the expression inside (4) is a strictly convex functional in π and yields the unique OT plan π * . Besides, regularized OT typically has better sample complexity (Genevay, 2019; Mena & Niles-Weed, 2019; Genevay et al., 2019) . Common regularizers are the entropic (Cuturi, 2013b), quadratic (Essid & Solomon, 2018) , Lasso (Courty et al., 2016) , etc.



Figure 1: The setup of class-guided dataset transfer. Input P = n α n P n , target Q = n β n Q n distributions are mixtures of N classes. The task is to learn a transport map T preserving the class.The learner has the access to labeled input data ∼ P and only partially labeled target data ∼ Q.Notation. The notation of our paper is based on that of(Paty & Cuturi, 2020; Korotin et al., 2022c). For a compact Hausdorf space S we use P(S) to denote the set of Borel probability distributions on S. We denote the space of continuous R-valued functions on S endowed with the supremum norm by C(S). Its dual space is the space M(S) ⊃ P(S) of finite signed Borel measures over S. For a

