NEURAL OPTIMAL TRANSPORT WITH GENERAL COST FUNCTIONALS

Abstract

We present a novel neural-networks-based algorithm to compute optimal transport (OT) plans and maps for general cost functionals. The algorithm is based on a saddle point reformulation of the OT problem and generalizes prior OT methods for weak and strong cost functionals. As an application, we construct a functional to map data distributions with preserving the class-wise structure of data.

1. INTRODUCTION

Optimal transport (OT) is a powerful framework to solve mass-moving and generative modeling problems for data distributions. Recent works (Korotin et al., 2022c; Rout et al., 2022; Korotin et al., 2021b; Fan et al., 2021a; Daniels et al., 2021) propose scalable neural methods to compute OT plans (or maps). They show that the learned transport plan (or map) can be used directly as the generative model in data synthesis (Rout et al., 2022) and unpaired learning (Korotin et al., 2022c; Rout et al., 2022; Daniels et al., 2021; Gazdieva et al., 2022) . Compared to WGANs (Arjovsky et al., 2017) which employ OT cost as the loss for generator (Rout et al., 2022, 3) , these methods provide better flexibility: the properties of the learned model can be controlled by the transport cost function. Existing neural OT plan (or map) methods consider distance-based cost functions, e.g., weak or strong quadratic costs (Korotin et al., 2022c; Fan et al., 2021a; Gazdieva et al., 2022) . Such costs are suitable for the tasks of unpaired image-to-image style translation (Zhu et al., 2017, Figures 1, 2 ) and image restoration (Lugmayr et al., 2020) . However, they do not take into account the class-wise structure of data or available side information, e.g., the class labels. As a result, such costs are hardly applicable to certain tasks such as the dataset transfer where the preservation the class-wise structure is needed (Figure 1 ). We tackle this issue. Contributions. We propose the extension of neural OT which allows to apply it to previously unreleased problems. For this, we develop a novel neural-networks-based algorithm to compute optimal transport plans for general cost functionals ( 3). As an example, we construct ( 4) and test ( 6) the functional for mapping data distributions with the preservation the class-wise structure. The learner has the access to labeled input data ∼ P and only partially labeled target data ∼ Q. Notation. The notation of our paper is based on that of (Paty & Cuturi, 2020; Korotin et al., 2022c) . For a compact Hausdorf space S we use P(S) to denote the set of Borel probability distributions on S. We denote the space of continuous R-valued functions on S endowed with the supremum norm by C(S). Its dual space is the space M(S) ⊃ P(S) of finite signed Borel measures over S. For a functional F : M(S) → R ∪ {∞}, we use F * (h) def = sup π∈M(S) S h(s)dπ(s)-F(π) to denote its convex conjugate functional F * : C(S) → R ∪ {∞}. Let X , Y be compact Hausdorf spaces and P ∈ P(X ), Q ∈ P(Y). We use Π(P) ⊂ P(X × Y) to denote the subset of probability measures on X × Y which projection onto the first marginal is P. We use Π(P, Q) ⊂ Π(P) to denote the subset of probability measures (transport plans) on X × Y with marginals P, Q. For u, v ∈ C(X ), C(Y) we write u ⊕ v ∈ C(X × Y) to denote the function u ⊕ v : (x, y) → u(x) + v(y). For a functional F : M(X × Y) → R we say that it is separably *-increasing if for all functions u, v ∈ C(X ), C(Y) and any function c ∈ C(X × Y) from u ⊕ v ≤ c (point-wise) it follows F * (u ⊕ v) ≤ F * (c). For a measurable map T : X × Z → Y, we denote the associated push-forward operator by T # . For Q 1 , Q 2 ∈ P(Y) with Y ⊂ R D , the (square of) energy distance E (Rizzo & Székely, 2016) between them is: E 2 (Q 1 , Q 2 ) = E∥Y 1 -Y 2 ∥ - 1 2 E∥Y 1 -Y ′ 1 ∥ - 1 2 E∥Y 2 -Y ′ 2 ∥, where 1) is a particular case of the Maximum Mean Discrepancy (Sejdinovic et al., 2013) . It equals zero only when Y 1 ∼ Q 1 , Y ′ 1 ∼ Q 1 , Y 2 ∼ Q 2 , Y ′ 2 ∼ Q 2 are independent random vectors. Energy distance ( Q 1 = Q 2 .

2. PRELIMINARIES

