

Abstract

We study the use of amortized optimization to predict optimal transport (OT) maps from the input measures, which we call Meta OT. This helps repeatedly solve similar OT problems between different measures by leveraging the knowledge and information present from past problems to rapidly predict and solve new problems. Otherwise, standard methods ignore the knowledge of the past solutions and suboptimally re-solve each problem from scratch. We instantiate Meta OT models in discrete and continuous (Wasserstein-2) settings between images, spherical data, and color palettes and use them to improve the computational time of standard OT solvers by multiple orders of magnitude.

1. INTRODUCTION

Optimal transportation (Villani, 2009; Ambrosio, 2003; Santambrogio, 2015; Peyré et al., 2019; Merigot and Thibert, 2021) is thriving in domains including economics (Galichon, 2016) , reinforcement learning (Dadashi et al., 2021; Fickinger et al., 2021 ), style transfer (Kolkin et al., 2019) , generative modeling (Arjovsky et al., 2017; Seguy et al., 2018; Huang et al., 2020; Rout et al., 2021 ), geometry (Solomon et al., 2015; Cohen et al., 2021 ), domain adaptation (Courty et al., 2017; Redko et al., 2019 ), signal processing (Kolouri et al., 2017) , fairness (Jiang et al., 2020) , and cell reprogramming (Schiebinger et al., 2019) . A core component in these settings is to couple two measures (α, β) supported on domains (X , Y) by solving a transport optimization problem such as the primal Kantorovich problem, which is defined by: π (α, β, c) ∈ arg min π∈U (α,β) X ×Y c(x, y)dπ(x, y), where the optimal coupling π is a joint distribution over the product space, U(α, β) is the set of admissible couplings between α and β, and c : X × Y → R is the ground cost, that represents a notion of distance between elements in X and elements in Y. Challenges. Unfortunately, solving eq. ( 1) once is computationally expensive between general measures and computationally cheaper alternatives are an active research topic: Entropic optimal transport (Cuturi, 2013) smooths the transport problem with an entropy penalty, and sliced distances (Kolouri et al., 2016; 2018; 2019; Deshpande et al., 2019) solve OT between 1-dimensional projections of the measures, where eq. ( 1) can be solved easily. Furthermore, when an optimal transport method is deployed in practice, eq. ( 1) is not just solved a single time, but is repeatedly solved for new scenarios between different input measures (α, β). For example, the measures could be representations of images we care about optimally transporting between and in deployment we would receive a stream of new images to couple. Repeatedly solving optimal transport problems also comes up in the context of comparing seismic signals (Engquist and Froese, 2013) and in single-cell perturbations (Bunne et al., 2021; 2022b; a) . Standard optimal transport solvers deployed in this setting would re-solve the optimization problems from scratch, but this ignores the shared structure and information present between different coupling problems. Overview and outline. We study the use of amortized optimization and machine learning methods to rapidly solve multiple optimal transport problems and predict the solution from the input measures (α, β). This setting involves learning a meta model to predict the solution to the optimal transport problem, which we will refer to as Meta Optimal Transport. We learn Meta OT models to predict the solutions to optimal transport problems and significantly improve the computational time and number of iterations needed to solve eq. ( 1) between discrete (sect. 3.1) and continuous (sect. 3.2)

