NEURAL OPTIMAL TRANSPORT

Abstract

We present a novel neural-networks-based algorithm to compute optimal transport maps and plans for strong and weak transport costs. To justify the usage of neural networks, we prove that they are universal approximators of transport plans between probability distributions. We evaluate the performance of our optimal transport algorithm on toy examples and on the unpaired image-to-image translation. (a) Celeba (female) → anime, outdoor → church, deterministic (one-to-one, W2). (b) Handbags → shoes, stochastic (one-to-many, W2,1).

1. INTRODUCTION

Solving optimal transport (OT) problems with neural networks has become widespread in machine learning tentatively starting with the introduction of the large-scale OT (Seguy et al., 2017) and Wasserstein GANs (Arjovsky et al., 2017) . The majority of existing methods compute the OT cost and use it as the loss function to update the generator in generative models (Gulrajani et al., 2017; Liu et al., 2019; Sanjabi et al., 2018; Petzka et al., 2017) . Recently, (Rout et al., 2022; Daniels et al., 2021) have demonstrated that the OT plan itself can be used as a generative model providing comparable performance in practical tasks. In this paper, we focus on the methods which compute the OT plan. Most recent methods (Korotin et al., 2021b; Rout et al., 2022) consider OT for the quadratic transport cost (the Wasserstein-2 distance, W 2 ) and recover a nonstochastic OT plan, i.e., a deterministic OT map. In general, it may not exist. (Daniels et al., 2021) recover the entropy-regularized stochastic plan, but the procedures for learning the plan and sampling from it are extremely time-consuming due to using score-based models and the Langevin dynamics (Daniels et al., 2021, 6) . Contributions. We propose a novel algorithm to compute deterministic and stochastic OT plans with deep neural networks ( 4.1, 4.2). Our algorithm is designed for weak and strong optimal transport costs ( 2) and generalizes previously known scalable approaches ( 3, 4. 3). To reinforce the usage of neural nets, we prove that they are universal approximators of transport plans ( 4.4). We show that our algorithm can be applied to large-scale computer vision tasks ( 5). Notations. We use X , Y, Z to denote Polish spaces and P(X ), P(Y), P(Z) to denote the respective sets of probability distributions on them. We denote the set of probability distributions on X × Y with marginals P and Q by Π(P, Q). For a measurable map T : X × Z → Y (or T : X → Y), we denote the associated push-forward operator by T # .

2. PRELIMINARIES

In this section, we provide key concepts of the OT theory (Villani, 2008; Santambrogio, 2015; Gozlan et al., 2017; Backhoff-Veraguas et al., 2019) that we use in our paper. where the minimum is taken over measurable functions (transport maps) T : X → Y that map P to Q (Figure 2 ). The optimal T * is called the OT map. Note that (1) is not symmetric and does not allow mass splitting, i.e., for some P, Q ∈ P(X ), P(Y), there may be no T satisfying T # P = Q. Thus, (Kantorovitch, 1958) proposed the following relaxation: Cost(P, Q) def = inf π∈Π(P,Q) X ×Y c(x, y)dπ(x, y), (2) where the minimum is taken over all transport plans π (Figure 3a ), i.e., distributions on X × Y whose marginals are P and Q. The optimal π * ∈ Π(P, Q) is called the optimal transport plan. If π * is of the form [id, T * ] # P ∈ Π(P, Q) for some T * , then T * minimizes (1). In this case, the plan is called deterministic. Otherwise, it is called stochastic (nondeterministic). (Gozlan et al., 2017) optimal transport fomulations. An example of OT cost for X = Y = R D is the (p-th power of) Wasserstein-p distance W p , i.e., formulation (2) with c(x, y) = ∥x -y∥ p . Two its most popular cases are p = 1, 2 (W 1 , W 2 2 ). Weak OT formulation (Gozlan et al., 2017) . Let C : X × P(Y) → R be a weak cost, i.e., a function which takes a point x ∈ X and a distribution of y ∈ Y as input. The weak OT cost between P, Q is Cost(P, Q) def = inf π∈Π(P,Q) X C x, π(•|x) dπ(x), where π(•|x) denotes the conditional distribution (Figure 3b ). Note that (3) is a generalization of (2). Indeed, for cost C x, µ = Y c(x, y)dµ(y), the weak formulation (3) becomes strong (2). An example of a weak OT cost for X = Y = R D is the γ-weak (γ ≥ 0) Wasserstein-2 (W 2,γ ): C x, µ = Y 1 2 ∥x -y∥ 2 dµ(y) - γ 2 Var(µ) (4) Existence and duality. Throughout the paper, we consider weak costs C(x, µ) which are lower bounded, convex in µ and jointly lower semicontinuous in an appropriate sense. Under these assumptions, (Backhoff-Veraguas et al., 2019) prove that the minimizer π * of (3) always exists.foot_0 With mild assumptions on c, strong costs satisfy these assumptions. In particular, they are linear w.r.t. µ, and, consequently, convex. The γ-weak quadratic cost (4) is lower-bounded (for γ ≤ 1) and is also convex since the functional Var(µ) is concave in µ. For the costs in view, the dual form of (3) is Cost(P, Q) = sup f X f C (x)dP(x) + Y f (y)dQ(y), where f are the upper-bounded continuous functions with not very rapid growth (Backhoff-Veraguas et al., 2019, Equation 1.2) and f C is the weak C-transform of f , i.e. f C (x) def = inf µ∈P(Y) C(x, µ) - Y f (y)dµ(y) . Note that for strong costs C, the infimum is attained at any µ ∈ P(Y) supported on the arg inf y∈Y {c(x, y) -f (y)} set. Therefore, it suffices to use the strong c-transform: f C (x) = f c (x) def = inf y∈Y {c(x, y) -f (y)} . For strong costs (2), formula ( 5) with ( 7) is the well known Kantorovich duality (Villani, 2008, 5) . Nonuniqueness. In general, an OT plan π * is not unique, e.g., see (Peyré et al., 2019, Remark 2.3 ).

3. RELATED WORK