In this section, we provide key concepts of the optimal transport theory. Thought the paper, we consider compact X = Y ⊂ R D and P, Q ∈ P(X ), P(Y). Strong OT. For a cost function c ∈ C(X × Y), the optimal transport cost between P, Q is Cost(P, Q) def = inf π∈Π(P,Q) X ×Y c(x, y)dπ(x, y), see (Villani, 2008, 1) Problem (2) admits a minimizer π * ∈ Π(P, Q) which is called an OT plan (Santambrogio, 2015, Theorem 1.4 ). It may be not unique (Peyré et al., 2019, Remark 2.3) . Intuitively, the cost function c(x, y) measures how hard it is to move a mass piece between points x ∈ X and y ∈ Y. That is, π * shows how to optimally distribute the mass of P to Q, i.e., with the minimal effort. For cost functions c(x, y) = ∥x -y∥ and c(x, y) = 1 2 ∥x -y∥ 2 , the OT cost (2) is called the Wasserstein-1 (W 1 ) and the (square of) Wasserstein-2 (W 2 ) distance, respectively, see (Villani, 2008, 1) or (Santambrogio, 2015, 1, 2) . Weak OT. Consider a weak cost function C : X × P(Y) → R. Its inputs are a point x ∈ X and a distribution of y ∈ Y. The weak OT cost is (Gozlan et al., 2017; Backhoff-Veraguas et al., 2019) Cost(P, Q) def = inf π∈Π(P,Q) X C x, π(•|x) dπ(x), where π(•|x) denotes the conditional distribution. Weak formulation (3) reduces to strong formulation (2) when C(x, µ) = Y c(x, y)dµ(y). An other example of a weak cost function is the γ-weak quadratic cost C x, µ = Y 1 2 ∥x -y∥ 2 dµ(y) -γ 2 Var(µ) , where γ ≥ 0 and Var(µ) is the variance of µ, see (Korotin et al., 2022c, Eq. 5) , (Alibert et al., 2019, 5. 2), (Gozlan & Juillet, 2020, 5. 2) for details. For this cost, we denote the optimal value of (3) by W 2 2,γ and call it γ-weak Wasserstein-2. Regularized OT. The expression inside (2) is a linear functional. It is common to add a lower-semicontinuous convex regularizer R : M(X × Y) → R ∪ {∞} with weight γ > 0: Cost(P, Q) def = inf π∈Π(P,Q) X ×Y c(x, y)dπ(x, y) + γR(π) . Regularized formulation (4) typically provides several advantages over original formulation (2). For example, if R(π) is strictly convex, the expression inside ( 4) is a strictly convex functional in π and yields the unique OT plan π * . Besides, regularized OT typically has better sample complexity (Genevay, 2019; Mena & Niles-Weed, 2019; Genevay et al., 2019) . Common regularizers are the entropic (Cuturi, 2013b) , quadratic (Essid & Solomon, 2018) , Lasso (Courty et al., 2016) , etc. General OT. Let F : M(X × Y) → R ∪ {+∞} be a convex lower-semi-continuous functional. Assume that there exists π ∈ Π(P, Q) for which F(π) < ∞. Consider the problem: Cost(P, Q) def = inf π∈Π(P,Q) F(π). ( ) The problem is a generalization of strong OT (2), weak OT (3), regularized OT (4); following (Paty & Cuturi, 2020) , we call problem (5) a general OT problem. Surprisingly, regularized OT (4) represents the same problem: it is enough to put c(x, y) ≡ 0, γ = 1 and R(π) = F(π) to obtain (5) from (4). That is, regularized OT (4) and general OT (5) can be viewed as equivalent formulations. Existence and duality. With mild assumptions on F, problem (5) admits a minimizer π * (Paty & Cuturi, 2020, Lemma 1). If F is separately *-increasing, the dual problem is Cost(P, Q) = sup u,v X u(x)dP(x) + Y v(y)dQ(y) -F * (u ⊕ v) , where optimization is performed over u, v ∈ C(X ), C(Y) which are called potentials, see (Paty & Cuturi, 2020 

3. ALGORITHM FOR LEARNING OPTIMAL TRANSPORT PLANS

In this section, we derive an algorithm to solve general OT problem (5) with neural networks. We prove that (5) can be reformulated as a saddle point optimization problem ( 3.1) from the solution of which one may implicitly recover the OT plan π * ( 3.2). We give the proofs in Appendix A.

3.1. MAXIMIN REFORMULATION OF THE DUAL PROBLEM

In this subsection, we derive the dual form, which is alternative to (6) and can be used to get the OT plan π * . Our two following theorems constitute the main theoretical idea of our approach. Theorem 1 (Maximin reformulation of the dual problem). For *-separately increasing convex and lower-semi-continuous functional F : M(X × Y) → R ∪ {+∞} it holds Cost(P, Q) = sup v inf π∈Π(P) L(v, π) = sup v inf π∈Π(P) F(π) - Y v(y)dπ(y)]+ Y v(y)dQ(y) , ( ) where the sup is taken over v ∈ C(Y) and π(y) is the marginal distribution over y of the plan π. From (7) we also see that it is enough to consider values of F in π ∈ Π(P) ⊂ M(X × Y). For convention, in further derivations we always consider F(π) = +∞ for π ∈ M(X × Y) \ Π(P). Theorem 2 (Optimal saddle points provide optimal plans). Let v * ∈ arg sup v inf π∈Π(P) L(v, π) be any optimal potential. Then for every optimal transport plan π * ∈ Π(P, Q) it holds: π * ∈ arg inf π∈Π(P) L(v * , π). ( ) If F is strictly convex in π ∈ Π(P), then L(v * , π) is strictly convex as a functional of π. Consequently, it has a unique minimizer. As a result, expression ( 8) is an equality. We have the following corollary. Corollary 1 (Every optimal saddle point provides the optimal transport plan). Assume additionally that F is strictly convex. Then the unique OT plan satisfies π * = arg inf π∈Π(P) L(v * , π). Thanks to our results above, one may solve (7), obtain the OT plan π * from the solution (v * , π * ) of the saddle point problem (7). We propose an algorithm to do this in the next subsections. 3.2 REPLACING MEASURES WITH STOCHASTIC MAPS Formulation (7) requires optimization over probability measures π ∈ Π(P). To make it practically feasible, we reformulate it as the optimization over functions T which generate these measures π. We introduce a latent space Z = R Z and an atomless measure S ∈ P(Z) on it, e.g., S = N (0, I Z ). For every π ∈ P(X × Y), there exists a measurable function T = T π : X × Z → Y which implicitly represents it (Korotin et al., 2022c, 4.1) . Such T π satisfies T π (x, •)♯S = π(y|x) for all Algorithm 1: Neural optimal transport (NOT) for general cost functionals Input :distributions P, Q, S accessible by samples; mapping network T θ : R P × R S → R Q ; potential network v ω : R Q → R; number of inner iterations K T ; empirical estimator F X, T (X, Z) for cost F(T ); Output :learned stochastic OT map T θ representing an OT plan between distributions P, Q; repeat Sample batches Y ∼ Q, X ∼ P and for each x ∈ X sample batch Z[x] ∼ S; L v ← 1 |X| x∈X 1 |Z[x]| z∈Z[x] v ω T θ (x, z) -1 |Y | y∈Y v ω (y); Update ω by using ∂Lv ∂ω ; for k T = 1, 2, . . . , K T do Sample batch X ∼ P and for each x ∈ X sample batch Z[x] ∼ S; L T ← F X, T θ (X, Z) -1 |X| x∈X 1 |Z[x]| z∈Z[x] v ω T θ (x, z) ; Update θ by using ∂L T ∂θ ; until not converged; x ∈ X . That is, given x ∈ X and a random latent vector z ∼ S, the function T produces sample T π (x, z) ∼ π(y|x). In particular, if x ∼ P, the random vector [x, T π (x, z)] is distributed as π. Thus, every π ∈ Π(P) can be implicitly represented as a function T π : X × Z → Y. Note that, in general, there might exist several suitable functions T π . Every measurable function T : X × Z → Y is an implicit representation of the measure π T which is the joint distribution a random vector [x, T (x, z)] with x ∼ P, z ∼ S. Consequently, the optimization over π ∈ Π(P) is equivalent to the optimization over measurable functions T : X × Z → Y. From our Theorem 1, we have the following corollary. Corollary 2. For *-separately increasing, lower-semi-continuous and convex F it holds Cost(P, Q) = sup v inf T L(v, T ) = sup v inf T F(T )- X ×Z v T (x, z) dP(x)dS(z)+ Y v(y)dQ(y) , (9) where the sup is taken over potentials v ∈ C(Y) and inf -over measurable functions T : X ×Z → Y. Here we identify F(T ) def = F(π T ) and L(v, T ) def = L(v, π T ). Following the notation of (Korotin et al., 2022c) , we say that functions T are stochastic maps (Figure 6 ). We say that T * is a stochastic OT map if it represents some OT plan π * , i.e., T * (x, •)♯S = π * (•|x) holds P-almost surely for all x ∈ X . From Theorem 2 and Corollary 1, we obtain the following result. Corollary 3 (Optimal saddle points provide stochastic OT maps). Let v * ∈ arg sup v inf T L(v, T ) be any optimal potential. Then for every stochastic OT map T * it holds: T * ∈ arg inf T L(v * , T ). ( ) If F is strictly convex in π, we have T * ∈ arg inf T L(v * , T ) ⇔ T * is a stochastic OT map. From our results it follows that by solving (9) and obtaining an optimal saddle point (v * , T * ), one gets a stochastic OT map T * . To guarantee that all the solutions are OT maps, one may consider adding strictly convex regularizers to F with some small weight, e.g., the conditional kernel variance (Korotin et al., 2022b) . Problem (9) replaces the constrained optimization over measures π ∈ Π(P) in ( 7) with optimization over stochastic maps T in (9), making it practically feasible.

