RIEMANNIAN METRIC LEARNING VIA OPTIMAL TRANSPORT

Abstract

We introduce an optimal transport-based model for learning a metric tensor from cross-sectional samples of evolving probability measures on a common Riemannian manifold. We neurally parametrize the metric as a spatially-varying matrix field and efficiently optimize our model's objective using a simple alternating scheme. Using this learned metric, we can nonlinearly interpolate between probability measures and compute geodesics on the manifold. We show that metrics learned using our method improve the quality of trajectory inference on scRNA and bird migration data at the cost of little additional cross-sectional data.

1. INTRODUCTION

In settings like single-cell RNA sequencing (scRNA-seq) (Tanay & Regev, 2017) , we often encounter pooled cross-sectional data: Time-indexed samples x i t Nt i=1 from an evolving population X t with no correspondence between samples x i s and x i t at times s ̸ = t. Such data may arise when technical constraints impede the repeated observation of some population member. For example, as scRNA-seq is a destructive process, any given cell's gene expression profile can only be measured once before the cell is destroyed. This data is often sampled sparsely in time, leading to interest in trajectory inference: Inferring the distribution of the population or the positions of individual particles between times {t i } T i=1 at which samples are drawn. A fruitful approach has been to model the evolving population as a time-varying probability measure P t on R D and to infer the distribution of the population between observed times by interpolating between subsequent pairs of measures P ti , P ti+1 . Some existing approaches to this problem use dynamical optimal transport to interpolate between probability measures, which implicitly encodes a prior that particles travel along straight lines between observations. This prior is often implausible, especially when the evolving population is sampled sparsely in time. One can straightforwardly extend optimal transport-based methods by allowing the user to specify a spatially-varying metric tensor to bias the inferred trajectories away from straight lines. This approach is theoretically well-founded and amounts to redefining a straight line by altering the manifold on which trajectories are inferred. Such a metric tensor, however, is typically unavailable in most real-world applications. We resolve this problem by introducing an optimal transport-based model in which a metric tensor may be recovered from cross-sectional samples of evolving probability measures on a common manifold. We derive a tractable optimization problem using the theory of optimal transport on Riemannian manifolds, neurally parametrize its variables, and solve it using gradient-based optimization. We demonstrate our algorithm's ability to recover a known metric tensor from cross-sectional samples on synthetic examples. We also show that our learned metric tensor improves the quality of trajectory inference on scRNA data and allows us to infer curved trajectories for individual birds from crosssectional samples of a migrating population. Our method is both computationally-efficient, requiring little computational resources relative to the downstream trajectory inference task, and data-efficient, requiring little data per time point to recover a useful metric tensor.

2. RELATED WORK

Measure interpolation. An emerging literature considers the problem of smooth interpolation between probability measures. Using the theory of optimal transport, we may construct a displacement interpolation (McCann, 1997) between successive measures (P i , P i+1 ), yielding a sequence of geodesics between pairs (P i , P i+1 ) in the space of probability measures equipped with the 2-Wasserstein distance W 2 (Villani, 2008) . This generalizes piecewise-linear interpolation to probability measures. Schiebinger et al. (2019) use this method to infer the developmental trajectories of cells based on static measurements of their gene expression profiles. Recent works such as (De Bortoli et al., 2021) and (Vargas et al., 2021) provide numerical schemes for computing Schrödinger bridges; these are an entropically-regularized analog to the displacement interpolation between measures. Chen et al. (2018b) ; Benamou et al. (2018) leverage the variational characterization of cubic splines as minimizers of mean-square acceleration over the set of interpolating curves and develop a generalization in the space of probability measures. Chewi et al. (2021) extend these works by providing computationally efficient algorithms for computing measure-valued splines. Hug et al. (2015) modify the usual displacement interpolation between probability measures by introducing anisotropy to the domain on which the measures are defined. This change corresponds to imposing preferred directions for the local displacement of probability mass. Whereas Hug et al. hard-code the domain's anisotropy, our method allows us to learn this anisotropy from snapshots of a probability measure evolving in time. Zhang et al. (2022) apply similar techniques to unstructured animation problems. Ding et al. (2020) propose a non-convex inverse problem for recovering the ground metric and interaction kernel in a class of mean-field games and supply a primal-dual algorithm for solving a grid discretization of the problem. Whereas their Eulerian approach inputs a grid discretization of observed densities and velocity fields, our approach is Lagrangian, operating directly on temporal observations of particle positions. Trajectory inference from population-level data. In domains such as scRNA-seq, we study an evolving population from which it is impossible or prohibitively costly to collect longitudinal data. Instead, we observe distinct cross-sectional samples from the population at a collection of times and wish to infer the dynamics of the latent population from these observations. This problem is called trajectory inference. Hashimoto et al. (2016) study the conditions under which it is possible to recover a potential function from population-level observations of a system evolving according to a Fokker-Planck equation. They provide an RNN-based algorithm for this learning task and investigate their model's ability to recover differentiation dynamics from scRNA-seq data sampled sparsely in time. A recent work of Bunne et al. (2022) presents a proximal analog to this approach, modeling population dynamics as a JKO flow with a learnable energy function and describing a numerical scheme for computing this flow using input-convex neural networks. This method addresses initial value problems where one is given an initial measure ρ 0 and seeks to predict future measures ρ t for t > 0. In contrast, our method is primarily applicable to boundary value problems, where we are given initial and final measures ρ 0 and ρ 1 and seek an interpolant ρ t for 0 < t < 1. Schiebinger et al. (2019) use optimal transport to infer future and prior gene expression profiles from a single observation in time of a cell's gene expression profile. Yang & Uhler (2019) propose a GAN-like solver for unbalanced optimal transport and investigate the effectiveness of their method for the inference of future gene expression states in zebrafish single-cell gene expression data. As they learn transport maps between probability measures defined at discrete time points, none of the above OT-based methods is suitable for inferring continuous trajectories. Tong et al. (2020) show that a regularized continuous normalizing flow can provide an efficient approximation to displacement interpolations arising from dynamical optimal transport. These displacement interpolations often do not yield plausible paths through gene expression space, so the authors propose several applicationspecific regularizers to bias the inferred trajectories toward more realistic paths. Rather than relying on bespoke regularizers, our method supplies a natural approach for learning local directions of particle motion in a way that is amenable to integration with algorithms like that of Tong et al. (2020) . Riemannian metric learning. Lebanon (2002) develops a parametric method for learning a Riemannian metric from sparse high-dimensional data and uses this metric for nearest-neighbor classification. Hauberg et al. (2012) construct a metric tensor as a weighted average of a set of learned metric tensors and show how to compute geodesics and exponential and logarithmic maps on the resulting Riemannian manifold. Arvanitidis et al. (2016) learn a Riemannian metric that encourages geodesics to move towards region of high data density, define a Riemannian normal distribution with respect to this metric, and show how to learn the parameters of this model via maximum likelihood estimation. Whereas these methods learn a Riemannian metric from static data, our model learns a metric from snapshots of probability distributions evolving in time on a common manifold.