In large-scale machine learning, OT costs are primarily used as the loss to learn generative models. Wasserstein GANs introduced by (Arjovsky et al., 2017; Gulrajani et al., 2017) are the most popular examples of this approach. We refer to (Korotin et al., 2022b; 2021b) for recent surveys of principles of WGANs. However, these models are out of scope of our paper since they only compute the OT cost but not OT plans or maps ( 4.3). To compute OT plans (or maps) is a more challenging problem, and only a limited number of scalable methods to solve it have been developed. We overview methods to compute OT plans (or maps) below. We emphasize that existing methods are designed only for strong OT formulation (2). Most of them search for a deterministic solution (1), i.e., for a map T * rather than a stochastic plan π * , although T * might not always exist. To compute the OT plan (map), (Lu et al., 2020; Xie et al., 2019) approach the primal formulation (1) or (2). Their methods imply using generative models and yield complex optimization objectives with several adversarial regularizers, e.g., they are used to enforce the boundary condition (T # P = Q). As a result, the methods are hard to setup since they require careful selection of hyperparameters. In contrast, methods based on the dual formulation (5) have simpler optimization procedures. Most of such methods are designed for OT with the quadratic cost, i.e., the Wasserstein-2 distance (W 2 2 ). An evaluation of these methods is provided in (Korotin et al., 2021b) . Below we mention their issues. Methods by (Taghvaei & Jalali, 2019; Makkuva et al., 2020; Korotin et al., 2021a; c) based on inputconvex neural networks (ICNNs, see (Amos et al., 2017) ) have solid theoretical justification, but do not provide sufficient performance in practical large-scale problems. Methods based on entropy regularized OT (Genevay et al., 2016; Seguy et al., 2017; Daniels et al., 2021) recover regularized OT plan that is biased from the true one, it is hard to sample from it or compute its density. According to (Korotin et al., 2021b) , the best performing approach is ⌈MM:R⌋, which is based on the maximin reformulation of (5). It recovers OT maps fairly well and has a good generative performance. The follow-up papers (Rout et al., 2022; Fan et al., 2022) test extensions of this approach for more general strong transport costs c(•, •) and apply it to compute W 2 barycenters (Korotin et al., 2022a) . Their key limitation is that it aims to recover a deterministic OT map T * which might not exist. In this section, we develop a novel neural algorithm to recover a solution π * of OT problem (3). The following lemma will play an important role in our derivations. Lemma 1 (Existence of transport maps.). Let µ and ν be probability distributions on R M and R N . Assume that µ is atomless. Then there exists a measurable t : R M → R N satisfying t # µ = ν. Proof. (Santambrogio, 2015, Cor. 1.29) proves the fact for M = N . The proof works for M ̸ = N . Throughout the paper we assume that P, Q are supported on subsets X ⊂ R P , Y ⊂ R Q , respectively.

4.1. REFORMULATION OF THE DUAL PROBLEM

First, we reformulate the optimization in C-transform (6). For this, we introduce a subset Z ⊂ R S with an atomless distribution S on it, e.g., S = Uniform [0, 1] or N (0, 1). Lemma 2 (Reformulation of the C-transform). The following equality holds: f C (x) = inf t C(x, t # S)- Z f t(z) dS(z) , where the infimum is taken over all measurable t : Z → Y. Proof. For all x ∈ X and t : Z → Y we have f C (x) ≤ C(x, t # S) -Z f t(z) dS(z). The inequality is straightforward: we substitute µ = t # S to (6) to upper bound f C (x) and use the change of variables. Taking the infimum over t, we obtain f C (x) ≤ inf t C(x, t # S) - Z f t(z) dS(z) . Now let us turn (9) to an equality. We need to show that ∀ϵ > 0 there exists t ϵ : Z → Y satisfying f C (x)+ϵ ≥ C(x, t ϵ # S)- Z f t ϵ (z) dS(z). By ( 6) and the definition of inf, ∃µ ϵ ∈ P(Y) such that f C (x) + ϵ ≥ C(x, µ ϵ ) -Y f (y)dµ ϵ (y). Thanks to Lemma 1, there exists t ϵ : Z → Y such that µ ϵ = t ϵ # S, i.e., (10) holds true. Now we use Lemma 2 to get an analogous reformulation of the integral of f C in the dual form (5). Lemma 3 (Reformulation of the integrated C-transform). The following equality holds: X f C (x)dP(x) = inf T X C x, T (x, •) # S - Z f T (x, z) dS(z) dP(x), where the inner minimization is performed over all measurable functions T : X × Z → Y. Proof. The lemma follows from the interchange between the infimum and integral provided by the Rockafellar's interchange theorem (Rockafellar, 1976, Theorem 3A) . The theorem states that for a function F : A × B → R and a distribution ν on A, A inf b∈B F (a, b)dν(a) = inf H:A→B A F (a, H(a))dν(a) We apply (12), use A = X , ν = P, and put B to be the space of measurable functions Z → Y, and F (a, b) = C(a, b # S)-Y f y d b # S (y). Consequently, we obtain that X f C (x)dP(x) equals inf H X C x, H(x) # S - Y f (y)d H(x) # S](y) dP(x) Finally, we note that the optimization over functions H : X → {t : Z → Y} equals the optimization over functions T : X × Z → Y. We put T (x, z) = [H(x)](z), use the change of variables for y = T (x, z) and derive (11) from (13). Lemma 3 provides the way to represent the dual form (5) as a saddle point optimization problem. Corollary 1 (Maximin reformulation of the dual problem). The following holds: Cost(P, Q) = sup f inf T L(f, T ), ( ) where the functional L is defined by L(f, T ) def = Y f (y)dQ(y) + X C x, T (x, •) # S - Z f T (x, z) dS(z) dP(x). ( ) Proof. It suffices to substitute (11) into (5). We say that functions T : X × Z → Y are stochastic maps. If a map T is independent of z, i.e., for all (x, z) ∈ X × Z we have T (x, z) ≡ T (x), we say the map is deterministic. The idea behind the introduced notation is the following. An optimal transport plan π * might be nondeterministic, i.e., there might not exist a deterministic function T : X → Y which satisfies π * = [id X , T ] # P. However, each transport plan π ∈ Π(P, Q) can be represented implicitly through a stochastic function T : X × Z → Y. This fact is known as noise outsourcing (Kallenberg, 1997, Theorem 5.10) for Z = [0, 1] ⊂ R 1 and S = Uniform([0, 1]). Combined with Lemma 1, the noise outsourcing also holds for a general Z ⊂ R S and atomless S ∈ P(Z). We visualize the idea in Figure 4 . For a plan π, there might exist multiple maps T which represent it. For a pair of probability distributions P, Q, we say that T * is a stochastic optimal transport map if it realizes some optimal transport plan π * . Such maps solve the inner problem in ( 14) for optimal f * . Lemma 4 (Optimal maps solve the maximin problem). For any maximizer f * of (5) and for any stochastic map T * which realizes some optimal transport plan π * , it holds that T * ∈ arg inf T L(f * , T ). ( ) Proof. Let π * be the OT plan realized by T * . We derive X Z f * T * (x, z) dS(z)dP(x) = X Y f * (y)dπ * (y|x)dπ * (x) = X ×Y f * (y)dπ * (x, y) = Y f * (y)dQ(y), where we change the variables for y = T * (x, z) and use the property dπ * (x) = dP(x). Now assume that T * / ∈ arg inf T L(f * , T ). In this case, from the definition (15) we conclude that L(f * , T * ) > Cost(P, Q). However, we derive substituting ( 17) into (15), we see that L(f * , T * ) = X C x, T * (x, •) # S π * (y|x) dP(x) dπ * (x) = Cost(P, Q), which is a contradiction. Thus, (16) holds true. For the γ-weak quadratic cost (4) which we use in the experiments ( 5), a maximizer f * of (5) indeed exists, see (Alibert et al., 2019, 5.22) or (Gozlan & Juillet, 2020) . Thanks to our Lemma 4, one may solve the saddle point problem ( 14) and extract an optimal stochastic transport map T * from its solution (f * , T * ). In general, the arg inf set for f * may contain not only the optimal stochastic transport maps but other stochastic functions as well. In Appendix F, we show that for strictly convex (in µ) costs C(x, µ), all the solutions of ( 14) provide stochastic OT maps.

4.2. PRACTICAL OPTIMIZATION PROCEDURE

To approach the problem ( 14) in practice, we use neural networks T θ : R P × R S → R Q and f ω : R Q → R to parameterize T and f , respectively. We train their parameters with the stochastic gradient ascent-descent (SGAD) by using random batches from P, Q, S, see Algorithm 1. Algorithm 1: Neural optimal transport (NOT) Input :distributions P, Q, S accessible by samples; mapping network T θ : R P × R S → R Q ; potential network f ω : R Q → R; number of inner iterations K T ; (weak) cost C : X ×P(Y) → R; empirical estimator C x, T (x, Z) for the cost; Output :learned stochastic OT map T θ representing an OT plan between distributions P, Q; repeat Sample batches Y ∼ Q, X ∼ P; for each x ∈ X sample batch Z x ∼ S; L f ← 1 |X| x∈X 1 |Zx| z∈Zx f ω T θ (x, z) -1 |Y | y∈Y f ω (y); Update ω by using ∂L f ∂ω ; for k T = 1, 2, . . . , K T do Sample batch X ∼ P; for each x ∈ X sample batch Z x ∼ S; L T ← 1 |X| x∈X C x, T θ (x, Z x ) -1 |Zx| z∈Zx f ω T θ (x, z) ; Update θ by using ∂L T ∂θ ; until not converged; duct tape Our Algorithm 1 requires an empirical estimator C for C x, T (x, •) # S . If the cost is strong, it is straightforward to use the following unbiased Monte-Carlo estimator from a random batch Z ∼ S: C x, T (x, •) # S = Z c(x, T (x, z))dS(z) ≈ 1 |Z| z∈Z c x, T (x, z) def = C x, T (x, Z) . ( ) For general costs C, providing an estimator might be nontrivial. For the γ-weak quadratic cost (4), such an unbiased Monte-Carlo estimator is straightforward to derive: C x, T (x, Z) def = 1 2|Z| z∈Z ∥x -T (x, z)∥ 2 - γ 2 σ2 , ( ) where σ2 is the (corrected ) batch variance σ2 = 1 |Z|-1 z∈Z ∥T (x, z) -1 |Z| z∈Z T (x, z)∥ 2 . To estimate strong costs (18), it is enough to sample a single noise vector (|Z| = 1). To estimate the γ-weak quadratic cost (19), one needs |Z| ≥ 2 since the estimation of the variance σ2 is needed.

4.3. RELATION TO PRIOR WORKS

Generative adversarial learning. Our algorithm 1 is a novel approach to learn stochastic OT plans; it is not a GAN or WGAN-based solution endowed with additional losses such as the OT cost. WGANs (Arjovsky et al., 2017) do not learn an OT plan but use the (strong) OT cost as the loss to learn the generator network. Their problem is inf T sup f V(T, f ). The generator T * solves the outer inf T problem and is the first coordinate of an optimal saddle point (T * , f * ). In our algorithm 1, problem ( 15) is sup f inf T L(f, T ), the generator (transport map) T * solves of the inner inf T problem and is the second coordinate of an optimal saddle point (f * , T * ). Intuitively, in our case the generator T is adversarial to potential f (discriminator), not vise-versa as in GANs. Theoretically, the problem is also significantly different -swapping inf T and sup f , in general, yields a different problem with different solutions, e.g., 1 = inf x sup y sin(x+y) ̸ = sup y inf x sin(x+y) = -1. Practically, we do K T > 1 updates of T per one step of f , which again differs from common GAN practices, where multiple updates of f are done per a step of T . Finally, in contrast to WGANs, we do not need to enforce any constraints on f , e.g., the 1-Lipschitz continuity. Stochastic generator parameterization. We add an additional noise input z to transport map T (x, z) to make it stochastic. This approach is a common technical instrument to parameterize one-to-many mappings in generative modeling, see (Almahairi et al., 2018, 3.1) or (Zhu et al., 2017b, 3) . In the context of OT, (Yang & Uhler, 2019) employ a stochastic generator to learn a transport plan π in the unbalanced OT problem (Chizat, 2017) . Due to this, their optimization objective slightly resembles ours (15). However, this similarity is deceptive, see Appendix G. Dual OT solvers. Our algorithm 1 recovers stochastic plans for weak costs (3). It subsumes previously known approaches which learn deterministic OT maps for strong costs (2). When the cost is strong (3) and transport map T is restricted to be deterministic T (x, z) ≡ T (x), our Algorithm 1 yields maximin method ⌈MM:R⌋, which was discussed in (Korotin et al., 2021b, 2) for the quadratic cost 1 2 ∥x -y∥ 2 and further developed by (Rout et al., 2022) for the Q-embedded cost -⟨Q(x), y⟩ and by (Fan et al., 2022) for other strong costs c(x, y). These works are the most related to our study.

4.4. UNIVERSAL APPROXIMATION WITH NEURAL NETWORKS

In this section, we show that it is possible to approximate transport maps with neural nets. Theorem 1 (Neural networks are universal approximators of stochastic transport maps). Assume that X , Z are compact and Q has finite second moment. Let T be a stochastic map from P to Q (not necessarily optimal). Then for any nonaffine continuous activation function which is continuously differentiable at at least one point (with nonzero derivative at that point) and for any ϵ > 0, there exists a neural network T θ : R P × R S → R Q satisfying ∥T θ -T ∥ 2 L 2 ≤ ϵ and W 2 2 (T θ ) # (P × S), Q ≤ ϵ, where L 2 = L 2 (P × S, X × Z → R Q ) is the space of quadratically integrable w.r.t. P × S functions X × Z → R Q . That is, the network T θ generates a distribution which is ϵ-close to Q in W 2 2 . Proof. The squared norm ∥T ∥ 2 L 2 is equal to the second moment of Q since T pushes P × S to Q. The distribution Q has finite second moment, and, consequently, T ∈ Lfoot_1 . Thanks to (Folland, 1999, Proposition 7.9), the continuous functions C 0 (X × Z → R Q ) are dense 2 in L 2 . According to (Kidger & Lyons, 2020 , Theorem 3.2), the neural networks R P × R S → R Q with the abovementioned activations are dense in C 0 (X × Z → R Q ) w.r.t. L ∞ norm and, consequently, w.r.t. L 2 norm. Combining these results yields that neural nets are dense in L 2 , and for every ϵ > 0 there necessarily exists network T θ satisfying the left inequality in (20). For T θ , the right inequality follows from (Korotin et al., 2021a, Lemma A.2) . Our Theorem 1 states that neural nets can approximate stochastic maps in L 2 norm. It should be taken into account that such continuous nets T θ may be highly irregular and hard to learn in practice.

5. EVALUATION

We perform comparison with the weak discrete OT (considered as the ground truth) on toy 2D, 1D distributions in Appendices B, C, respectively. In this section, we test our algorithm on an unpaired image-to-image translation task. We perform comparison with popular existing translation methods in Appendix D. The code is written in PyTorch framework and is publicly available at https://github.com/iamalexkorotin/NeuralOptimalTransport Image datasets. We use the following publicly available datasets as P, Q: aligned anime facesfoot_2 , celebrity faces (Liu et al., 2015) , shoes (Yu & Grauman, 2014) , Amazon handbags, churches from LSUN dataset (Yu et al., 2015) , outdoor images from the MIT places database (Zhou et al., 2014) . The size of datasets varies from 50K to 500K images. Train-test split. We pick 90% of each dataset for unpaired training. The rest 10% are considered as the test set. All the results presented here are exclusively for test images, i.e., unseen data. Transport costs. We experiment with the strong (γ = 0) and γ-weak (γ > 0) quadratic costs. Testing other costs, e.g., perceptual (Johnson et al., 2016) or semantic (Cherian & Sullivan, 2019) , might be interesting practically, but these two quadratic costs already provide promising performance. The other training details are given in Appendix E.

5.1. PRELIMINARY EVALUATION

In the preliminary experiments with strong cost (γ = 0), we noted that T (x, z) becomes independent of z. For a fixed potential f and a point x, the map T (x, •) learns to be the map pushing distribution S to some arg inf distribution µ of (6). For strong costs, there are suitable degenerate distributions µ, see the discussion around (7). Thus, for T it becomes unnecessary to keep any dependence on z; it simply learns a deterministic map T (x, z) = T (x). We call this behavior a conditional collapse. Importantly, for the γ-weak cost (γ > 0), we noted a different behavior: the stochastic map T (x, z) did not collapse conditionally. To explain this, we substitute (4) into (3) to obtain W 2 2,γ (P, Q) = inf π∈Π(P,Q) X ×Y 1 2 ∥x -y∥ 2 dπ(x, y) -γ • X 1 2 Var π(y|x) dπ(x) dP(x) . The first term is analogous to the strong cost (W 2 = W 2,0 ), while the additional second term stimulates the OT plan to be stochastic, i.e., to have high conditional variance. Taking into account our preliminary findings, we perform two types of experiments. In §5.2, we learn deterministic (one-to-one) translation maps T (x) for the strong cost (γ = 0), i.e., do not add z-channel. In §5.3, we learn stochastic (one-to-many) maps T (x, z) for the γ-weak cost (γ > 0). For completeness, in Appendix A, we study how varying γ affects the diversity of samples.

5.2. ONE-TO-ONE TRANSLATION WITH OPTIMAL MAPS

We learn deterministic OT maps between various pairs of datasets. We provide the results in Figures 1a and 5 . Extra results for all the dataset pairs that we consider are given in Appendix H. Being optimal, our translation map T (x) tries to minimally change the image content x in the L 2 pixel space. This results in preserving certain features during translation. In shoes ↔ handbags (Figures 5b, 5a ), the image color and texture of the pushforward samples reflects those of input samples. In celeba (female) ↔ anime (Figures 1a, 5c, 5d), head forms, hairstyles are mostly similar for input and output images. The hair in anime is usually bigger than that in celeba. Thus, when translating celeba (female) ↔ anime, the anime hair inherits the color from the celebrity image background. In outdoor → churches (Figure 1a ), the ground and the sky are preserved, in celeba (male) → celeba (female) (Figure 5e ) -the face does not change. We also provide results for translation in the case when the input and output domains are significantly different, see anime → shoes (Figure 5f ). Related work. Existing unpaired translation models, e.g., CycleGAN (Zhu et al., 2017a) or UNIT (Liu et al., 2017) , typically have complex adversarial optimization objectives endowed with additional losses. These models require simultaneous optimization of several neural networks. Importantly, vanilla CycleGAN searches for a random translation map and is not capable of preserving certain attributes, e.g., the color, see (Lu et al., 2019, Figure 5b ). To handle this issue, imposing extra losses is required (Benaim & Wolf, 2017; Kim et al., 2017) , which further complicates the hyperparameter selection. In contrast, our approach has a straightforward objective (14); we use only 2 networks (potential f , map T ), see Table 2 for the comparison of hyperparameters. While the majority of existing unpaired translation models are based on GANs, recent work (Su et al., 2023) proposes a diffusion model (DDIBs) and relates it to Schrödinger Bridge (Léonard, 2014) , i.e., entropic OT.

5.3. ONE-TO-MANY TRANSLATION WITH OPTIMAL PLANS

We learn stochastic OT maps between various pairs of datasets for the γ-weak quadratic cost. The parameter γ equals 2 3 or 1 in the experiments. We provide the results in Figures 1b and 6 . In all the Related work. Transforming a one-to-one learning pipeline to one-to-many is nontrivial. Simply adding additional noise input leads to conditional collapse (Zhang, 2018) . This is resolved by AugCycleGAN (Almahairi et al., 2018) and M-UNIT (Huang et al., 2018) , but their optimization objectives are much more complicated then vanilla versions. Our method optimizes only 2 nets f, T in straightforward objective ( 14). It offers a single parameter γ to control the amount of variability in the learned maps. We refer to Table 2 for the comparison of hyperparameters of the methods.

6. DISCUSSION

Potential impact. Our method is a novel generic tool to align probability distributions with deterministic and stochastic transport maps. Beside unpaired translation, we expect our approach to be applied to other one-to-one and one-to-many unpaired learning tasks as well (image restoration, domain adaptation, etc.) and improve existing models in those fields. Compared to the popular models based on GANs (Goodfellow et al., 2014) or diffusion models (Ho et al., 2020) , our method provides better interpretability of the learned map and allows to control the amount of diversity in generated samples (Appendix A). It should be taken into account that OT maps we learn might be suitable not for all unpaired tasks. We mark designing task-specific transport costs as a promising research direction. Limitations. Our method searches for a solution (f * , T * ) of a saddle point problem ( 14) and extracts the stochastic OT map T * from it. We highlight after Lemma 4 and in 5.1 that not all T * are optimal stochastic OT maps. For strong costs, the issue leads to the conditional collapse. Studying saddle points of ( 14) and arg inf sets ( 16) is an important challenge to address in the further research. Potential societal impact. Our developed method is at the junction of optimal transport and generative learning. In practice, generative models and optimal transport are widely used in entertainment (image-manipulation applications like adding masks to images, hair coloring, etc.), design, computer graphics, rendering, etc. Our method is potentially applicable to many problems appearing in mentioned industries. While the mentioned applications allow making image processing methods publicly available, a potential negative is that they might transform some jobs in the graphics industry. Reproducibility. We provide the source code for all experiments and release the checkpoints for all models of 5. The details are given in README.MD in the official repository. situation is even worsened by the nonuniqueness of π * . To cope with this issue, we consider the weak quadratic cost with γ = 1. For this cost, one may derive C x, µ = Y 1 2 ∥x -y∥ 2 dµ(y) - 1 2 Var(µ) = 1 2 ∥x - Y y dµ(y)∥ 2 . ( ) For cost ( 21) and a pair P, Q, (Gozlan & Juillet, 2020, Theorem 1.2) states that there exists a P-unique (up to a constant) convex ψ : R P → R such that every OT plan π * satisfies ∇ψ(x) = Y y dπ * (y|x). Besides, ∇ψ : R P → R P is 1-Lipschitz. Let T (x, z) be the stochastic map recovered by our Algorithm 1, and let π be the corresponding plan. Let T (x) def = Y y d π(y|x) = Z T (x, z)dS(z). ( ) Due to the above mentioned characterization of OT plans, T (x) should look like a gradient ∇ψ(x) of some convex function ψ(x) and should nearly be a contraction. Since here we work in the 2D space, we are able to get sufficiently many samples from P and Q and obtain a fine approximation of an OT plan π * and ∇ψ by a discrete weak OT solver. We may sample random batches from X ∼ P and Y ∼ Q of size 2 10 and use ot.weak from POT libraryfoot_3 to get some optimal π * and ∇ψ = Y y dπ * (y|x). We are going to compare our recovered average map T with ∇ψ. Datasets. We test 2 pairs P, Q: Gaussian → Mixture of 8 Gaussians; Gaussian → Swiss roll. Neural Networks. We use multi-layer perceptrons as f ω , T θ with 3 hidden layers of 100 neurons and ReLU nonlinearity. The input of the stochastic map T θ (x, z) is 2 + 2 = 4 dimensional. The two first dimensions represent the input x ∈ R 2 while the other dimensions represent the noise z ∼ S. We employ a Gaussian noise with σ = 0.1 Discussion. We provide qualitative results in Figures 8 and 9 . In both cases, the pushforward distribution T # (P × S) matches the desired target distribution Q (Figures 8c and 9c ). Figures 8e Note that T indeed roughly equals a gradient of a convex function. The gradients of convex functions are cycle monotone (Rockafellar, 1966) . Cycle monotonicity yields that for x 1 ̸ = x 2 the segments [x 1 , ∇ψ(x 1 )] and [x 2 , ∇ψ(x 2 )] do not intersect in the inner points (Villani, 2008, 8) . 5 Visually, we see that in Figures 8d and 9d the segments [x, T (x)] do not intersect for different x, which is good.

C TOY 1D EXPERIMENTS

In this section, we additionally test our Algorithm 1 on toy 1D distributions P, Q, i.e., P = Q = 1. In this case, transport plans are 2D distributions and can be conveniently visualized. We experiment with the 1-weak quadratic cost (21). Following the discussion in the previous section, we recall that an OT plan π * may be not unique. However, all OT plans satisfy ∇ψ(x) = Y y dπ * (y|x) for some 1-smooth convex function ψ : R → R. This simply means that x → ∇ψ(x) = Y y dπ * (y|x) is a monotone increasing 1-Lipschitz function ∇ψ : R → R. Below we check that this necessary condition holds for T (22), where T is our learned stochastic map. Datasets. We test 2 pairs P, Q: Gaussian → Mix of 2 Gaussians; Gaussian → Mix of 3 Gaussians. Neural Networks. We use the same networks as in Appendix B. This time, the input of the stochastic map T θ (x, z) is 1 + 1 = 2 dimensional, the input to f ω -1-dimensional. Discussion. We provide qualitative results in Figures 10 and 11 . For each case, we plot the results of 3 random restarts of our method (π denotes our learned OT plan). Similarly to Appendix B, we plot the results obtained by a discrete weak OT solver (ot.weak from POT library). Namely, in Figures 10e, 11e we show its results obtained for 4 restarts with differing seeds. Note that the average maps T computed by our algorithm in both cases nearly match those computed by the discrete weak OT. This indicates that the transport cost of our computed plan π is since [Cost of π] = X 1 2 ∥x -T (x) ≈ψ(x) ∥ 2 dP(x) ≈ X 1 2 ∥x -∇ψ(x)∥ 2 dP(x) = Cost(P, Q), i.e., it nearly equals the optimal cost. Here we use T (x) ≈ ∇ψ(x) observed from the experiments. To conclude, wee see that the recovered plans are close to the DOT considered as the ground truth.

D COMPARISON WITH PRINCIPAL UNPAIRED TRANSLATION METHODS

We compare our Algorithm 1 with popular models for unpaired translation. We consider handbags → shoes (64 × 64), celeba male → female (64 × 64), outdoor → church (128 × 128) translation. For quantitative comparison, we compute Frechet Inception Distancefoot_5 (Heusel et al., 2017, FID) of the Methods. We compare our method with one-to-one CycleGANfoot_6 (Zhu et al., 2017a) , DiscoGANfoot_7 (Kim et al., 2017) and with one-to-many AugCycleGANfoot_8 (Almahairi et al., 2018) and MUNITfoot_9 (Huang et al., 2018) . We use the official or community implementations with the hyperparameters from the respective papers. We choose the above-mentioned methods for comparison because they are principal methods for one-to-one and one-to-many translation. Recent methods (GMM-UNIT (Liu et al., 2020) , COCO-FUNIT (Saito et al., 2020) , StarGAN (Choi et al., 2020) ) are based on them and focus on specific details/setups such as style/content separation, few-shot learning, disentanglement, multi-domain transfer, which are out of scope of our paper. Discussion. Existing one-to-one methods visually preserve the style during translation comparably to our method. Alternative one-to-many methods do not preserve the style at all. When the input and output domains are similar (handbags→shoes, celeba male → female), the FID scores of all the models are comparable. However, most models are outperformed by NOT when the domains are distant (outdoor → church), see Figure 14 and the last row in Table 1 . For completeness, in Table 2 we compare the number of hyperparameters of the translation methods in view. Note that in contrast to the other methods, we optimize only 2 neural networks -transport map and potential. Table 2 : Comparison of the number of hyperparameters of the optimization objectives, the number of networks and their parameters for the considered unpaired translation methods for 64×64 images. 

E EXPERIMENTAL DETAILS

Pre-processing. We beforehand rescale anime face images to 512 × 512, and do 256 × 256 crop with the center located 14 pixels above the image center to get the face. Next, for all these datasets, we rescale RGB channels to [-1, 1] and resize images to the required size (64 × 64 or 128 × 128). We do not apply any augmentations to data. Neural networks. We use WGAN-QC discriminator's ResNet architecture (Liu et al., 2019) for potential f . We use UNet (Ronneberger et al., 2015) as the stochastic transport map T (x, z). The noise z is simply an additional 4th input channel (RGBZ), i.e., the dimension of the noise equals the image size (64 × 64 or 128 × 128). We use high-dimensional Gaussian noise with axis-wise σ = 0.1. Dynamic weak cost. In 5.3, we train the algorithm with the gradually changing γ. Starting from γ = 0, we linearly increase it to the desired value ( 2 3 or 1) during 25K first iterations of f ω . Stability of training. In several cases, we noted that the optimization fluctuates around the saddle points or diverges. An analogous behavior of saddle point methods for OT has been observed in (Korotin et al., 2021b) . For the γ-weak quadratic cost (γ > 0), we sometimes experienced instabilities when the input P is notably less disperse than Q or when the parameter γ is high. Studying this behaviour and improving stability/convergence of the optimization is a promising research direction. Computational complexity. The time and memory complexity of training deterministic OT maps T (x) is comparable to that of training usual generative models for unpaired translation. Our networks converge in 1-3 days on a Tesla V100 GPU (16 GB); wall-clock times depend on the datasets and the image sizes. Training stochastic T (x, z) is harder since we sample multiple random z per x (we use |Z| = 4). Thus, we learn stochastic maps on 4 × Tesla V100 GPUs.

F OPTIMALITY OF SOLUTIONS FOR STRICTLY CONVEX COSTS

Our Lemma 4 proves that optimal maps T * are contained in the arg inf T sets of optimal potentials f * but leaves the question what else may be contained in these arg inf T sets open. Our following result shows that for strictly convex costs, nothing else beside OT maps is contained there. Lemma 5 (Solutions of the maximin problem are OT maps). Let C(x, µ) be a weak cost which is strictly convex in µ. Assume that there exists at least one potential f * which maximizes dual form (5). Consider any such optimal potential f * ∈ arg sup f inf T L(f, T ). It holds that T ∈ arg inf T L(f * , T ) ⇒ T is a stochastic OT map. Proof of Lemma (5). By the definition of f * , we have L(f * , T ) = sup f inf T L(f, T ) = Cost(P, Q), i.e., T attains the optimal cost. It remains to check that it satisfies T ♯(P × S) = Q, i.e., T generates Q from P. Let T * be any true stochastic OT map. We denote µ * x = T * (x, •)♯S and μx = T (x, •)♯S for all x ∈ X and define µ 1 x = 1 2 (µ * x + μx ). Let T 1 : X × Z → Y be any stochastic map which satisfies T 1 (x, •)♯S = µ 1 x for all x ∈ X ( 4.1). By using the change of variables, we derive Cost(P, Q) ≥ L(f * , T 1 ) = X C(x, µ 1 x )dP(x) - We substitute these findings to L(f * , T 1 ) and get Cost(P, Q) ≥ L(f * , T 1 ) ≥ 1 2 L(f * , T * )+ 1 2 L(f * , T ) = 1 2 Cost(P, Q)+ 1 2 Cost(P, Q) = Cost(P, Q). Thus, ( 23) is an equality P-almost surely for all x ∈ X and µ * x = μx holds P-almost surely. This means that T * and T generate the same distribution from P × S, i.e., T is a stochastic OT map. Our generic framework allows learning stochastic transport maps (Lemma 4). For strictly convex costs, all the solutions of our objective ( 14) are guaranteed to be stochastic OT maps (Lemma 5). In the experiments, we focus on strong and weak quadratic costs, which are not strictly convex but still provide promising performance in the downstream task of unpaired image-to-image translation ( 5). Developing strictly convex costs is a promising research avenue for the future work.

G RELATION TO PRIOR WORKS IN UNBALANCED OPTIMAL TRANSPORT

In the context of OT, (Yang & Uhler, 2019) employ a stochastic generator to learn a transport plan π in the unbalanced OT problem (Chizat, 2017) . Due to this, their optimization objective slightly resembles our objective (15). However, this similarity is deceptive. Unlike strong (2) or weak (3) OT, the unbalanced OT is an unconstrained problem, i.e., there is no need to satisfy π ∈ Π(P, Q). This makes unbalanced OT easier to handle: to optimize it one just has to parametrize the plan π and backprop through the loss. The challenging part with which the authors deal is the estimation of the ϕ-divergence terms in the unbalanced OT objective. These terms can be interpreted as a soft relaxation of the constraints π ∈ Π(P, Q), i.e., penalization for disobeying the constraints. The authors compute these terms by employing the variational (dual) formula from f -GAN (Nowozin et al., 2016) . This yields a GAN-style optimization problem min T max f which is similar to other problems in the generative adversarial framework. The problem we tackle is strong (2) and weak (3) OT which requires enforcing of the constraint π ∈ Π(P, Q). We reformulate the dual (weak) OT problem (5) into maximin problem (15) which can be used to recover the OT plan (via the stochastic map T ). Our approach can be viewed as a hard enforcement of the constraints. Our max f min T saddle point problem (15) is atypical for the traditional generative adversarial framework. 



Backhoff-Veraguas et al. (2019) work with the subset Pp(Y) ⊂ P(Y) whose p-th moment is finite. Henceforth, we also work in Pp(Y) equipped with the Wasserstein-p topology. Since this detail is not principal for our subsequent analysis, to keep the exposition simple, we still write P(Y) but actually mean Pp(Y). The proposition considers scalar-valued functions (Q = 1), but is analogous for vector-valued functions. kaggle.com/reitanaka/alignedanimefaces https://pythonot.github.io/ For the sake of clarity, we slightly reformulated the property of the cycle monotone maps(Villani, 2008). github.com/mseitzer/pytorch-fid github.com/eriklindernoren/PyTorch-GAN/tree/master/implementations/ cyclegan github.com/eriklindernoren/PyTorch-GAN/tree/master/implementations/ discogan github.com/aalmah/augmented_cyclegan github.com/NVlabs/MUNIT



Figure 1: Unpaired translation with our Neural Optimal Transport (NOT) Algorithm 1.

Figure 2: Monge's OT formulation.

Figure 3: Strong (Kantorovich's) and weak (Gozlan et al., 2017) optimal transport fomulations.

Figure 4: Stochastic function T (x, z) representing a transport plan. The function's input is x ∈ X and z ∼ S.

(a) Handbags → shoes, 128 × 128. (b) Shoes → handbags, 128 × 128. (c) Celeba (female) → anime, 64 × 64. (d) Anime → celeba (female), 64 × 64.(e) Celeba (male) → celeba (female), 64 × 64.(f) Anime → shoes, 64 × 64.

Figure 5: Unpaired translation with deterministic OT maps (W 2 ).

(a) Celeba (female) → anime, 128 × 128 (W 2, 2 3 ). (b) Outdoor → church, 128 × 128 (W 2, d) Shoes → handbags, 64×64 (W2,1). (e) Anime → shoes, 64 × 64 (W2,1).

Figure 6: Unpaired translation with stochastic OT maps (W 2,γ ). cases, the random noise inputs z ∼ S are not synchronized for different inputs x. The examples with the synchronized noise inputs z are given in Appendix I. Extended results and examples of interpolation in the conditional latent space are given in Appendix H. The stochastic map T (x, z) preserves the attributes of the input image and produces multiple outputs.

The work was supported by the Analytical center under the RF Government (subsidy agreement 000000D730321P5Q0002, Grant No. 70-2021-00145 02.11.2021).(a) Input distribution P. (b) Target distribution Q. (c) Fitted T # (P × S) ≈ Q. (d) Map T (x) = Z T (x, z)dS(z). (e) Learned stochastic map T (x, z). (f) Map x → Y ydπ * (y|x).

Figure 8: Gaussian → Mixture of 8 Gaussians, learned stochastic map for the 1-weak quadratic cost.

(a) Input distribution P. (b) Target distribution Q. (c) Fitted T # (P × S) ≈ Q. (d) Map T (x) = Z T (x, z)dS(z). (e) Learned stochastic map T (x, z). (f) Map x → Y ydπ * (y|x).

Figure 9: Gaussian → Swiss Roll, learned stochastic OT map for the 1-weak quadratic cost.

(a) Input and output distributions. (b) Learned plan π and marginal T# (P × S), test 1. (c) Learned plan π and marginal T# (P × S), test 2. (d) Learned plan π and marginal T# (P × S), test 3.(e) Various optimal plans π * learned by discrete OT (considered here as the ground truth).

Figure 10: Stochastic plans between toy 1D distributions (Figure 10a) learned by NOT (Figures 10b, 10c, 10d) and discrete OT (Figure 10e) with the 1-weak quadratic cost. The figures with the 2D transport plans also demonstrate the average map x → Y ydπ(y|x) (conditional expectation).

(a) Input and output distributions. (b) Learned plan π and marginal T# (P × S), test 1. (c) Learned plan π and marginal T# (P × S), test 2. (d) Learned plan π and marginal T# (P × S), test 3.(e) Various optimal plans π * learned by discrete OT (considered here as the ground truth).

Figure 11: Stochastic plans between toy 1D distributions (Figure 11a) learned by NOT (Figures 11b, 11c, 11d) and discrete OT (Figure 11e) with the 1-weak quadratic cost. The figures with the 2D transport plans also demonstrate the average map x → Y ydπ(y|x) (conditional expectation).

(a) NOT (ours, W2), one-to-one. (b) DiscoGAN, one-to-one. (c) CycleGAN, one-to-one. (d) NOT (ours, W 2, 2 3 ), one-to-many. (e) MUNIT, one-to-many. (f) AugCycleGAN

Figure 12: Handbags → shoes translation (64 × 64) by the methods in view.

(a) NOT (ours, W2), one-to-one. (b) DiscoGAN, one-to-one. (c) CycleGAN, one-to-one.(d) NOT (ours, W 2, 2 3 ), one-to-many. (e) MUNIT, one-to-many. (f) AugCycleGAN

Figure 13: Celeba (male) → Celeba (female) translation (64 × 64) by the methods in view.

(a) DiscoGAN, one-to-one. (b) CycleGAN, one-to-one. (c) NOT (ours, W2), one-to-one. (d) AugCycleGAN, one-to-many. (e) MUNIT, one-to-many. (f) NOT (ours, W 2, 2 3), one-to-many.

Figure 14: Outdoor → church (128 × 128) translation with various methods.

* (y)dµ 1 x (y) dP(x) + Y f * (y)Q(y).

(23)    github.com/milesial/Pytorch-UNet Since C is convex in the second argument, we have strictly convex, the equality in (24) is possible only when µ * x = μx . We also note that Y

(a) Celeba (female) → anime translation, 64 × 64. (b) Anime → celeba (female) translation, 64 × 64.

(c) Celeba (male) → celeba (female) translation, 64 × 64. (d) Anime → shoes translation, 64 × 64.

Figure 16: Unpaired translation with OT maps (W 2 ). Additional examples (part 2).

(a) Input images x and random translated examples T (x, z). (b) Interpolation in the conditional latent space, z = (1 -α)z1 + αz2.

Figure 17: Celeba (female) → anime, 128 × 128, stochastic. Additional examples.

Test FID↓ of the considered unpaired translation methods.

Optimization. We use the Adam optimizer(Kingma & Ba, 2014) with the default betas for both T θ and f ω . The learning rate is lr = 1 • 10 -4 . The batch size is |X| = 64. The number of inner iterations is k T = 10. When training with the weak cost (4), we sample |Z x | = 4 noise samples per each image x in batch. In toy experiments, we do 10K total iterations of f ω update. In the experiments with unpaired translation, our Algorithm 1 converges in ≈ 40K iterations for most datasets.

A VARIANCE-SIMILARITY TRADE-OFF

In this section, we study the effect of the parameter γ on the structure of the learned stochastic map for the γ-weak quadratic cost. We consider handbags → shoes translation (64 × 64) and test γ ∈ {0, 1 3 , 2 3 , 1}. The results are shown in Figure 7 . Discussion. For γ = 0 there is no variety in produced samples (Figure 7a ), i.e., the conditional collapse happens. With the increase of γ (Figures 7b, 7c ), the variety of samples increases and the style of the input images is mostly preserved. For γ = 1 (Figure 7d ), the variety of samples is very high but many of them do not preserve the style of the input image. The parameter γ can be viewed as the trade-off parameter balancing the variance of samples and their similarity to the input.

B TOY 2D EXPERIMENTS

In this section, we test our Algorithm 1 on toy 2D distributions P, Q, i.e., P = Q = 2.Strong quadratic cost (γ = 0). As we noted in 5.1 and Appendix A, for the strong quadratic cost, our method tends to learn deterministic maps T (x, z) = T (x) which are independent of the noise input z. For deterministic maps T (x), our method yields ⌊MM:R⌉ method which has been evaluated in the recent Wasserstein-2 benchmark by (Korotin et al., 2021b) . The authors show that the method recovers OT maps well on synthetic high-dimensional pairs P, Q with known ground truth OT maps. Thus, for brevity, we do not include toy experiments with our method for the strong quadratic cost.Weak quadratic cost (γ > 0). To our knowledge, our method is the first to solve weak OT, i.e., there are no approaches to compare with. The analysis of computed transport plans for weak costs is challenging due to the lack of nontrivial pairs P, Q with known ground truth OT plan π * . The 

I EXAMPLES WITH THE SYNCHRONIZED NOISE

In this section, for handbags→shoes (64×64) and outdoor→church (128×128) datasets, we pick a batch of input data x 1 , . . . , x N ∼ P and noise z 1 , . . . , z K ∼ S to plot the N × K matrix of generated images T θ (x n , z k ). Our goal is to assess whether using the same z k for different x n leads to some shared effects such as the same form a generated shoe or church.The images results are given in Figures 23 and 24 . Qualitatively, we do not find any close relation between images produced with the same noise vectors for different input images x n . 