3.3. GENERIC PRACTICAL OPTIMIZATION PROCEDURE

To approach the problem (9) in practice, we use neural nets T θ : R D × R S → R D and v ω : R D → R to parametrize T and v ω , respectively. We train them with the stochastic gradient ascent-descent (SGAD) by using random batches from P, Q, S.The optimization procedure is given in Algorithm 1. In the implementation, to update networks T θ and v ω , we use Adam optimizer (Kingma & Ba, 2014) . Algorithm 2: Neural optimal transport (NOT) with the class-guided cost functional F G .  Input :distributions P = n α n P n , Q = n β n Q n , S × R S → R Q ; potential network v ω : R Q → R; number of inner iterations K T ; Output :learned stochastic OT map T θ representing an OT plan between distributions P, Q; repeat Sample (unlabeled) batches Y ∼ Q, X ∼ P and for each x ∈ X sample batch Z[x] ∼ S; L v ← 1 |X| x∈X 1 |Z[x]| z∈Z[x] v ω T θ (x, z) -1 |Y | y∈Y v ω (y); Update ω by using ∂Lv ∂ω ; for k T = 1, 2, . . . , K T do Pick n ∈ {1, 2, . . . , N } at random with probabilities (α 1 , . . . , α N ); Sample (labeled) batches X n ∼ P n , Y n ∼ Q; for each x ∈ X sample batch Z n [x] ∼ S; L T ← ∆E 2 X n , T (X n , Z n ), Y n -1 |Xn| x∈Xn 1 |Zn[x]| z∈Zn[x] v ω T θ (x, z) ; Update θ by using ∂L T ∂θ ; until not converged; Our Algorithm 1 requires an empirical estimator F for F(T ). Providing such an estimator might be non-trivial for general F. If F(π) = X C(x, π(•|x))dP(x), i.e., the cost is weak (3), one may use the following unbiased Monte-Carlo estimator: F X, T (X, Z) def = |X| -1 x∈X C x, T (x, Z[x] ) , where C is the respective estimator for the weak cost C and Z[x] denotes a random batch of latent vectors z ∼ S for a given x ∈ X . For strong costs and the γ-weak quadratic cost, the estimator C is given by (Korotin et al., 2022c, Eq. 22 and 23) and our Algorithm 1 for general OT 5 reduces to the NOT algorithm (Korotin et al., 2022c , Algorithm 1) for weak (3) or strong (2) OT. Unlike the predecessor, our algorithm is suitable for general OT formulation (5). In the next section, we propose a cost functional F G (and provide an estimator for it) to solve the class-guided dataset transfer task.

4. CLASS-GUIDED DATASET TRANSFER WITH NEURAL OPTIMAL TRANSPORT