3. METHOD

We now describe our method for learning a Riemannian metric from cross-sectional samples of populations evolving in time on a common manifold. We learn a metric that minimizes the average 1-Wasserstein distance on the manifold between pairs of subsequent time samples from each population. We derive a dual formulation of our problem, parametrize its variables by neural networks, and solve for the dual variables and the metric via alternating optimization.

3.1. MODEL

Suppose we have K populations evolving according to unknown continuous dynamics over a common Riemannian manifold M = R D , g with unknown metric g. The metric g is defined at any x ∈ R D by the inner product ⟨u, v⟩ x = u T A(x)v for A(x) ≻ 0. We model each population as a compactlysupported probability distribution P k with density ρ k on R D being pushed through an unobserved velocity field v k (x). We will learn the metric tensor A(x) from temporal snapshots of the populations P k during their evolution on the manifold. Depending on the nature of our data, we may have the ability to repeatedly sample from P k at a pair of initial and final times t = 0 and t = T k or have a fixed set of samples {x k,0 i } S k i=1 and {x k,T k i } S k i=1 drawn from the populations P k at their respective times. For convenience, we denote both the density ρ k and the empirical distributions over samples from P k at times t ∈ {0, T k } by ρ k 0 and ρ k 1 respectively. As we only observe the initial and final spatial distributions ρ k 0 and ρ k 1 of the populations and do not observe their dynamics, we assume that probability mass travels from initial to final positions along A-geodesics. Geodesics are paths that minimize the action (or average kinetic energy) of a particle traveling between points x, y ∈ M; this least-action interpretation of a geodesic makes it a natural prior on paths in the absence of further information. We learn a field of positive definite matrices A(x) that minimizes the average A-geodesic distance between the initial and final positions of each unit of probability mass from each P k . Formally, let r k be the map sending a point x ∈ M to its final position r k (x) after flowing through the latent velocity field v k from time t = 0 to t = T k . If we had access to such a solution map, we would ideally solve the following problem: inf A:R D →S D ++ 1 K K k=1 M d A x, r k (x) dρ k 0 (x) + λR (A) , where R(A) is a regularizer that excludes the trivial solution A ≡ 0. Problem (1) optimizes for a non-trivial metric A(x) that minimizes the average A-geodesic distance d A x, r k (x) traveled by particles x in the population ρ k 0 at time t = 0 to their final positions r k (x) at time t = T k . Since we do not know the velocity fields v k that encode the particle dynamics, however, we also do not know the maps r k in (1) that specify the correspondence between particle positions x at t = 0 and positions r k (x) at final times T k . Furthermore, as noted in Section 1, we often encounter data for which it is impossible to observe contiguous particle trajectories: A particle that we observe at t = 0 may not be in the sample at t = T k . This issue is unavoidable for destructive measurement processes such as scRNA sequencing; in this setting, a cell whose scRNA profile is observed at t = 0 would be destroyed at this time and hence unobservable at a future time t = T k . To accommodate these limitations, we replace the true matchings of initial and final positions r k with the Monge map s k , defined as the solution to the following problem: W 1,A (ρ k 0 , ρ k 1 ) = inf s k :ρ k 1 =s k # ρ k 0 M d A x, s k (x) dρ k 0 (x). Here s k is a pushforward of ρ k 0 onto ρ k 1 ; we write ρ k 1 = s k # ρ k 0 to denote this relationship. This map matches units of mass from the initial and final distributions ρ k 0 , ρ k 1 to minimize their average A-geodesic distance. Substituting solutions to (2) for the maps r k in our idealized objective (1), we obtain the following lower bound on (1): inf A:R D →S D ++ inf s k :ρ k 1 =s k # ρ k 0 1 K K k=1 M d A x, s k (x) dρ k 0 (x) = 1 K K k=1 W 1,A (ρ k 0 ,ρ k 1 ) +λR (A) . (3) Problem ( 3) is challenging as written: It requires the ability to compute and differentiate geodesic distances d A with respect to an arbitrary metric. However, the inner optimization problem over maps s k is a collection of decoupled Monge problems (2) whose optimal value is the average 1-Wasserstein distance between (ρ k 0 , ρ k 1 ) pairs on the manifold with metric A(x). As shown in Appendix B, these problems can be expressed in a dual form (10) which is amenable to gradient-based optimization. Replacing the inner Monge problems with their dual formulations (10), we may equivalently write Problem (3) as a minimax problem: inf A:R D →S D ++ sup ϕ k :M→R ∥∇ϕ k (x)∥ A -1 (x) ≤1 1 K K k=1 M ϕ k (x)dρ k 0 (x) - M ϕ k (x)dρ k 1 (x) + λR (A) . (4) In Problem (4), we learn a metric A(x) that minimizes the average 1-Wasserstein distance on the manifold between pairs of subsequent time samples from each population. ( 4) requires neither the computation of geodesic distances on M nor the solution of an assignment problem. As such, it is substantially more tractable than the initial objective (3). We provide the details of our implementation of (4) in Section 3.2.

3.2. IMPLEMENTATION

