RIEMANNIAN METRIC LEARNING VIA OPTIMAL TRANSPORT

Abstract

We introduce an optimal transport-based model for learning a metric tensor from cross-sectional samples of evolving probability measures on a common Riemannian manifold. We neurally parametrize the metric as a spatially-varying matrix field and efficiently optimize our model's objective using a simple alternating scheme. Using this learned metric, we can nonlinearly interpolate between probability measures and compute geodesics on the manifold. We show that metrics learned using our method improve the quality of trajectory inference on scRNA and bird migration data at the cost of little additional cross-sectional data.

1. INTRODUCTION

In settings like single-cell RNA sequencing (scRNA-seq) (Tanay & Regev, 2017) , we often encounter pooled cross-sectional data: Time-indexed samples x i t Nt i=1 from an evolving population X t with no correspondence between samples x i s and x i t at times s ̸ = t. Such data may arise when technical constraints impede the repeated observation of some population member. For example, as scRNA-seq is a destructive process, any given cell's gene expression profile can only be measured once before the cell is destroyed. This data is often sampled sparsely in time, leading to interest in trajectory inference: Inferring the distribution of the population or the positions of individual particles between times {t i } T i=1 at which samples are drawn. A fruitful approach has been to model the evolving population as a time-varying probability measure P t on R D and to infer the distribution of the population between observed times by interpolating between subsequent pairs of measures P ti , P ti+1 . Some existing approaches to this problem use dynamical optimal transport to interpolate between probability measures, which implicitly encodes a prior that particles travel along straight lines between observations. This prior is often implausible, especially when the evolving population is sampled sparsely in time. One can straightforwardly extend optimal transport-based methods by allowing the user to specify a spatially-varying metric tensor to bias the inferred trajectories away from straight lines. This approach is theoretically well-founded and amounts to redefining a straight line by altering the manifold on which trajectories are inferred. Such a metric tensor, however, is typically unavailable in most real-world applications. We resolve this problem by introducing an optimal transport-based model in which a metric tensor may be recovered from cross-sectional samples of evolving probability measures on a common manifold. We derive a tractable optimization problem using the theory of optimal transport on Riemannian manifolds, neurally parametrize its variables, and solve it using gradient-based optimization. We demonstrate our algorithm's ability to recover a known metric tensor from cross-sectional samples on synthetic examples. We also show that our learned metric tensor improves the quality of trajectory inference on scRNA data and allows us to infer curved trajectories for individual birds from crosssectional samples of a migrating population. Our method is both computationally-efficient, requiring little computational resources relative to the downstream trajectory inference task, and data-efficient, requiring little data per time point to recover a useful metric tensor.

2. RELATED WORK

Measure interpolation. An emerging literature considers the problem of smooth interpolation between probability measures. Using the theory of optimal transport, we may construct a displacement interpolation (McCann, 1997) between successive measures (P i , P i+1 ), yielding a sequence of geodesics between pairs (P i , P i+1 ) in the space of probability measures equipped with the 2-Wasserstein distance W 2 (Villani, 2008) . This generalizes piecewise-linear interpolation to probability measures. Schiebinger et al. (2019) use this method to infer the developmental trajectories of cells based on static measurements of their gene expression profiles. Recent works such as (De Bortoli et al., 2021) and (Vargas et al., 2021) provide numerical schemes for computing Schrödinger bridges; these are an entropically-regularized analog to the displacement interpolation between measures. 2020) propose a non-convex inverse problem for recovering the ground metric and interaction kernel in a class of mean-field games and supply a primal-dual algorithm for solving a grid discretization of the problem. Whereas their Eulerian approach inputs a grid discretization of observed densities and velocity fields, our approach is Lagrangian, operating directly on temporal observations of particle positions. Trajectory inference from population-level data. In domains such as scRNA-seq, we study an evolving population from which it is impossible or prohibitively costly to collect longitudinal data. Instead, we observe distinct cross-sectional samples from the population at a collection of times and wish to infer the dynamics of the latent population from these observations. This problem is called trajectory inference. 2016) study the conditions under which it is possible to recover a potential function from population-level observations of a system evolving according to a Fokker-Planck equation. They provide an RNN-based algorithm for this learning task and investigate their model's ability to recover differentiation dynamics from scRNA-seq data sampled sparsely in time. A recent work of Bunne et al. ( 2022) presents a proximal analog to this approach, modeling population dynamics as a JKO flow with a learnable energy function and describing a numerical scheme for computing this flow using input-convex neural networks. This method addresses initial value problems where one is given an initial measure ρ 0 and seeks to predict future measures ρ t for t > 0. In contrast, our method is primarily applicable to boundary value problems, where we are given initial and final measures ρ 0 and ρ 1 and seek an interpolant ρ t for 0 < t < 1. et al. (2019) use optimal transport to infer future and prior gene expression profiles from a single observation in time of a cell's gene expression profile. Yang & Uhler (2019) propose a GAN-like solver for unbalanced optimal transport and investigate the effectiveness of their method for the inference of future gene expression states in zebrafish single-cell gene expression data.

Schiebinger

As they learn transport maps between probability measures defined at discrete time points, none of the above OT-based methods is suitable for inferring continuous trajectories. Tong et al. (2020) show that a regularized continuous normalizing flow can provide an efficient approximation to displacement interpolations arising from dynamical optimal transport. These displacement interpolations often do not yield plausible paths through gene expression space, so the authors propose several applicationspecific regularizers to bias the inferred trajectories toward more realistic paths. Rather than relying on bespoke regularizers, our method supplies a natural approach for learning local directions of particle motion in a way that is amenable to integration with algorithms like that of Tong et al. (2020) . Riemannian metric learning. Lebanon (2002) develops a parametric method for learning a Riemannian metric from sparse high-dimensional data and uses this metric for nearest-neighbor classification. Hauberg et al. (2012) construct a metric tensor as a weighted average of a set of learned metric tensors and show how to compute geodesics and exponential and logarithmic maps on the resulting Riemannian manifold. Arvanitidis et al. ( 2016) learn a Riemannian metric that encourages geodesics to move towards region of high data density, define a Riemannian normal distribution with respect to



Chen et al. (2018b); Benamou et al. (2018) leverage the variational characterization of cubic splines as minimizers of mean-square acceleration over the set of interpolating curves and develop a generalization in the space of probability measures. Chewi et al. (2021) extend these works by providing computationally efficient algorithms for computing measure-valued splines. Hug et al. (2015) modify the usual displacement interpolation between probability measures by introducing anisotropy to the domain on which the measures are defined. This change corresponds to imposing preferred directions for the local displacement of probability mass. Whereas Hug et al. hard-code the domain's anisotropy, our method allows us to learn this anisotropy from snapshots of a probability measure evolving in time. Zhang et al. (2022) apply similar techniques to unstructured animation problems. Ding et al. (