In this section, we show that general cost functionals (5) are useful, for example, for the class-guided dataset transfer. To begin with, we theoretically formalize the problem setup. The task above is related to domain adaptation or transfer learning problems. It does not always have a solution with each P n exactly mapped to Q n due to possible prior/posterior shift (Kouw & Loog, 2018) . We aim to find a stochastic map T between P and Q satisfying T ♯ (P n ×S) ≈ Q n for all n = 1, . . . , N . To solve the above discussed problem, we propose the following functional: F G (π) = F G (T π ) def = N n=1 α n E 2 T π ♯(P n ×S), Q n , where E denotes the energy distance (1). Functional ( 11) is non-negative and attains zero value when the components of P are correctly mapped to the respective components of Q (if this is possible). Theorem 3. Functional F G (π) is convexfoot_0 in π ∈ Π(P) and * -separably increasing. In practice, each of the terms E 2 T π ♯(P n ×S), Q n in (11) admits estimation from samples from π. Proposition 1 (Estimator for E 2 ). Let X n ∼ P n be a batch of K X samples from class n. For each x ∈ X n let Z n [x] ∼ S be a latent batch of size K Z . Consider a batch Y n ∼ Q n of size K Y . Then ∆E 2 X n , T (X n , Z n ), Y n def = 1 K Y • K X • K Z y∈Yn x∈Xn z∈Zn[x] ∥y -T (x, z)∥ - 1 2 • (K 2 X -K X ) • K 2 Z x∈Xn z∈Zn[x] x ′ ∈Xn\{x} z ′ ∈Z x ′ ∥T (x, z) -T (x ′ , z ′ )∥ (12) is an estimator of E 2 T ♯(P n ×S), Q n up to a constant T -independent shift. To estimate F G (T ), one may separately estimate terms E 2 T ♯(P n ×S), Q n for each n and sum them up with weights α n . We only estimate only n-th term with probability α n at each iteration. We highlight the two key details of the estimation of ( 11) which are significantly different from estimation weak OT costs (3) appearing in related works (Korotin et al., 2022c; Fan et al., 2021a; Korotin et al., 2021b) . First, one has to sample not just from the input distribution P, but separately from each its component (class) P n . Moreover, one also has to be able to separately sample from the target distribution's Q components Q n . This is the part where the guidance (semi-supervision) happens. We note that to estimate costs such as strong or weak (3), no target samples from Q are needed at all, i.e., they can be viewed as unsupervised. In practice, we assume that the learner is given a labeled empirical sample from P for training. In contrast, we assume that the available samples from Q are only partially labeled (with ≥ 1 labeled data point per class). That is, we know the class label only for a limited amount of data (Figure 1 ). In this case, all n cost terms (12) can still be stochastically estimated. These cost terms are used to learn the transport map T θ in Algorithm 1. The remaining (unlabeled) samples will be used when training the potential f ω , as labels are not needed to update the potential in (9). We provide the detailed procedure for learning with the functional F G (11) in Algorithm 2.

5. RELATED WORK

Neural networks for OT. To the best of our knowledge, our method ( 3) is the first to compute OT plans for general cost functionals (5). Our duality formula (9) subsumes previously known formulas for weak (3) and strong (2) functionals (Korotin et al., 2022a, Eq. 7) , (Korotin et al., 2021b , Eq. 9), (Rout et al., 2022, Eq. 14) , (Fan et al., 2021a, Eq. 11) , (Henry-Labordere, 2019, Eq. 11), (Gazdieva et al., 2022, Eq. 10) . For the strong quadratic cost, (Makkuva et al., 2019) , (Korotin et al., 2021a, Eq. 10) consider analogous to (9) formulations restricted to convex potentials; they use Input Convex Neural Networks (Amos et al., 2017) to approximate them. These nets are popular in OT (Korotin et al., 2021c; Mokrov et al., 2021; Fan et al., 2021a; Bunne et al., 2021; Alvarez-Melis et al., 2021) but OT algorithms based on them are outperformed (Korotin et al., 2021b) by the above-mentioned unrestricted formulations. In (Genevay et al., 2016; Seguy et al., 2017; Daniels et al., 2021; Fan et al., 2021b) , the authors propose methods for f -divergence regularized functionals (4). The first two methods recover biased plans which is a notable issue in high dimensions (Korotin et al., 2021b, 4. 2). Method (Daniels et al., 2021) is computationally heavy due to using the Langevin dynamics. Many approaches in generative learning use OT cost as the loss function to update generative models (WGANs (Arjovsky & Bottou, 2017; Petzka et al., 2017; Liu et al., 2019) ). They are not related to our work as they do not compute OT plans (or maps). Importantly, saddle point problems such as (9) significantly differ from GANs, see (Gazdieva et al., 2022, 6.2) . Dataset transfer and domain adaptation. Deep distance-based algorithms (Gretton et al., 2012; Long et al., 2015; 2017) or adversarial algorithms (Ganin & Lempitsky, 2015; Long et al., 2018) are common solutions to the domain adaptation problem. Using neural networks, these methods align probability distributions while maintaining the discriminativity between classes (Wang & Deng, 2018) . Mostly they perform domain adaptation for image data at the feature level and are typically not used at the pixel level (data space). Pixel-level adaptation is typically performed by common unsupervised image-to-image translation techniques such as CycleGAN (Zhu et al., 2017; Hoffman et al., 2018; Almahairi et al., 2018) and UNIT (Huang et al., 2018; Liu et al., 2017) . OT and domain adaptation. Discrete OT solvers (EMD (Nash, 2000; Courty et al., 2016) , Sinkhorn (Cuturi, 2013a) , etc.) are usually employed to map labeled input samples to the unlabeled or partially labeled target samples. Combinations with neural feature extractors (Courty et al., et al., 2016) . The available labels can be used to reconstruct the cost function and catch the data structure (Courty et al., 2016; Stuart & Wolfram, 2020; Liu et al., 2020; Li et al., 2019) . Discrete OT performs a matching between the given empirical samples and does not provide out-of-sample estimates. In constrast, our method generalizes to unseen source data and generates new target data. Conditional generative models. Conditional models, e.g., GAN (Mirza & Osindero, 2014) , Adversarial Autoencoder (Makhzani et al., 2015) , use the labels to apply conditional generation. They are not relevant to our work as we do not use any label information during the inference. Our learned mapping is based only on the input content.

6. EXPERIMENTS