Enforcing the Lipschitz constraint. Problem (4) includes global constraints of form ∥∇ϕ k ∥ A -1 ≤ 1. These constraints are the Riemannian analog to the Lipschitz constraint in the dual formulation of the 1-Wasserstein distance on R D . Constraints of this type are challenging to enforce in gradient-based optimization, and the Wasserstein GAN literature has explored approximations (Arjovsky et al., 2017; Gulrajani et al., 2017; Miyato et al., 2018) . We follow the standard technique introduced by Gulrajani et al. ( 2017) and replace the global constraints ∥∇ϕ k ∥ A -1 ≤ 1 with soft penalties of the following form: E x0∼ρ k 0 x1∼ρ k 1 t∼U (0,1) SoftPlus ∥∇ϕ k σ x1 x0 (t) ∥ 2 A -1 (σ x 1 x 0 (t)) -1 , where σ x1 x0 (t) := (1 -t)x 0 + tx 1 is a line segment between x 0 and x 1 parametrized by t ∈ [0, 1]. Intuitively, (5) penalizes violation of the Lipschitz constraint ∥∇ϕ k ∥ A -1 ≤ 1 along line segments connecting randomly-paired points in X k 0 and X k 1 . Gulrajani et al. (2017, Prop. 1) justify this choice via a standard result in W 1 theory showing that the constraint ∥∇ϕ k ∥ ≤ 1 binds on line segments connecting pairs of points that are matched by the Monge map s k . Korotin et al. (2022) show that this method results in accurate approximations to the directions of the gradients ∇ φk of the true Kantorovich potentials. This observation is sufficient for our purposes; as noted below, our optimization scheme encourages the low-energy eigenvectors of A(x) to be well-aligned with the solutions ∇ϕ k to the inner problem in (4).

Choice of regularizer R(A).

Without a regularizer R(A), the objective in (4) can be driven to 0 by choosing A(x) ≡ αI for arbitrarily small α > 0, thereby making all pairs of measures arbitrarily close. This trivial solution would incorporate no useful information from the observed samples from ρ k 0 and ρ k 1 , as it is simply a rescaling of the standard Euclidean metric. We opt for the following regularizer in our method: R(A) = 1 K K k=1 E x0∼ρ k 0 x1∼ρ k 1 t∼U (0,1) ∥A -1 σ x1 x0 (t) ∥ 2 F , where σ x1 x0 (t) is as in the previous paragraph and ∥ • ∥ 2 F denotes the squared Frobenius norm. Regularizer (6) penalizes ∥A -1 ∥ 2 F at points drawn using a sampling scheme analogous to that in (5). This is a natural way to exclude trivial solutions of form A(x) ≡ αI for arbitrarily small α > 0, as ∥A -1 ∥ 2 F is large for such metrics. We investigate the impact of the regularization coefficient λ on the learned metric in Appendix E. In Appendix F, we demonstrate the effect of removing the optimal transport term from (4), which is equivalent to setting λ = +∞. Optimization scheme. We parametrize the scalar potentials ϕ k by neural networks to enable the use of gradient-based optimization to solve Problem (4). As A appears in Equations ( 4), ( 5), and (6) only via its inverse A -1 , we directly parametrize the matrix field A -1 by a neural network; where we require the evaluation of A(x) in downstream applications, we evaluate and then invert A -1 (x). We enforce the positive definiteness of A -1 (and hence the positive definiteness of A) by parametrizing it as A -1 (x) = Q(x) T Q(x) + ηI for a matrix-valued function Q(x) : M → R D×D and η > 0. After parametrizing our problem variables with neural networks, we optimize the objective in (4) via alternation. In the first phase of our scheme, we solve the inner problem by holding A fixed (initializing it as A(x) = A -1 (x) ≡ I) and solving for the optimal ϕ k . This step decouples over the potentials ϕ k , and each of the resulting problems is an instance of the dual problem for the 1-Wasserstein distance on M. We approximate the integrals in (4) as sample means 1 N N i=1 ϕ k (x i ), where the {x i } are samples from the distributions ρ k 0 and ρ k 1 . We likewise approximate the Lipschitz penalty (5) and the regularizer (6) as sample means over draws from ρ k 0 and ρ k 1 and over t drawn uniformly from [0, 1]. In the second phase, we solve the outer problem by fixing the optimal ϕ k from the previous step and solving for the optimal matrix field A(x). We optimize both problems using AdamW (Loshchilov & Hutter, 2019) . Few alternations are needed in practice to obtain high-quality results. The results in Appendix B show that given a fixed metric defined by A(x), the optimal ∇ϕ k from the first phase of our scheme point along A-geodesics joining pairs of points (x, s k (x)) where x ∼ ρ k 0 and s k solves (2). At initialization, A ≡ I, and these geodesics are line segments in R D connecting the matched points. Given fixed Kantorovich potentials ϕ k from the first phase, the second phase of our scheme solves for a matrix field A(x) that minimizes the regularizer R(A) = ∥A -1 ∥ 2 F while also making ∥∇ϕ k ∥ A -1 large. Intuitively, this encourages the unit eigenvector u 1 (x) corresponding to the minimal eigenvalue λ 1 (x) of A(x) to be aligned with ∇ϕ k (x) wherever the constraint is enforced.

4. EXPERIMENTS

In this section, we first use synthetic data to demonstrate that our algorithm successfully recovers the correct eigenspaces of a known metric A(x) from cross-sectional samples satisfying our model. We then use our method to learn a metric from cross-sectional scRNA data and show that this metric improves the accuracy of trajectory inference for scRNA data that is sampled sparsely in time. We finally show that by learning a metric from time-stamped bird sightings, we can infer curved migratory trajectories for individual birds given the initial and final points of their trajectories. Details for all experiments are provided in Appendix D.

4.1. METRIC RECOVERY

We first show that our method recovers the correct eigenspaces of a known metric A(x) from cross-sectional samples X k 0 and X k 1 satisfying the model in Section 3.1. We fix initial positions x n (0) and final positions x n (1) for a set of N particles x n ∈ X and also fix a spatially-varying metric A(x). We compute A-geodesics between each (x n (0), x n (1)) pair and define populations X ti = {x n (t i ) : n = 1, ..., N } for times t i ∈ [0, 1], i = 0, ..., T . We use our method to learn the latent metric A(x) from the pairs of samples (X ti , X ti+1 ). For each example, we plot the eigenvectors of the true metric A(x) and those of the learned metric Â(x) on a P × P grid (first row, Figure 1 ). We also plot the eigenvectors of Â(x) along with the log-condition number log ( λ2(x) /λ1(x)) of Â(x) (second row, Figure 1 ); this value is large when the learned metric is highly anisotropic. We finally report a measure of the alignment of the eigenspaces of the true metric A and its learned counterpart: ℓ(A, Â) = 1 D|X | x∈X D d=1 |⟨u d (x), ûd (x)⟩|. Here u d (x) is the unit eigenvector with eigenvalue λ d of A(x), ûd (x) is the corresponding eigenvector for Â(x), and X is a set of grid points at which we plot the eigenvectors of A and Â. ℓ(A, Â) = 1 when each eigenvector of A(x) is parallel to the corresponding eigenvector of Â (x) at all grid points x ∈ X . Our method accurately recovers the eigenvectors of the true metric, and the learned metric is highly anisotropic in regions that overlap the observed data. Our method accurately recovers the eigenspaces of the "Circular" (ℓ(A, Â) = 0.995) and "X Paths" (ℓ(A, Â) = 0.916) metrics. It achieves lower accuracy with the "Mass Splitting" metric (ℓ(A, Â) = 0.839), struggling to capture the discontinuity in its eigenvectors at the x-axis and exhibiting numerical instability to the left of the y-axis, where the training trajectories diverge. Note, however, that the alignment score ℓ(A, Â) is in part measured at points that do not lie on the trajectory of the training data. We would not expect our method to accurately recover the metric in these regions; the fact that it largely does so for the "Circular" and "X Paths" examples reflects desirable smoothness properties of the neural parametrization of the Kantorovich potentials and the metric tensor. Row 2 of Figure 1 shows that our learned metric is highly anisotropic in regions that overlap with the training data. In the "Mass Splitting" and "X Paths" examples, Â(x) also has small condition number near the origin where the two paths diverge and cross, respectively; this behavior reflects expected uncertainty in low-energy directions of motion in this region.

