NEURAL OPTIMAL TRANSPORT

Abstract

We present a novel neural-networks-based algorithm to compute optimal transport maps and plans for strong and weak transport costs. To justify the usage of neural networks, we prove that they are universal approximators of transport plans between probability distributions. We evaluate the performance of our optimal transport algorithm on toy examples and on the unpaired image-to-image translation. (a) Celeba (female) → anime, outdoor → church, deterministic (one-to-one, W2). (b) Handbags → shoes, stochastic (one-to-many, W2,1).

1. INTRODUCTION

Solving optimal transport (OT) problems with neural networks has become widespread in machine learning tentatively starting with the introduction of the large-scale OT (Seguy et al., 2017) and Wasserstein GANs (Arjovsky et al., 2017) . The majority of existing methods compute the OT cost and use it as the loss function to update the generator in generative models (Gulrajani et al., 2017; Liu et al., 2019; Sanjabi et al., 2018; Petzka et al., 2017 ). Recently, (Rout et al., 2022; Daniels et al., 2021) have demonstrated that the OT plan itself can be used as a generative model providing comparable performance in practical tasks. In this paper, we focus on the methods which compute the OT plan. Most recent methods (Korotin et al., 2021b; Rout et al., 2022) consider OT for the quadratic transport cost (the Wasserstein-2 distance, W 2 ) and recover a nonstochastic OT plan, i.e., a deterministic OT map. In general, it may not exist. (Daniels et al., 2021) recover the entropy-regularized stochastic plan, but the procedures for learning the plan and sampling from it are extremely time-consuming due to using score-based models and the Langevin dynamics (Daniels et al., 2021, 6) . Contributions. We propose a novel algorithm to compute deterministic and stochastic OT plans with deep neural networks ( 4.1, 4.2). Our algorithm is designed for weak and strong optimal transport costs ( 2) and generalizes previously known scalable approaches ( 3, 4.3). To reinforce the usage of neural nets, we prove that they are universal approximators of transport plans ( 4.4). We show that our algorithm can be applied to large-scale computer vision tasks ( 5). Notations. We use X , Y, Z to denote Polish spaces and P(X ), P(Y), P(Z) to denote the respective sets of probability distributions on them. We denote the set of probability distributions on X × Y with marginals P and Q by Π(P, Q). For a measurable map T : X × Z → Y (or T : X → Y), we denote the associated push-forward operator by T # .

2. PRELIMINARIES

In this section, we provide key concepts of the OT theory (Villani, 2008; Santambrogio, 2015; Gozlan et al., 2017; Backhoff-Veraguas et al., 2019) that we use in our paper.  (P, Q) def = inf T # P=Q X c x, T (x) dP(x), (1) where the minimum is taken over measurable functions (transport maps) T : X → Y that map P to Q (Figure 2 ). The optimal T * is called the OT map. Note that (1) is not symmetric and does not allow mass splitting, i.e., for some P, Q ∈ P(X ), P(Y), there may be no T satisfying T # P = Q. Thus, (Kantorovitch, 1958) proposed the following relaxation: Cost(P, Q) def = inf π∈Π(P,Q) X ×Y c(x, y)dπ(x, y), where the minimum is taken over all transport plans π (Figure 3a ), i.e., distributions on X × Y whose marginals are P and Q. The optimal π * ∈ Π(P, Q) is called the optimal transport plan. If π * is of the form [id, T * ] # P ∈ Π(P, Q) for some T * , then T * minimizes (1). In this case, the plan is called deterministic. Otherwise, it is called stochastic (nondeterministic).  where π(•|x) denotes the conditional distribution (Figure 3b ). Note that (3) is a generalization of (2). Indeed, for cost C x, µ = Y c(x, y)dµ(y), the weak formulation (3) becomes strong (2). An



Figure 1: Unpaired translation with our Neural Optimal Transport (NOT) Algorithm 1.

Figure 2: Monge's OT formulation.

Figure 3: Strong (Kantorovich's) and weak (Gozlan et al., 2017) optimal transport fomulations. An example of OT cost for X = Y = R D is the (p-th power of) Wasserstein-p distance W p , i.e., formulation (2) with c(x, y) = ∥x -y∥ p . Two its most popular cases are p = 1, 2 (W 1 , W 22 ). Weak OT formulation(Gozlan et al., 2017). Let C : X × P(Y) → R be a weak cost, i.e., a function which takes a point x ∈ X and a distribution of y ∈ Y as input. The weak OT cost between P, Q is