In this section, we test NOT with our cost functional F G on toy cases ( 6.1) and image data ( 6.2). For comparison, we consider NOT with Euclidean costs (Korotin et al., 2022c; Rout et al., 2022; Korotin et al., 2021b; Fan et al., 2021a) and image-to-image translation models (Isola et al., 2017; Huang et al., 2018; Liu et al., 2017) . The code is written in PyTorch framework and will be made public along with the trained networks. On the image data, our method converges in 5-15 hours on a Tesla V100 (16 GB). We give the training details (architectures, pre-processing, etc.) in Appendix B. Our Algorithm 2 learns stochastic (one-to-many) transport maps T (x, z). Following (Korotin et al., 2021b, 5) , we also test deterministic T (x, z) ≡ T (x), i.e., do not add a random noise z to input. This disables stochasticity and yields deterministic (one-to-one) transport maps x → T (x). In 6.1, (toy examples), we test only deterministic variant of our method. In 6.2, we test both cases.

6.1. TOY EXAMPLES

The moons. The task is to map two balanced classes of moons (red and green) between P and Q (circles and crosses in Figure 2a , respectively). The target distribution Q is P rotated by 90 degrees. The number of randomly picked labeled samples in each target moon 10. The maps with learned by NOT with the quadratic cost (W 2 , (Korotin et al., 2022c; Rout et al., 2022; Fan et al., 2021a) ) and our functional F G are given in Figures 2c and 2d , respectively. For completeness, in Figure 2b we show the matching performed by a discrete OT-SI algorithm which learns the transport cost with a neural The results of mapping between two notably different datasets (unrelated domains). net from a known classes' correspondence (Liu et al., 2020) . The map for W 2 does not preserve the classes (Figure 2c ), while our map solves the task (Figure 2d ). The Gaussians. Here both P, Q are balanced mixtures of 16 Gaussians, and each color denotes a unique class. The goal is to map Gaussians in P (Figure 3a ) to respective Gaussians in Q (Figure 3b ) which have the same color. The result of our method (10 known target labels per class) is given in Figure 3c . It correctly maps the classes. NOT for the quadratic cost is not shown as it results in the identity map (the same image as Figure 3a ) which is completely mistaken in classes.

6.2. IMAGE DATA EXPERIMENTS