4.2. POPULATION-LEVEL TRAJECTORY INFERENCE WITH A LEARNED METRIC

scRNA sequencing (scRNA-seq) allows biologists to study the set of mRNA molecules (the "transcriptome") of individual cells at high resolution (Haque et al., 2017) . As scRNA-seq is a destructive process, any individual cell's RNA can be sequenced only once, impeding the use of this technology to study the change in a cell's transcriptome over time. This has led to interest in methods that use population-level scRNA-seq data to infer the temporal evolution of an individual cell's scRNA-seq profile. Optimal transport-based techniques are well-established tools for solving these trajectory inference problems. Schiebinger et al. (2019) and Tong et al. (2020) both use optimal transport to infer cellular development trajectories but assume a Euclidean metric on gene expression space. In this section, we use our method to relax this strong assumption by learning a metric for the ground space from scRNA data. We incorporate Â(x) into a downstream trajectory inference task and show that this strategy yields more accurate trajectories than a baseline without a learned metric and a baseline using the Euclidean metric. scRNA data. We perform trajectory inference experiments with the scRNA data drawn from Schiebinger et al. (2019) . This data consists of force-directed layout embedding coordinates of gene expression data collected over 18 days of reprogramming (39 time points total). We construct populations X ti for i = 1, . . . , 39 by drawing 500 samples per time point in the original data; this sampling uses 8.25% of the available data on average. We follow the same procedure as in Section 4.1 to learn a metric tensor Â(x) from subsequent pairs of samples (X ti , X ti+1 ). Learning the tensor takes 16 minutes on a single V100 GPU. For the downstream trajectory inference task, we keep one out of every k time points in the original data for k = 2, ..., 19 to obtain a new collection of time points tℓ with ℓ ∈ {nk : n ∈ N, n < 39 k }. We then perform trajectory inference between subsequent retained time points ( tℓ , tℓ+1 ) (using all of the available data for these time points) by optimizing the following objective: Here S ϵ is the Sinkhorn divergence (Feydy et al., 2019) between the target data X tℓ+1 and the advected samples ϕ t=1 (v θ )[X tℓ ]; ϕ t=1 (v θ ) denotes the solution map resulting from advecting a particle through a neurally-parametrized time-varying velocity field v t,θ for one unit of time; and t j ∈ [0, 1]. The Sinkhorn divergence used here is min θ S ϵ X tℓ+1 , ϕ t=1 (v θ )[X tℓ ] + λ x∈X tℓ m j=1 ||v t,θ (x(t j ))|| 2 Â(x(tj )) . S ϵ (α, β) = W ϵ (α, β) -1 2 W ϵ (α, α) -1 2 W ϵ (β, β); it is a debiased variant of the entropically-regularized 2-Wasserstein distance W ϵ . This is similar to the method of Tong et al. (2020) , where we replace the log-likelihood fitting loss with the Sinkhorn divergence; we found that the Sinkhorn divergence led to stabler training than a log-likelihood fitting loss. We use GeomLoss (Feydy et al., 2019) to compute the Sinkhorn divergence between the target and advected samples efficiently. We follow Tong et al. (2020) and assess the quality of inferred trajectories by measuring the W 1 distance (EMD) between left-out time points in the ground truth data and advected samples at corresponding times in the inferred trajectories. These distances are of the form W 1 X ti , ϕ t= ti (v θ )[X tℓ ] , so we compare ground truth samples at each left-out time t i = nk + h (h = 1, ..., k -1) with samples from X tℓ advected through v θ for time ti = h k . For each value of k, we record the average EMD between left out samples and their inferred counterparts and compare our method to a baseline where λ = 0 in (7) ("Without A" in Figure 2 ) and a Euclidean baseline where A(x) ≡ I ("A = Id" in Figure 2 ). Our "Without A" and "A = Id" baselines are comparable to the "Base" and "Base + E" models, respectively, from Tong et al. (2020) . Both our baseline models and their models learn a velocity field v t,θ that pushes samples at some time tℓ onto samples at a future time tℓ+1 and use the path followed by these samples as they flow through v t,θ as an inferred trajectory. Our "Without A" baseline and their "Base" model do not regularize v t,θ , whereas our "A = Id" baseline and their "Base + E" model regularize v t,θ by penalizing its squared norm, which encourages samples to flow along straight paths. We also compare against interpolants obtained via the method of Schiebinger et al. (2019) . As their Waddington OT method solves a static optimal transport problem and pushes cells through transport maps between fixed time steps, a direct comparison against our dynamical method is not possible. However, we attempt to replicate their validation by geodesic interpolation as closely as possible. We compute static optimal transport couplings between data at subsequent retained time points ( tℓ , tℓ+1 ), linearly interpolate between coupled data points, and measure the W 1 distance (EMD) between left-out time points in the ground truth data and advected samples at corresponding times in the interpolated trajectories. Figure 2 shows that our learned metric improves on the "Without A" baseline (analogous to the "Base" model of Tong et al. (2020) ) and the approach based on Schiebinger et al. (2019) for nearly all values of k, whereas the Euclidean baseline A(x) ≡ I (analogous to the "Base + E" model of Tong et al. (2020) ) generally results in less accurate trajectories than the non-regularized baseline. As expected, the gap between our results and both baselines increases for large values of k; for example, Trajectories inferred using our learned metric tensor (second column) more closely follow the manifold structure of the ground truth data than the non-regularized baseline trajectories (first column), where particles follow nearly straight-line paths between observed time points. using our learned metric results in a 19.7% mean reduction in EMD between left-out samples and the corresponding advected samples for k ≥ 10. This observation indicates that including a learned metric has a larger impact on the inferred trajectories when the ground truth data is sampled sparsely in time. Figure 3 compares the trajectories inferred using our method to the non-regularized baseline and ground truth trajectories for k = 3, where the time sampling is sufficiently dense that the baseline performs well, and k = 15, where our learned metric tensor substantially improves the quality of the inferred trajectories. Whereas the non-regularized baseline simply advects particles from the base measures to the corresponding targets along nearly-straight paths, the learned metric biases the trajectory to follow the highly-curved paths taken by the ground truth data. These results suggest that in settings where data collection is expensive and samples collected at a small subset of times are of primary interest, our method enables the plausible inference of particle positions at intermediate time points at the cost of little additional data and computation.

