RIEMANNIAN METRIC LEARNING VIA OPTIMAL TRANSPORT

Abstract

We introduce an optimal transport-based model for learning a metric tensor from cross-sectional samples of evolving probability measures on a common Riemannian manifold. We neurally parametrize the metric as a spatially-varying matrix field and efficiently optimize our model's objective using a simple alternating scheme. Using this learned metric, we can nonlinearly interpolate between probability measures and compute geodesics on the manifold. We show that metrics learned using our method improve the quality of trajectory inference on scRNA and bird migration data at the cost of little additional cross-sectional data.

1. INTRODUCTION

In settings like single-cell RNA sequencing (scRNA-seq) (Tanay & Regev, 2017) , we often encounter pooled cross-sectional data: Time-indexed samples x i t Nt i=1 from an evolving population X t with no correspondence between samples x i s and x i t at times s ̸ = t. Such data may arise when technical constraints impede the repeated observation of some population member. For example, as scRNA-seq is a destructive process, any given cell's gene expression profile can only be measured once before the cell is destroyed. This data is often sampled sparsely in time, leading to interest in trajectory inference: Inferring the distribution of the population or the positions of individual particles between times {t i } T i=1 at which samples are drawn. A fruitful approach has been to model the evolving population as a time-varying probability measure P t on R D and to infer the distribution of the population between observed times by interpolating between subsequent pairs of measures P ti , P ti+1 . Some existing approaches to this problem use dynamical optimal transport to interpolate between probability measures, which implicitly encodes a prior that particles travel along straight lines between observations. This prior is often implausible, especially when the evolving population is sampled sparsely in time. One can straightforwardly extend optimal transport-based methods by allowing the user to specify a spatially-varying metric tensor to bias the inferred trajectories away from straight lines. This approach is theoretically well-founded and amounts to redefining a straight line by altering the manifold on which trajectories are inferred. Such a metric tensor, however, is typically unavailable in most real-world applications. We resolve this problem by introducing an optimal transport-based model in which a metric tensor may be recovered from cross-sectional samples of evolving probability measures on a common manifold. We derive a tractable optimization problem using the theory of optimal transport on Riemannian manifolds, neurally parametrize its variables, and solve it using gradient-based optimization. We demonstrate our algorithm's ability to recover a known metric tensor from cross-sectional samples on synthetic examples. We also show that our learned metric tensor improves the quality of trajectory inference on scRNA data and allows us to infer curved trajectories for individual birds from crosssectional samples of a migrating population. Our method is both computationally-efficient, requiring little computational resources relative to the downstream trajectory inference task, and data-efficient, requiring little data per time point to recover a useful metric tensor.

2. RELATED WORK

Measure interpolation. An emerging literature considers the problem of smooth interpolation between probability measures. Using the theory of optimal transport, we may construct a displacement interpolation (McCann, 1997) between successive measures (P i , P i+1 ), yielding a sequence