Datasets. We use MNIST (LeCun & Cortes, 2010) , MNIST-M (Ganin & Lempitsky, 2015) , Fashion-MNIST (Xiao et al., 2017) , KMNIST (Clanuwat et al., 2018) datasets as P, Q. Each dataset has 10 (balanced) classes and the pre-defined train-test split. We consider two cases. In the related domains case, source and target are close: MNIST → USPS, MNIST → MNIST-M. In the unrelated domains case, they notably differ: MNIST → KMNIST and FashionMNIST → MNIST. In all the cases, we use the default class correspondence between the datasets. For completeness, we provide an example with imbalanced classes and a non-default correspondence in Appendices E and G, respectively. Baselines: We compare our method to principal unsupervised translation models. We consider (one-to-many) AugCycleGAN (Almahairi et al., 2018) , MUNIT (Huang et al., 2018) . We use the official implementations with the hyperparameters from the respective papers. Also we test NOT with Euclidean cost functions: the quadratic cost 1 2 ∥x -y∥ 2 (W 2 ) and the γ-weak (one-to-many) quadratic cost (W 2,γ , γ = 1 10 ). The above-mentioned methods are unsupervised, i.e., they do not use the label information. For completeness, we add (one-to-one) OTDD flow (Alvarez- Metrics. All the models are fitted on the train parts of datasets; all the provided qualitative and quantitative results are exclusively for test (unseen) data. To evaluate the visual quality, we compute FID (Heusel et al., 2017) of the mapped source test set w.r.t. the target test set. To estimate the accuracy of the mapping we pre-train ResNet18 (He et al., 2016) classifier on the target dataset. We consider the mapping T correct, if the predicted label for the mapped sample T (x, z) matches the corresponding label of x. Results. Qualitative results are show in Figures 4, 5 ; FID, accuracy -in Tables 2 and 1 , respectively. To keep the figures simple, for all the models (one-to-one, one-to-many), we plot a single output per an input. For completeness, we show multiple outputs per an input for our method in Appendix C. Note that in our method and OTDD, we use 10 labeled samples per each class in the target distribution. The other methods under consideration do not use the label information. In the related domains case (Figure 4 ), GAN-based methods and NOT with our guided cost F G show high accuracy ≥ 90%. However, NOT with strong and weak quadratic costs provides low accuracy (35-50%). We presume that this is because for these dataset pairs the ground truth OT map for the (pixel-wise) quadratic cost simply does not preserve the class. This agrees with (Daniels et al., 2021 , Figure 3 ) which test an entropy-regularized quadratic cost in a similar MNIST→USPS setup. For our method with guided cost F G , ablation study on Z size presented in Appendix D. The OTDD gradient flows method provides reasonable accuracy on MNIST→USPS. However, OTDD has much higher FID than the other methods. Visually, the OTDD results are comparable to (Alvarez-Melis & Fusi, 2021 , Figure 3 ). In the unrelated domains case (Figure 5 ), which is of our main interest, all the methods except our method do not preserve the class structure since they do not use label information. Consequently, their accuracy is around 10% (random guess). OTDD does not preserve the class structure and generates samples with worse FID. Only NOT with our cost F G preserves the class labels accurately.

Conclusion.

Our approach provides FID scores which are better or comparable to principal image-toimage translation methods. In the related domain case (Figure 4 ), we provide comparable accuracy of class preservation. When the domains are unrelated, we notably outperform existing approaches. 6.3 DISCUSSION Potential Impact. Our method is a generic tool to learn transport maps between data distributions. In contrast to many other generative models, it allows to control the properties of the learned map via choosing a task-specific cost functional F. Our method could be used for data generation and editing purposes, and analogously to GANs, have promising positive real-world applications, such as digital content creation and artistic expression. At the same time, generative models can also be used for negative data manipulation purposes such as deepfakes. In general, the impact of our work on society depends on the scope of its application and the task at the hand. Limitations. To apply our method, one has to provide an estimator F(T ) for the functional F which may be non-trivial. Besides, the construction of a cost functional F for a particular downstream task may be not straightfoward. This should be taken into account when using the method in practice. Constructing task-specific functionals F and estimators F is a promising future research avenue.

7. REPRODUCIBILITY

To ensure the reproducibility of our experiments, we provide the source code in the supplementary material. For toy experiments 6.1, run twomoons_toy.ipynb and gaussian_toy.ipynb. For the dataset transfer experiments 6.2, run dataset_transfer.ipynb and dataset_transfer_no_z.ipynb. The detailed information about the data preprocessing and training hyperparameters is presented in 6 and Appendix B.

A PROOFS

Proof of Theorem 1. We use the dual form (6) to derive Cost(P, Q) = sup v sup u X u(x)dP(x) F * (u ⊕ v) + Y v(y)dQ(y) = (13) sup v sup u X u(x)dP(x) -sup π X ×Y (u ⊕ v)dπ(x, y) -F(π) + Y v(y)dQ(y) = (14) sup v sup u X u(x)dP(x) + inf π F(π) - X ×Y (u ⊕ v)dπ(x, y) + Y v(y)dQ(y) = (15) sup v sup u inf π F(π) - X u(x)d π -P)(x) - Y v(y)dπ(y) + Y v(y)dQ(y) ≤ (16) sup v sup u inf π∈Π(P) F(π) - X u(x)d π -P)(x) - Y v(y)dπ(y) + Y v(y)dQ(y) = (17) sup v sup u inf π∈Π(P) F(π) - Y v(y)dπ(y) + Y v(y)dQ(y) = (18) sup v inf π∈Π(P) F(π) - Y v(y)dπ(y) + Y v(y)dQ(y) ≤ (19) sup v F(π * ) - Y v(y) dπ * (y) dQ(y) + Y v(y)dQ(y) = F(π * ) = Cost(P, Q). In line ( 13), we group the terms involving the potential u. In line ( 14), we express the conjugate functional F * by using its definition. In transition to line (15), we replace inf π operator with the equivalent sup π operator with the changed sign. In transition to ( 16), we put the term X u(x)dP(x) under the inf π operator; we use definition (u⊕v)(x, y) = u(x)+v(y) to split the integral over π(x, y) into two separate integrals over π(x) and π(y) respectively. In transition to (17), we restrict the inner inf π to probability measures π ∈ Π(P) which have P as the first marginal, i.e. dπ(x) = dP(x). This provides an upper bound on (16), in particular, all u-dependent terms vanish, see (18) . As a result, we remove the sup u operator in line (19) . In transition to line (20) we substitute an optimal plan π * ∈ Π(P, Q) ⊂ Π(Q) to upper bound (19). Since Cost(P, Q) turns to be both an upper bound (20) and a lower bound (13) for ( 19), we conclude that (7) holds. Proof of Theorem 2. Assume that π * / ∈ arg inf π∈Π(P) L(v * , π), i.e., L(v * , π * ) > inf π∈Π(P) L(v * , π) = Cost(P, Q). We substitute v * and π * to L and see that L(v * , π * ) = F(π * ) - Y v(y) dπ * (y) dQ(y) ]+ Y v(y)dQ(y) = F(π * ) = Cost(P, Q), which is a contradiction. Thus, the assumption is wrong and (8) holds. Proof of Proposition 3. First, we prove that it is *-separately increasing. For π ∈ M(X × Y) \ Π(P) it holds that F(π) = +∞. Consequently, X ×Y c(x, y)dπ(x, y) -F(π) = X ×Y u(x) + v(y) dπ(x, y) -F(π) = -∞. When π ∈ Π(P) it holds that π is a probability measure. We integrate u(x) + v(y) ≤ c(x, y) w.r.t. π, substract F(π) and obtain X ×Y c(x, y)dπ(x, y) -F(π) ≥ X ×Y u(x) + v(y) dπ(x, y) -F(π). By taking the sup of ( 21) and ( 22 ) w.r.t. π ∈ M(X × Y), we F * (c) ≥ F * (u ⊕ v). 2 Next, we prove that F is convex. We prove that every term Efoot_1 T π ♯(P n × S), Q n is convex in π. First, we show that π → f n (π) def = T π ♯(P n × S) is linear in π ∈ Π(P). Pick any π 1 , π 2 , π 3 ∈ Π(P) which lie on the same line. Without loss of generatity we assume that π 3 ∈ [π 1 , π 2 ], i.e., π 3 = απ 1 + (1 -α)π 2 for some α ∈ [0, 1]. We need to show that f n (π 3 ) = αf n (π 1 ) + (1 -α)f n (π 2 ). In what follows, for a random variable U we denote its distribution by Law(U ). The first marginal distribution of each π i is P. From the gluing lemma (Villani, 2008, 1) it follows that there exists a triplet of (dependent) random variables (X, Y 1 , Y 2 ) such that Law(X, Y i ) = π i for i = 1, 2. We define Y 3 = Y r , where r is an independent random variable which takes values in {1, 2} with probabilities {α, 1 -α}. From the construction of Y 3 it follows that Law(X, Y 3 ) is a mixture of Law(X, Y 1 ) = π 1 and Law(X, Y 2 ) = π 2 with weights α and 1 -α. Thus, Law(X, Y 3 ) = απ 1 + (1 -α)π 2 = π 3 . We conclude that Law(Y 3 |X = x) = π 3 (•|x) for P-almost all x ∈ X (recall that Law(X) = P). On the other hand, again by the construction, the conditional Law(Y 3 |X = x) is a mixture of Law(Y 1 |X = x) = π 1 (•|x) and Law(Y 2 |X = x) = π 2 (•|x) with weights α and 1 -α. Thus, π 3 (•|x) = απ 1 (•|x) + (1 -α)π 2 (•|x ) holds true for P-almost all x ∈ X . Consider independent random variables X n ∼ P n and Z ∼ S. From the definition of T πi we conclude that Law T πi (x, Z) = π i (•|x) for P-almost all x ∈ X and, since P n is a component of P, for P n -almost all x ∈ X as well. As a result, we define T i = T πi (X n , Z) and derive Law(T 3 |X n = x) = π 3 (•|x) = απ 1 (•|x) + (1 -α)π 2 (•|x) = αLaw(T 1 |X n = x) + (1 -α)Law(T 2 |X n = x) for P n -almost all x ∈ X . Thus, Law(X n , T 3 ) is also a mixture of Law(X n , T 1 ) and Law(X n , T 2 ) with weights α and 1 -α. In particular, Law(T 3 ) = αLaw(T 1 ) + (1 -α)Law(T 2 ). We note that Law(T i ) = f n (π i ) by the definition of f n and obtain (23). Second, we highlight that for every ν ∈ P(Y), the functional P(Y) ∋ µ → E 2 (µ, ν) is convex in µ. Indeed, E 2 is a particular case of (the square of) Maximum Mean Discrepancy (MMD, (Sejdinovic et al., 2013) ). Therefore, there exists a Hilbert space H and a function ϕ : Y → H (feature map), such that E 2 (µ, ν) = Y ϕ(y)dµ(y) - Y ϕ(y)dν(y) 2 H . Since the kernel mean embedding µ → Y ϕ(y)dµ(y) is linear in µ and ∥ • ∥ 2 H is convex, we conclude that E 2 (µ, ν) is convex in µ. To finish the proof, it remains to combine the fact that π → T π ♯(P n × S) is linear and E 2 (•, Q n ) is convex in the first argument. Proof of Proposition 1. Direct calculation of the expectation of (12) yields the value E∥Y -T (X, Z)∥ - 1 2 E∥T (X, Z) -T (X ′ , Z ′ )∥ = E∥Y -T (X, Z)∥ - 1 2 E∥T (X, Z) -T (X ′ , Z ′ )∥ - 1 2 E∥Y -Y ′ ∥ + 1 2 E∥Y -Y ′ ∥ = E 2 T ♯(P n × S), Q n + 1 2 E∥Y -Y ′ ∥, where Y, Y ′ ∼ Q n and (X, Z), (X ′ , Z ′ ) ∼ (P n × S) are independent random variables. It remains to note that 1 2 E∥Y -Y ′ ∥ is a T -independent constant.