4.3. INDIVIDUAL-LEVEL TRAJECTORY INFERENCE WITH A LEARNED METRIC

In this section, we learn a metric Â(x) from time-stamped sightings of snow geese during their spring migration. We then compute an Â-geodesic between the initial and final point of GPS-tagged snow geese and show that this provides a reasonable coarse approximation to the geese's ground truth trajectories. Migratory paths as geodesics on a latent manifold. Many migratory bird species return to their summer breeding sites and wintering grounds with high spatial fidelity, sometimes returning to within 500 meters of their usual breeding location (Mowbray et al., 2020) . Given this behavior, a boundary value problem that fixes the endpoints of the birds' migratory trajectories is a reasonable model for bird migration. Once the endpoints are fixed, we must specify an objective that birds plausibly optimize to determine their trajectories. Absent further information, minimizing total kinetic energy over their trajectory is a reasonable choice. However, measuring kinetic energy with respect to the Euclidean metric leads to straight-line paths, which typically do not agree with observed migration trajectories. We use our method to learn a metric from untagged snow goose sightings that agrees with the data for this species. This metric summarizes the factors that birds may use to modify their migratory paths locally, such as local weather conditions, food availability, and predatory pressures.

Snow goose data.

The training data for this experiment consists of time-stamped sightings of untagged snow geese (Anser caerulescens) across the U.S. and Canada during their spring migration. We treat the sightings at time t as samples from a time-indexed spatial distribution of birds ρ t and train our metric on subsequent pairs of bird distributions (ρ t , ρ t+1 ). This implies that our method does not have access to any complete goose trajectories when learning the metric. We do not expect untagged goose sightings to contain enough information to predict high-frequency detail in migratory trajectories, but this data is far cheaper to obtain than complete migratory trajectories, which are typically recorded via GPS trackers attached to individual geese. The widespread availability and Figure 4 : By using a metric Â(x) learned from time-stamped bird sightings, we obtain inferred trajectories (center) that capture the curved structure of the ground truth migratory paths (left). Our method results in a 26.9% reduction in mean DTW distance between the inferred and ground truth trajectories relative to the Euclidean baseline (right). low cost of obtaining untagged bird sighting data motivates the use of our method for bird trajectory inference The training data is drawn from the eBird basic dataset (Sullivan et al., 2009) , current as of February 2022. We bin the sightings by month of observation and use our algorithm to learn a metric tensor Â(x) from populations X i consisting of the spatial coordinates of snow goose sightings in month i for i = 0, ..., 5 (i.e. January to June). This training data is depicted in Figure 5 in Appendix D.3. We use this learned metric tensor to compute an Â-geodesic between the initial and final observation of several snow geese along their migratory paths. This data is drawn from the Banks Island Snow Goose study as hosted on Movebank (Kays et al., 2021) . We estimate an Â-geodesic between initial and final points on each path by learning a time-varying velocity field v t,θ optimizing the following objective: min θ ∥x 1 -x(1)∥ 2 + µ m j=1 ||v t,θ (x(t j ))|| 2 Â(x(tj )) . As in Section 4.2, the velocity field v t,θ is neurally parametrized, 0 = t 0 ≤ • • • t j ≤ • • • t m = 1 , and we optimize (8) using AdamW. We set the initial condition x(0) = x 0 such that x 0 is the goose's initial position on its migratory trajectory and use a Euclidean norm penalty to force its final position x(1) along the inferred trajectory to match the true final position x 1 . Figure 4 compares the Â-geodesics obtained using the method described above (center plot) to the ground truth goose trajectories (left plot). While population-level data (in the form of untagged snow goose sightings; see Figure 5 in Appendix D.3) does not provide sufficient information to reconstruct the migratory paths of individual geese perfectly, the inferred trajectories accurately capture the curved structure of the ground truth trajectories. We evaluate our method's performance by computing the dynamic time warping (DTW) distance (Berndt & Clifford, 1994) between each inferred goose trajectory and the corresponding ground truth trajectory. We also generate straight-line paths (i.e. Euclidean geodesics) between the initial and final points of each ground truth goose trajectory and compute DTW distances between these paths and the ground truth. Our method results in a 26.9% reduction in mean DTW distance between the inferred and ground truth trajectories relative to the Euclidean baseline.

5. DISCUSSION AND CONCLUSION

We have introduced an optimal transport-based method for learning a Riemannian metric from cross-sectional samples of populations evolving over time on a latent Riemannian manifold. Our method accurately recovers metrics from cross-sections of populations moving along geodesics on a manifold, improves the quality of trajectory inference on sparsely-sampled scRNA data at low data and compute cost, and allows us to approximate individual trajectories of migrating birds using information from untagged sightings. One key limitation of our work is that it learns a Riemannian metric on R D , whereas large swaths of data naturally lie on graphs. Future work might consider extending this algorithm learn from data defined on graphs; a recent work of Heitz et al. (2021) presents an optimal transport-based direction towards solving this problem. Future work may also extend our methods to learn a metric from temporal snapshots of Schrödinger bridges.

6. ACKNOWLEDGEMENTS