B EXPERIMENTS DETAILS

DATA PREPOSSESSING. We rescale the images to size 32×32 and normalize their channels to [-1, 1]. For the images, we repeat their channel 3 times and work with 3-channel images. We do not apply any augmentations to data. We use the default train-test splits for all the datasets. TRAINING DETAILS. In our Algorithm 2, we use Adam (Kingma & Ba, 2014) optimizer with lr = 10 -foot_3 for both T θ and v ω . The number of inner iterations for T θ is K T = 10. Doing preliminary experiments, we noted that it is sufficient to use small mini-batch sizes K X , K Y , K Z in (12). Therefore, we decided to average loss values over K B small independent mini-batches (each from class n with probability α n ) rather than use a single large batch from one class. This is done parallel with tensor operations. Two moons. We use 500 train and 150 test samples for each moon. We use the fully-connected net with 2 ReLU hidden layers size of 128 for both T θ and v ω . We train the model for 10k iterations of v ω with K B = 32, K X = K Y = 2 (K Z plays no role as we do not use z here). Gaussians. We use the fully-connected network with 2 ReLU hidden layers size of 256 for both T θ and v ω . There are 10000 train and 500 test samples in each Gaussian. We train the model for 10k iterations of v ω with K B = 32, K X = K Y = 2 (K Z plays no role here as well). Images. We use WGAN-QC discriminator's ResNet architecture (He et al., 2016) for potential v ω . We use UNetfoot_2 (Ronneberger et al., 2015) as the stochastic transport map T θ (x, z). To condition it on z, we insert conditional instance normalization (CondIN) layers after each UNet's upscaling block 4 . We use CondIN from AugCycleGAN (Almahairi et al., 2018) . In experiments, z is the 128-dimensional standard Gaussian noise. The batch size is K B = 32, K X = K Y = 2, K Z = 2 for training with z. When training without z, we use the original UNet without conditioning; the batch parameters are the same (K Z does not matter). Our method converges in ≈ 60k iterations of v ω . For comparison in image domain we use the official implementations with the hyperparameters from the respective papers: AugCycleGANfoot_4 (Almahairi et al., 2018) , MUNITfoot_5 (Huang et al., 2018) . For comparison with NOT (W 2 , W 2,γ ), we use the code shared by the authors of (Korotin et al., 2022c) . OTDD flow details. As in our method, the number of labeled samples in each class is 10. We learn the OTDD flow between the labeled source datasetfoot_6 and labeled target samples. Note the OTDD method does not use the unlabeled target samples. As the OTDD method does not produce out-of-sample estimates, we train UNet to map the source data to the data produced by the OTDD flow via regression. Then we compute the metrics on test (FID, accuracy) for this mapping network.

D ABLATION STUDY OF THE LATENT SPACE DIMENSION

In this subsection, we study the structure of the learned stochastic map for F with different latent space dimensions Z. We consider MNIST → USPS transfer task (10 classes). The results are shown in Figures 8, 9 and Table 3 . As it can be seen, our model performs comparably for different Z. 

F ICNN-BASED DATASET TRANSFER