The compactly-supported densities ρ 0 , ρ 1 on M, Monge's problem seeks a pushforward s of ρ 0 onto ρ 1 solving the following problem: W 1,A (ρ 0 , ρ 1 ) = inf s:ρ1=s # ρ0 M d A (x, s(x)) dρ 0 (x). Feldman & McCann (2002) show that under mild technical conditions, (9) has a possibly non-unique solution s. Intuitively, this map is a pushforward of ρ 0 onto ρ 1 minimizing the average geodesic distance between matched units of probability mass. The optimal value of ( 9) is the 1-Wasserstein distance between ρ 0 and ρ 1 , which we denote by W 1,A (ρ 0 , ρ 1 ). If ϕ is continuously differentiable on M, then it is 1-Lipschitz on M if and only if ∥∇ϕ(x)∥ A -1 (x) ≤ 1 for all x ∈ M. Here, we use ∇ϕ(x) to denote the vector of partial derivatives of ϕ at x. Armed with this local characterization of Lipschitz continuity, we may define the following dual problem to (9): sup ϕ:M→R ∥∇ϕ(x)∥ A -1 (x) ≤1 M ϕ(x)dρ 0 (x) - M ϕ(x)dρ 1 (x) . ( ) If the minimum in ( 10) is attained by some Kantorovich potential ϕ, the optimal values of ( 9) and ( 10) coincide and the Lipschitz bound for ϕ is tight at pairs of points (x, y) arising from a Monge map s solving (9) (Feldman & McCann, 2002, Lemma 4) . Feldman & McCann (2002, Lemma 10) also show that given a geodesic γ x : [0, 1] → M between a pair of matched points (x, s(x)) and t ∈ (0, 1), we have ∇ϕ (γ x (t)) = - γx(t) ∥ γx(t)∥A(γ x (t)) . Intuitively, ∇ϕ points along geodesics γ x joining pairs of points in the support of ρ 0 and ρ 1 that are matched by the Monge map s.

C BACKGROUND ON DYNAMICAL OPTIMAL TRANSPORT

Let α and β be probability measures defined on Ω ⊆ R D . The Benamou-Brenier formulation of the 2-Wasserstein distance W 2 (α, β) defines it as the solution to a fluid-dynamical problem (Benamou & Brenier, 2000 ) W 2 2 (α, β) = min ρt,vt 1 0 Ω ρ t (x)∥v t (x)∥ 2 2 dx dt subject to the constraints that ρ 0 = α, ρ 1 = β and the continuity equation ∂ρt ∂t + ∇ • (ρ t v t ) = 0. Solving this problem yields a time-varying velocity field v t (x) that transports α's mass to β's along a curve ρ t of probability measures. The kinetic energy in the integrand of (11) encourages probability mass to travel in straight lines. This assumption is occasionally undesirable, but it is straightforward to modify Eq. ( 11) to encourage v to point in specified directions (Hug et al., 2015) : min ρt,vt 1 0 Ω v t (x) T A(x)v t (x)ρ t (x) dx dt. ( ) Using the language developed in Appendix A, A(x) ≻ 0 specifies a metric g on the Riemannian manifold M = R D , g , and v t (x) T A(x)v t (x) = ∥v t (x)∥ 2 A(x) . Eq. ( 12) encourages v to be aligned with the eigenvectors u 1 corresponding to the minimal eigenvalues λ 1 of the matrices A(x). While works like (Hug et al., 2015; Zhang et al., 2022) investigate the modeling applications of anisotropic optimal transport, they only consider the case where the Riemannian metric A(x) is available a priori. This assumption is unrealistic for many problem domains, motivating our model, which learns a metric from cross-sectional samples from populations evolving over time. 

√

x 2 +y 2 , x √ x 2 +y 2 . To generate the training data, we begin by drawing 100 samples each from 4 isotropic normal distributions with standard deviation σ = 0.1 whose means are µ i ∈ {(1, 0), (-1, 0), (-1, -1), (0, 1)}. We randomly pair samples from subsequent distributions and compute A-geodesics between each pair by solving problem (8). We implement (8) in Pytorch using a time-invariant vector field v θ parametrized by a fully connected two-layer neural network with ELU nonlinearities and 64 hidden dimensions. We set λ = 1 and solve the initial value problem ẋ(t) = v θ (x(t)); x(0) = x 0 using the explicit Adams solver in torchdiffeq's odeint with default hyperparameters (Chen et al., 2018a) . We optimize the objective using AdamW with learning rate 10 -3 and weight decay 10 -3 and train for 100 epochs per pair of samples. We then draw 24 points at equispaced times t i ∈ [0, 1] from each resulting geodesic and aggregate across geodesics to form the observed populations X ti . We then use our method to recover Â-1 (x) from the X ti . We parametrize the scalar potentials in (4) as a single-hidden-layer neural net with 32 hidden dimensions and Softplus activation. We parametrize the matrix field Â-1 (x) as A -1 (x) = Q(x) T Q(x) + 10 -3 I, where Q(x) is a two-layer neural network with Softplus activations and 32 hidden dimensions. The strength of the gradient penalty ( 5) is 10 -3 when training the potentials ϕ in the first step of alternation and 10 -4 in the next step; it is 1 when training Â-1 . The strength of the regularization ( 6) is 10 3 . We carry out a two steps of the alternating scheme by training with AdamW with learning rate 10 -2 and weight decay 5 * 10 -1 . We train for 300 epochs for the ϕ step and 1,000 epochs for the A step. We evaluate the alignment score ℓ(A, Â) on a 100 × 100 grid overlaid on a box of radius 1.5. Mass splitting data. The ground truth metric tensor for this example is A(x, y) = Iv(x, y)v(x, y) T , where v(x, y) = 1 √ 2 , 1 √ 2 for y ≥ 0 and v(x, y) = 1 √ 2 , -1 √ 2 for y < 0. To generate the training data, we begin by drawing 100 samples each from a standard normal distribution and from a mixture of two isotropic normal distributions with unit variance and mixture components centered at (10, 10) and (10, -10). We randomly pair samples from subsequent distributions and compute A-geodesics between each pair by solving problem (8) using the method described in the circular data section. We then draw 10 points at equispaced times t i ∈ [0, 1] from each resulting geodesic and aggregate across geodesics to form the observed populations X ti . We then use our method to recover Â-1 (x) from the X ti . We parametrize the scalar potentials in (4) as a single-hidden-layer neural net with 32 hidden dimensions and Softplus activation. We parametrize the matrix field Â-1 (x) as A -1 (x) = Q(x) T Q(x) + 10 -3 I, where Q(x) is a twolayer neural network with Softplus activations and 32 hidden dimensions. The strength of the gradient penalty ( 5) is 2 when training the potentials ϕ and 1 when training Â. The strength of the regularization (6) is 10 6 . We carry out a single step of the alternating scheme by training with AdamW with learning rate 5 * 10 -3 and weight decay 10 -3 . We train for 600 epochs for the ϕ step and 20,000 epochs for the A step. We evaluate the alignment score ℓ(A, Â) on a 100 × 100 grid overlaid on the rectangular region [-2.5, 15 ] × [-15, 15] . X-Paths data. The ground truth metric tensor for this example is A(x, y) = I -v(x, y)v(x, y) T . Here we define v(x, y) = α(x, y)v 1 (x, y) + β(x, y)v 2 (x, y), where v 1 (x, y) = 1 √ 2 , 1 √ 2 and v 2 (x, y) = 1 √ 2 , -1 √ 2 . We then define α(x, y) = 1.25 tanh(ReLU(x • y)) and β(x, y) = -1.25 tanh(ReLU(-x • y)). Intuitively, α should be large in quadrants 1 and 3 and β should be large in quadrants 2 and 4. dataset (Sullivan et al., 2009) . We bin the sightings by month of observation, keep 1000 records per month, and convert the sighting locations from latitude-longitude to xy coordinates using the Mercator projection implemented in Matplotlib's Basemap module with rsphere set to 5 and the latitude/longitude of the lower left and upper right corners of the projection set to the minimum latitude and longitude and maximum latitude and longitude in the training data, respectively. The ground truth goose trajectories are drawn from the "Banks Island SNGO" study on Movebank. This data consists of time-stamped GPS measurements of the locations of 8 snow geese from 2019 to 2022. We use the same Mercator projection to convert the goose location measurements from latitude/longitude to xy coordinates. We then estimate the initial and final time point of a single migration for each goose. For geese with ID {82901, 82902, 82905, 82906, 82907, 82908, 82909, 82910}, their respective initial time indices are {20000, 0, 0, 20000, 0, 0, 0, 0} and their respective final time indices are {26000, 9100, 15057, 26500, 9000, 13037, 7201, 10000}. The initial location of each goose is the initial condition x 0 in (8), and the final location of each goose is x 1 . Learning the metric. We use our method to learn a metric Â-1 (x) from the X ti . We parametrize the scalar potentials in (4) as a single-hidden-layer neural net with 32 hidden dimensions and a Softplus activation. We parametrize the matrix field Â-1 (x) as A -1 (x) = Q(x) T Q(x) + 10 -3 I, where Q(x) is a two-layer neural network with Softplus activations and 32 hidden dimensions. The strength of the gradient penalty ( 5) is 10 -6 when training the potentials ϕ and 1 when training Â-1 . The strength of the regularization ( 6) is 10 9 . We carry out a single a step of the alternating scheme by training with AdamW with learning rate 10 -2 and weight decay 10 -3 . We train for 2,000 epochs for the ϕ step and 10,000 epochs for the A step. Figure 5 : The untagged goose sighting data used to learn the metric for the bird trajectory inference experiments. Each subplot depicts goose sightings in the U.S. and Canada during one month of the spring migration, beginning in January (upper-left) and ending in June (bottom-right). Note that there is no correspondence between points in subsequent time points; any given goose is likely to have been sighted only once. Trajectory inference task. We parametrize the time-varying velocity field v t,θ as a fully connected three-hidden-layer neural network with 64 hidden dimensions and ELU nonlinearities. We compute particle trajectories x(t) by solving the initial value problem ẋ(t) = v θ (x(t)); x(0) = x 0 for initial positions x 0 ∈ X ti using the dopri5 solver in torchdiffeq's odeint with default hyperparameters 



Figure 1: Row 1: Eigenvectors of true metric A(x) (purple) and learned metric Â(x) (orange). Row 2: Log-condition number of learned metric Â(x) -yellow indicates highly anisotropic Â(x). The points are time samples from which Â(x) is recovered. Teal points indicate earlier times and yellow points indicate later times, so the pairs of temporal samples (X ti , X ti+1 ) follow the color gradient.Our method accurately recovers the eigenvectors of the true metric, and the learned metric is highly anisotropic in regions that overlap the observed data.

Figure 2: Mean EMD between left-out samples and corresponding advected samples versus k, with our learned metric (orange), the Euclidean metric A(x) ≡ I (green), no regularizer (blue), and interpolants obtained using Schiebinger et al. (2019)'s method (red).

Figure 3: Comparison of inferred trajectories for k = 3 (first row) and k = 15 (second row).Trajectories inferred using our learned metric tensor (second column) more closely follow the manifold structure of the ground truth data than the non-regularized baseline trajectories (first column), where particles follow nearly straight-line paths between observed time points.

Circular data. The ground truth metric tensor for this example is A(x, y) = I -v(x, y)v(x, y) T , where v(x, y) = -y

Figure 6: Results of ablation study of the impact of regularization coefficient λ in (4). The eigenvectors of the learned metric Â(x) are robust to the value of λ (rows 1,3). The log-condition number of Â(x) increases with λ (rows 2,4), indicating that large values of λ lead to increased anisotropy.

MIT Geometric Data Processing group acknowledges the generous support of Army Research Office grants W911NF2010168 and W911NF2110293, of Air Force Office of Scientific Research award FA9550-19-1-031, of National Science Foundation grants IIS-1838071 and CHS-1955697, from the CSAIL Systems that Learn program, from the MIT-IBM Watson AI Laboratory, from the Toyota-CSAIL Joint Research Center, from a gift from Adobe Systems, and from a Google Research Scholar award.

A BACKGROUND ON RIEMANNIAN GEOMETRY