For completeness, we show the performance of ICNN-based method for the strong (2) quadratic transport cost c(x, y) 1 2 ∥x -y∥ 2 on the dataset transfer task. We use the non-minimax version (Korotin et al., 2021a) of the ICNN-based method by (Makkuva et al., 2019) . We employ the publicly available code and dense ICNN architectures from the Wasserstein-2 benchmark repository 8 . The batch size is K B = 32, total number of iterations is 100k, lr = 3 • 10 -3 , and the Adam optimizer is used. The datasets are preprocessed as in the other experiments, see Appendix B. The qualitative results for MNIST→USPS and FashionMNIST→MNIST transfer are given in Figure 12 . The results are reasonable in the first case (related domains). However, they are visually unpleasant in the second case (unrelated domains). This is expected as the second case is notably harder. More generally, as derived in the Wasserstein-2 benchmark (Korotin et al., 2021b) , the ICNN models do not work well in the pixel space due to the poor expressiveness of ICNN architectures. The ICNN method achieved 18.8% accuracy and ≫100 FID in the FMNIST→MNIST transfer, and 35.6% and accuracy and 13.9 FID in the MNIST→USPS case. All the metrics are much worse than those achieved by our general OT method with the class-guided functional F G , see Table 1 , 2 for comparison. 

G NON-DEFAULT CLASS CORRESPONDENCE

To show that our method can work with any arbitrary correspondence between datasets, we also consider FMNIST→MNIST dataset transfer with the following non-default correspondence between the dataset classes: 0 ) 9, 1 ) 0, 2 ) 1, 3 ) 2, 4 ) 3, 5 ) 4, 6 ) 5, 7 ) 6, 8 ) 7, 9 ) 8. In this experiment, we use the same architectures and data preprocessing as in dataset transfer tasks; see Appendix B. We use our F G (11) as the cost functional and learn a deterministic transport map T (no z). In this setting, our method produce comparable results to the previously reported in Section 6 accuracy equal to 83.1, and FID 6.69. The qualitative results are given in Figure 13 . 



The functional FG(π) is not necessarily strictly convex. The proof is generic and works for any functional which equals +∞ outside π ∈ P(X × Y). github.com/milesial/Pytorch-UNet github.com/kgkgzrtk/cUNet-Pytorch github.com/aalmah/augmented_cyclegan github.com/NVlabs/MUNIT We use only 15k source samples since OTDD is computationally heavy (the authors use 2k samples).



Figure 1: The setup of class-guided dataset transfer. Input P = n α n P n , target Q = n β n Q n distributions are mixtures of N classes. The task is to learn a transport map T preserving the class.The learner has the access to labeled input data ∼ P and only partially labeled target data ∼ Q.

Let each input P and output Q distributions be a mixture of N distributions (classes) {P n } N n=1 and {Q n } N n=1 , respectively. That is P = N n=1 α n P n and Q = N n=1 β n Q n where α n , β n ≥ 0 are the respective weights (class prior probabilities) satisfying N n=1 α n = 1 and N n=1 β n = 1. In this general setup, we aim to find the transport plan π(x, y) ∈ Π(P, Q) for which the classes of x ∈ X and y ∈ Y are the same for as much pairs (x, y) ∼ π as possible. That is, its respective stochastic map T should map each component P n (class) of P to the respective component Q n (class) of Q.

Figure 2: The results of mapping two moons using OT with different cost functionals.

Figure 4: The results of mapping between two similar images datasets (related domains).

Figure5: The results of mapping between two notably different datasets (unrelated domains). net from a known classes' correspondence(Liu et al., 2020). The map for W 2 does not preserve the classes (Figure2c), while our map solves the task (Figure2d). The Gaussians. Here both P, Q are balanced mixtures of 16 Gaussians, and each color denotes a unique class. The goal is to map Gaussians in P (Figure3a) to respective Gaussians in Q (Figure3b) which have the same color. The result of our method (10 known target labels per class) is given in Figure3c. It correctly maps the classes. NOT for the quadratic cost is not shown as it results in the identity map (the same image as Figure3a) which is completely mistaken in classes.

Figure 8: MNIST → USPS translation with functional F G and varying Z = 1, 4, 8, 16, 32, 64.

Figure 9: Stochastic transport maps T θ (x, z) learned by our Algorithm 2 with different sizes of Z.

Figure 12: Results of ICNN-based method applied to the dataset transfer task.

Figure 13: FMNIST→MNIST mapping with F G no z cost, classes are permuted.

, Theorem 2) for details. Problem (6) also admits a pair of minimizers u * , v * .

accessible by samples (unlabeled); weights α n are known and samples from each P n , Q n are accessible (labeled); mapping network T θ : R P

Melis & Fusi, Accuracy↑ of the maps learned by the translation methods in view.

FID↓ of the samples generated by the translation methods in view.2021; 2020). The method employs gradient flows to perform the transfer preserving the class label. Additionally, we show the results of ICNN-based W 2 OT method(Makkuva et al., 2019; Korotin  et al., 2021a)  in Appendix F.

C ADDITIONAL EXAMPLES OF STOCHASTIC MAPS

In this subsection, we provide additional examples the learned map for F G (with z). We consider all the image datasets from the main experiments ( 6). The results are shown in Figure 7 and demonstrate that for a fixed x and different z, our model generates diverse samples. 

E IMBALANCED CLASSES

In this subsection, we study the behaviour of the optimal map for F G when the classes are imbalanced input and target domains. Since out method learns a transport map from P to Q, it should capture the class balance of the Q regardless of the class balance in P. We check this below.We consider MNIST → USPS datasets with n = 3 classes in MNIST and n = 3 classes in USPS. We assume that the class probabilities are α 1 = α 2 = 1 2 , α 3 = 0 and β 1 = β 2 = β 3 = 1 3 . That is, there is no class 3 in the source dataset and it is not used anywhere during training. In turn, the target class 3 is not used when training T θ but is used when training f ω . All the hyperparameters here are the same as in the previous MNIST → USPS experiments with 10 known labels in target classes. The results are shown in Figure 10a and 11a. We present deterministic (no z) and stochastic (with z) maps.Our cost functional F G stimulates the map to maximally preserve the input class. However, to transport P to Q, the model must change the class balance. We show the confusion matrix for learned maps T θ in Figures 10b, 11b . It illustrates that model maximally preserves the input classes 0, 1 and uniformly distributes the input classes 0 and 1 into the class 2, as suggested by our cost functional. 