A Riemannian manifold M = (M, g) is a differentiable manifold M equipped with a Riemannian metric g. Throughout this paper, we take M = R D and focus our attention on the metric g. A Riemannian metric assigns an inner product ⟨•, •⟩ p to the tangent space T p M of each p ∈ M in a way that varies smoothly with p. In the case where M = R D , we may simply identify all tangent spaces T p M with R D ; here a metric g amounts to a spatially-varying inner product on R D . Since any inner product on R D may be computed as ⟨u, v⟩ A = u T Av for some A ≻ 0, a Riemannian metric g on M = R D is specified by a smooth field of positive definite matrices A(x) : R D → S D ++ (we use S D ++ to denote the set of positive definite D × D matrices). We use ⟨u,. Given this norm, we define the length of a continuously differentiable curve γ : [0, 1] → M to be ℓ(γ) = The training data for this example consists of two trajectories. To generate it, we begin by drawing 100 samples each from isotropic normal distributions with standard deviation σ = 0.1 centred at (-1, -1), (1, 1) for the first trajectory and (-1, 1), (1, -1) for the second trajectory. We randomly pair samples from subsequent distributions along each trajectory and compute A-geodesics between each pair by solving problem (8) using the method described in the circular data section. We then draw 10 points at equispaced times t i ∈ [0, 1] from each resulting geodesic and aggregate across geodesics to form the observed populations X ti .We then use our method to recover Â(x) from the X ti . We parametrize the scalar potentials in (4) as a single-hidden-layer neural net with 32 hidden dimensions and Softplus activation. We parametrize the matrix field Â(x) as A(x) = Q(x) T Q(x) + 10 -3 I, where Q(x) is a two-layer neural network with Softplus activations and 32 hidden dimensions. The strength of the gradient penalty ( 5) is 10 -3 when training the potentials ϕ in the first alternating step and 10 -4 for subsequent steps. It is 1 when training Â. The strength of the regularization ( 6) is 10 3 . We carry out three steps of the alternating scheme by training with AdamW with learning rate 10 -2 and weight decay 5 * 10 -3 . We train for 300 epochs for the ϕ step and 40,000 epochs for the A step.We evaluate the alignment score ℓ(A, Â) on a 100 × 100 grid overlaid on a box of radius 1.5.

D.2 SCRNA TRAJECTORY INFERENCE

Data pre-processing. The data for this experiment consists of force-directed layout embedding coordinates of gene expression data from Schiebinger et al. (2019) collected over 18 days of reprogramming (39 time points total) which we rescale by a factor of 10 -3 for increased stability in training. We construct populations X ti for i = 1, ..., 39 by randomly drawing 500 samples per time point in the original data; this corresponds to using 8.25% of the available data on average.Learning the metric. We use our method to learn a metric Â-1 (x) from the X ti . We parametrize the scalar potentials in (4) as a single-hidden-layer neural net with 128 hidden dimensions and Softplus activation. We parametrize the matrix field Â-1 (x) as A -1 (x) = Q(x) T Q(x) + 10 -9 I, where Q(x) is a single-hidden-layer neural network with Softplus activation and 2048 hidden dimensions. We omit the gradient penalty (5) when training the potentials ϕ and set the penalty strength to 10 when training Â. The strength of the regularization (6) is 5 * 10 2 . We carry out a single step of the alternating scheme by training with AdamW with learning rate 5 * 10 -3 and weight decay 1.5 * 10 -2 . We train for 100 epochs for the ϕ step and 5,000 epochs for the A step.Trajectory inference task. We parametrize the time-varying velocity field v t,θ as a fully connected three-hidden-layer neural network with 64 hidden dimensions and Softplus activations. We follow Grathwohl et al. (2019) and concatenate the time variable to the input to each layer of the neural net. We compute particle trajectories x(t) by solving the initial value problem ẋ(t) = v θ (x(t)); x(0) = x 0 for initial positions x 0 ∈ X ti using the midpoint method solver in torchdiffeq's odeint with default hyperparameters (Chen et al., 2018a) . The fitting loss for this task is GeomLoss's Sinkhorn divergence with p = 2 and blur parameter fixed to 5 * 10 -2. We fix λ = 10 -1 for both our learned metric Â and for the identity baseline A = I and choose the intermediate time points t j to be an equispaced sampling of [0, 1] with step size 1 60 for all experiments. We solve problem (7) for each pair of subsequent time retained points (t i , t i+1 ). In each case, we optimize the objective using AdamW with learning rate 10 -3 and weight decay 10 -3 and train for 10,000 iterations. We evaluate the inferred trajectories by approximately computing the W 1 distance (using GeomLoss with p = 1 and blur of 10 -6 ) between left-out time points in the ground truth data and advected samples at corresponding time points. For left-out time points of form t = ki + ℓ for some integer ℓ ∈ 1, ..., k -1, the corresponding advected sample in the equispaced sampling of step size 1 60 of [0, 1] has index ⌊ j k * 60⌋.

D.3 SNOW GOOSE TRAJECTORY INFERENCE

Data pre-processing. The training data for learning the metric in this experiment consists of time-stamped sightings of untagged snow geese (Anser caerulescens) across the U.S. and Canada during their spring migration. This data is drawn from the February 2022 version of the eBird basic (Chen et al., 2018a) . We fix λ = 1.25 * 10 2 for geese 82901, 82902, 82906, 82908, and 82909 and use λ = 2.5 * 10 2 for the remaining geese 82905, 82907, 82910. In each case, we optimize the objective using AdamW with learning rate 10 -3 and no weight decay, set the times t j in (8) be an equispaced sampling of [0, 1] with 32 time points and train for 500 iterations.

E ABLATION STUDY: IMPACT OF THE REGULARIZATION COEFFICIENT

In this section, we carry out an ablation study of the impact of the coefficient λ on the regularization term R(A) in (4). We repeat the metric recovery experiment on the "Circular" example. We follow the procedure in Appendix D.1 but set λ ∈ {5 * 10 2 , 10 3 , 5 * 10 3 , 10 4 , 5 * 10 4 , 10 5 }. (Note that our reported results in Section 4.1 use λ = 10 3 .)The results of this experiment are recorded in Figure 6 . The eigenvectors of the learned metric Â(x) are robust to the value of λ except at the largest value of λ = 10 5 , where the alignment score falls to 0.857. As expected, the log-condition number of Â(x) increases somewhat with λ, indicating that larger values of λ favor high levels of anisotropy.F ABLATION STUDY: REMOVING THE OT TERMIn this section, we briefly demonstrate that if we remove the optimal transport term from (4), the resulting metric is not informed by the training data. We repeat the metric recovery experiment on the "Circular" example. We follow the procedure in Appendix D.1 but exclude the OT termfrom our implementation of (4). The results of this experiment are presented in Figure 7 . Note that the learned metric is no longer informed by the data; the low-energy eigenvectors point in the direction of the vertical axis everywhere and the log-condition number is simply an increasing function of distance from the origin.Published as a conference paper at ICLR 2023 

