TRANSPORT WITH SUPPORT: DATA-CONDITIONAL DIFFUSION BRIDGES

Abstract

The dynamic Schrödinger bridge problem provides an appealing setting for posing optimal transport problems as learning non-linear diffusion processes and enables efficient iterative solvers. Recent works have demonstrated state-of-the-art results (e.g., in modelling single-cell embryo RNA sequences or sampling from complex posteriors) but are typically limited to learning bridges with only initial and terminal constraints. Our work extends this paradigm by proposing the Iterative Smoothing Bridge (ISB). We combine learning diffusion models with Bayesian filtering and optimal control, allowing for constrained stochastic processes governed by sparse observations at intermediate stages and terminal constraints. We assess the effectiveness of our method on synthetic and real-world data and show that the ISB generalises well to high-dimensional data, is computationally efficient, and provides accurate estimates of the marginals at intermediate and terminal times.

1. INTRODUCTION

Generative diffusion models have gained increasing popularity and achieved impressive results in a variety of challenging application domains, such as computer vision (e.g., Ho et al., 2020; Song et al., 2021a; Dhariwal & Nichol, 2021) , reinforcement learning (e.g., Janner et al., 2022) , and time series modelling (e.g., Rasul et al., 2021; Vargas et al., 2021; Tashiro et al., 2021; Park et al., 2022) . Recent works have explored connections between denoising diffusion models and the dynamic Schrödinger bridge problem (SBP, e.g., Vargas et al., 2021; De Bortoli et al., 2021; Shi et al., 2022) to adopt iterative schemes for solving the dynamic optimal transport problem more efficiently. The solution of the SBP that correspond to denoising diffusion models is then given by the finite-time process, which is the closest in Kullback-Leibler (KL) divergence to the forward noising process of the diffusion model under marginal constraints. Data is then generated by time-reversing the process. In many applications, the interest is not purely in modelling transport between an initial and terminal state distribution In naturally occurring generative processes, we typically observe snapshots of realizations along intermediate stages of individual sample trajectories (see Fig. 1 ). Such problems arise in medical diagnosis (e.g., tissue changes and cell growth), demographic modelling, environmental dynamics, and animal movement modelling-see Fig. 4 for modelling bird migration and wintering patterns. Recently, constrained optimal control problems have been explored by adding additional fixed path constraints (Maoutsa et al., 2020; Maoutsa & Opper, 2021) or modifying the prior processes (Fernandes et al., 2021) . However, defining meaningful fixed path constraints or prior processes for the optimal control problems can be challenging, while sparse observational data are accessible in many real-world applications. In this work, we propose the Iterative Smoothing Bridge (ISB), an iterative method for solving control problems under sparse observational data constraints and constraints on the initial and terminal distribution. We perform the conditioning by leveraging the iterative pass idea from the Iterative Proportional Fitting procedure (IPFP) (Kullback, 1968; De Bortoli et al., 2021) procedure and applying differentiable particle filtering (Reich, 2013; Corenflos et al., 2021) within the outer loop. Integrating sequential Monte Carlo methods (e.g., Doucet et al., 2001; Chopin & Papaspiliopoulos, 2020) into the IPFP framework in such a way is non-trivial and can be understood as a novel iterative version of the algorithm by Maoutsa & Opper (2021) but with more general terminal constraints and path constraints defined by data. We summarize the contributions as follows. (i) We propose a novel method for solving constrained optimal transport as a bridge problem under sparse observational data constraints. (ii) Thereof, we utilize the strong connections between the constrained bridging problem and particle filtering in sequential Monte Carlo, extending them from pure inference to learning. Additionally, (iii) we demonstrate practical efficiency and show that the iterative smoothing bridge approach scales to high-dimensional data.

Schrödinger bridges

The problem of learning a stochastic process moving samples from one distribution to another can be posed as a type of a transport problem known as a dynamic Schrödinger bridge problem (SBP, e.g., Schrödinger, 1932; Léonard, 2014) , where the resulting marginal densities are desired to resemble a given reference measure. In machine learning literature, the problem has been studied through learning the drift function of the dynamical system (De Bortoli et al., 2021; Wang et al., 2021; Vargas et al., 2021; Bunne et al., 2022) . When an SDE system also defines the reference measure, the bridge problem becomes synonymous with a constrained optimal control problem (e.g., Caluya & Halder, 2022; 2021; Chen et al., 2021) , which has been leveraged in learning Schrödinger bridges by Tianrong Chen (2022) through forward-backward SDEs. An optimal control problem with both constraints on the initial and terminal distribution and a fixed path constraint has been studied in Maoutsa et al. (2020) and Maoutsa & Opper (2021) , where particle filtering is applied to continuous path constraints but the boundary constraints are defined by a single point. Furthermore, the combination of Schrödinger bridges and state-space models has been studied by Reich (2019) , in a setting where Schrödinger bridges are applied to the transport problem between filtering distributions.

Diffusion models in machine learning

The recent advances in diffusion models in machine learning literature have been focused in generating samples from complex distributions defined by data through transforming samples from an easy-to-sample distribution by a dynamical system (e.g., Ho et al., 2020; Song et al., 2021b; a; Nichol & Dhariwal, 2021) . The concept of reversing SDE trajectories via score-based learning (Hyvärinen & Dayan, 2005; Vincent, 2011) has allowed for models scalable enough to be applied to high-dimensional data sets directly in the data space. In earlier work, scorebased diffusion models have been applied to problems where the dynamical system itself is of interest, for example, for the problem of time series amputation in Tashiro et al. (2021) and inverse problems in imaging in Song et al. (2022) . Other dynamical models parametrized by neural networks have been applied to modelling latent time-series based on observed snapshots of dynamics (Rubanova et al., 2019; Li et al., 2020) , but without further constraints on the initial or terminal distributions.

State-space models

In their general form, state-space models combine a latent space dynamical system with an observation (likelihood) model. Evaluating the latent state distribution based on observational data can be performed by applying particle filtering and smoothing (Doucet et al., 2000) or by approximations of the underlying state distribution of a non-linear state-space model by a specific model family, for instance, a Gaussian (see Särkkä, 2013 , for an overview). Speeding up parameter inference and learning in state-space models has been widely studied (e.g., Schön et al., 2011; Svensson & Schön, 2017; Kokkala et al., 2014) . Particle smoothing can be connected to Schrödinger bridges via the two-filter smoother (e.g., Bresler, 1986; Briers et al., 2009; Hostettler, 2015) , where the smoothing distribution is estimated by performing filtering both forward from the initial constraint and backward from the terminal constraint. We refer to Mitter (1996) and Todorov (2008) for a more detailed discussion on the connection of stochastic control and filtering and to Chopin & Papaspiliopoulos (2020) for an introduction to particle filters.

2. BACKGROUND

Let C = C([0, T ], R d ) denote the space of continuous functions from [0, T ] to R d and let B(C) denote the Borel σ-algebra on C. Let P(π 0 , π T ) denote the space of probability measures on (C, B(C)) such that the marginals at 0, T coincide with probability densities π 0 and π T , respectively. The KL divergence from measure Q to measure P is written as D KL [Q P], where we assume that Q ≪ P. For modelling the time dynamics, we assume a (continuous-time) state-space model consisting of a non-linear latent Itô SDE (see, e.g., Øksendal, 2003; Särkkä & Solin, 2019)  in [0, T ] × R d with drift function f θ (•) and diffusion function g(•) , and a Gaussian observation model, i.e., x 0 ∼ π 0 , dx t = f θ (x t , t) dt + g(t) dβ t , y k ∼ N(y k | x t , σ 2 I d ) t=t k , where the drift f θ : R d ×[0, T ] → R d is a mapping modelled by a neural network (NN) parameterized by θ ∈ Θ, diffusion g : [0, T ] → R and β t denotes standard d-dimensional Brownian motion. x t denotes the latent stochastic process and y t denotes the observation-space process. In practice, we consider the continuous-discrete time setting, where the process is observed at discrete time instances t k such that observational data can be given in terms of a collection of input-output pairs {(t k , y k )} M k=1 .

2.1. SCHRÖDINGER BRIDGES AND OPTIMAL CONTROL

The Schrödinger bridge problem (SBP, Schrödinger, 1932; Léonard, 2014) can be described as an entropy-regularized optimal transport problem where the optimality is measured through the KL divergence from a reference measure P to the posterior measure Q, with fixed initial and final densities π 0 and π T , i.e., min Q∈P(π0,π T ) D KL [Q P] . In this work, we consider only the case where the measures P and Q are constructed as the marginals of an SDE, i.e., Q t is the probability measure of the marginal of the SDE in Eq. (1) at time t, whereas P t corresponds to the probability measure of the marginal of a reference SDE dx t = f (x t , t) dt + g(t) dβ t , at time t, where we call f the reference drift. Under the optimal control formulation of the SBP (Caluya & Halder, 2021) the KL divergence in Eq. ( 2) reduces to E T 0 1 2g(t) 2 f θ (x t , t) -f (x t , t) 2 dt , where the expectation is over paths from Eq. (1). Rüschendorf & Thomsen (1993) and Ruschendorf (1995) showed that a solution to the SBP can be obtained by iteratively solving two half-bridge problems using the Iterative Proportional Fitting procedure (IPFP) for l = 0, 1, . . . , L steps: Q 2l+1 = arg min Q∈P(•,π T ) D KL [Q Q 2l ] and Q 2l+2 = arg min Q∈P(π0,•) D KL [Q Q 2l+1 ] , where Q 0 is set as the reference measure, and P(π 0 , •) and P(•, π T ) denote the sets of probability measures with only either the marginal at time 0 or time T coinciding with π 0 or π T , respectively. Recently the IPFP to solving Schrödinger bridges has been adapted as a machine learning problem (Bernton et al., 2019; Vargas et al., 2021; De Bortoli et al., 2021) . In practice, the interval [0, T ] is discretized and the forward drift f θ and the backward drift b φ of the corresponding reverse-time process (Haussmann & Pardoux, 1986; Föllmer, 1988) are modelled by NNs. Under the Gaussian transition approximations, the resulting discrete-time diffusion model can be reversed by applying a mean-matching based objective. The forward diffusion at the first iteration (ISB 1) learns to account for the sparse observations but does not converge to the correct terminal distribution (t = T ), and the backward diffusion vice versa. After iterating (ISB 6), the forward and backward diffusions converge to the correct targets and are able to account for the sparse observational data.

3. METHODS

Given an initial and terminal distribution π 0 and π T , we are interested in learning a data-conditional bridge between π 0 and π T . Let D = {(t j , y j )} M j=1 be a set of M sparsely observed values, i.e., only a few or no observations are made at each point in time, and let the state-space model of interest be given by Eq. (1). Note that we deliberately use (t j , y j ) (instead of (t k , y k )) to highlight that we allow for multiple observations at the same time point t k . Our aim is to find a parameterization of the drift function f θ such that evolving N particles x i t , with x i 0 ∼ π 0 (with i = 1, 2, . . . , N ), according to Eq. (1) will result in samples x i T from the terminal distribution π T . Inspired by the IPFP by De Bortoli et al. (2021) , which decomposes the SBP into finding two half-bridges, we propose to iteratively solve the two half-bridge problems while accounting for the additional sparse observations simultaneously. For this, let dx t = f l,θ (x t , t) dt + g(t) dβ t , x 0 ∼ π 0 , dz t = b l,φ (z t , t) dt + g(t) d βt , z 0 ∼ π T , denote the forward and backward SDE at iteration l = 1, 2, . . . , L, where βt is the reverse-time Brownian motion. For simplicity, we denote β t = βt when the direction of the SDE is clear. To learn the data-conditioned bridge, we iteratively employ the following steps: 1 evolve and filter forward particle trajectories according to Eq. ( 5) with drift f l-1,θ and observations {(t k , y k )} M k=1 , 2 learn the drift function b l,φ for the reverse-time SDE, 3 evolve and filter backward particle trajectories according to Eq. ( 6) with the drift b l,φ learned in step 2 and observations {(t k , y k )} M k=1 , and 4 learn the drift function f l,θ for the forward SDE based on the backward particles. Fig. 2 illustrates the forward and backward process of our iterative scheme for a data-conditioned denoising diffusion bridge. Next, we will go through steps 1 -4 in detail and introduce the Iterative Smoothing Bridge method for solving data-conditional diffusion bridges.

3.1. THE ITERATIVE SMOOTHING BRIDGE

The Iterative Smoothing Bridge (ISB) method iteratively generates particle filtering trajectories (steps 1 and 3 in Fig. 2 ) and learns the parameterizations of the forward and backward drift functions f l,θ and b l,φ (steps 2 and 4 ) by minimizing a modified version of the mean-matching objective presented by De Bortoli et al. (2021) . Note that steps 2 and 4 are dependent on applying differential resampling in the particle filtering steps 1 and 3 for reversing the generated trajectories. We will now describe the forward trajectory generating step 1 and the backward drift learning step 2 in detail. Steps 3 and 4 are given by application of 1 and 2 on their reverse-time counterparts. Step 1 (and 3 ): Given a fixed discretization of the time interval [0, T ] denoted as {t k } K k=1 with t 1 = 0 and t K = T , denote the time step lengths as ∆ k = t k+1 -t k . By truncating the Itô-Taylor series of the SDE, we can consider an Euler-Maruyama (e.g., Ch. 8 in Särkkä & Solin, 2019) type of discretization for the continuous-time problem. We give the time-update of the i th particle at time t k evolved according to Eq. ( 5) before conditioning on the observational data as xi t k = x t k-1 + f l-1,θ (x t k-1 ,t k-1 )∆ k + g(t k-1 ) ∆ k ξ i k , where ξ i k ∼ N(0, I). In step 3 , the particles zi t k of the backward SDE Eq. ( 6) are similarly obtained. The SDE dynamics sampled in steps 1 and 3 apply the learned drift functions f l-1,θ and b l,φ from the previous step and do not require sampling from the underlying SDE model. For times t k at which no observations are available, we set x i t = xi t (and z i t k = zi t k respectively) and otherwise compute the particle filtering weights w i t k based on the observations {(t j , y j ) ∈ D | t j = t k } for resampling. See Sec. 3.2 for details on the particle filtering proposal density and calculation of the particle weights. For resampling, we employ a differentiable resampling procedure, where the particles and weights (x i t k , w i t k ) are transported to uniformly weighted particles x i t k by solving an entropy-regularized optimal transport problem (Cuturi, 2013; Peyré & Cuturi, 2019; Corenflos et al., 2021) , see App. D for further details. Through application of the ε-regularized optimal transport map T (ε) ∈ R N ×N (see Corenflos et al., 2021, for details) the particles are resampled via the map to x i t k = X⊤ t k T (ε),i , where Xt k ∈ R N ×d denotes the stacked particles {x i t k } N i=1 at time t k before resampling. The resampled particles for the backward process are given similarly. Step 2 (and 4 ): Given the particles {x i t k } K,N k=1,i=1 , we now aim to learn the drift function for the respective reverse-time process. In case no observation is available at time t k , we apply the meanmatching loss based on a Gaussian transition approximation proposed in De Bortoli et al. ( 2021): ℓ i k+1,no obs = b l,φ (x i t k+1 , t k+1 )∆ k -x i t k+1 -f l-1,θ (x i t k+1 , t k )∆ k +x i t k +f l-1,θ (x i t k , t k )∆ k 2 . (8) In case an observation is available at time t k the particle values Xt k will be coupled through the optimal transport map. Therefore, the transition density is a sum of Gaussian variables (see App. A for details and a derivation), and the mean-matching loss is therefore given by sℓ i k+1,obs = b l,φ (x i t k+1 , t k+1 )∆ k -x i t k+1 -f l-1,θ (x i t k+1 , t k )∆ k + N n=1 T (ε),i,n x n t k + f l-1,θ (x n t k , t k )∆ k 2 . (9) The overall objective function is a combination of both loss functions, with the respective meanmatching loss depending on whether t k is an observation time. The final loss function is written as: ℓ(φ) = N i=1 K k=1 ℓ i k,obs (φ)I yt k =∅ + K k=1 ℓ i k,no obs (φ)I yt k =∅ , where I cond. denotes an indicator function that returns '1' iff the condition is true, and '0' otherwise. Consequently, the parameters φ of b l,φ are learned by minimizing Eq. ( 10) through gradient descent. In practice, a cache of trajectories {x i t k } K,N k=1,i=1 is maintained through training of the drift functions, and refreshed at fixed number of inner loop iterations, as in De Bortoli et al. ( 2021), avoiding differentiation over the SDE generation computational graph. The calculations for step 4 follow similarly. The learned backward drift b l,φ can be interpreted as an analogy of the backward drift in Maoutsa & Opper (2021) , connecting our approach to solving optimal control problems through Hamilton-Jacobi equations, see App. A.2 for an analysis of the backwards SDE and the control objective. While we are generally considering problem settings where the number of observations is low, we propose that letting M → ∞ yields the underlying marginal distribution, see Prop. 2 in App. A.3.

3.2. COMPUTATIONAL CONSIDERATIONS

The ISB algorithm is a generic approach to learn data-conditional diffusion bridges under various choices of, e.g., the particle filter proposal density or the reference drift. Next, we cover practical considerations for the implementation of the method and highlight the model choices in the experiments. Multiple observations per time step Naturally, we can make more than one observation at a single point in time t k , denoted as D t k = {(t j , y j ) ∈ D | t j = t k }. To compute particle weights w i t k for the i th particle we consider only the H-nearest neighbours of x i t k in D t k instead of all observations in D t k . By restricting to the H-nearest neighbours, denoted as D H t k , we introduce an additional locality to the proposal density computation which can be helpful in case of multimodality. On the other hand, letting H > 1 results in weights which take into account the local density of the observations, not only the distance to the nearest neighbour. In experiments with few observations, we set H = 1, the choice of H is discussed when we have set the value higher.

Particle filtering proposal

The proposal density chosen for the ISB is the bootstrap filter, where the proposal matches the Gaussian transition density p(x t k | x t k-1 ). Assuming a Gaussian noise model N(0, σ 2 I), the unnormalized log-weights for the i th particle at time t k are given by: log w i t k = - 1 2σ 2 yj ∈D H t k x i t k -y j 2 . ( ) Observational noise schedule In practice, using a constant observation noise σ 2 variance can result in an iterative scheme which does not have a stationary point as L → ∞. Even if the learned drift function f l,θ was optimal, the filtering steps 1 and 3 would alter the trajectories unless all particles would have uniform weights. Thus, we introduce a noise schedule κ(l) which ensures that the observation noise increases in the number of ISB iterations, causing ISB to converge to the IPFP (De Bortoli et al., 2021) as L → ∞. We found that letting the observation noise first decrease and then increase (in the spirit of simulated annealing) often outperformed a strictly increasing observation noise schedule. The noise schedule is studied in App. C, where we derive the property that letting L → ∞ yields IPFP. Drift initialization Depending on the application, one may choose to incorporate additional information by selecting an appropriate initial drift. A possible choice includes a pre-trained neural network drift learned to transport π 0 to π T without accounting for observations. However, starting from a drift for the unconstrained SBP can be problematic in cases where the observations are far away from the unconstrained bridge. To encouraged exploration, one may choose f 0 = 0 for the initial drift. In various problem settings, we found both, a zero drift and starting from the SBP, to be successful in the experiments, see App. C for discussion.

4. EXPERIMENTS

To assess the properties and performance of the ISB, we present a range of experiments that demonstrate how the iterative learning procedure can incorporate both observational data and terminal constraints. We start with simple examples that build intuition (cf. Fig. 1 and Fig. 2 ) and show standard ML benchmark tasks. For quantitative assessment, we design an experiment with a non-linear SDE for which the marginal distributions are available in closed-form. All experiment settings include a number of hyperparameter choices, some of them typical to all diffusion problems and some specific to particle filtering and smoothing. The diffusion g(t) is a pre-determined function not optimized in the training. We divide the experiments to two main subsets: problems of 'sharpening to achieve a data distribution' and 'optimal transport problems'. In the former, the initial distribution has a support overlapping with the terminal distribution and the process noise level g(t) goes from high to low as time progresses. Conversely in the latter setting, the particles sampled from the initial distribution must travel to reach the support of the terminal distribution, and we chose to use a constant process noise level. Perhaps the most significant choice of hyperparameter is the observational noise level, as it imposes a preference on how closely should the observational points be followed, see App. C.1 for details. 2D toy examples We show illustrative results for the TWO MOONS and CIRCLES from scikit-learn. We add artificial observation data to bias the processes. For the circles, the observational data consists of 10 points, spaced evenly on the circle. The points are all observed simultaneously, at halfway through the process, forcing the marginal density of the generating SDE to collapse to the small circle, and then to expand. For the two moons, the observational data is collected from 10 trajectories of a diffusion model which generates the two moons from noise, and these 10 trajectories are then observed at three points in time. Results are visualized in Fig. 3 (see videos in supplement). For reference, we have included plots of the IPFP dynamics in the supplement, see Fig. 6 . 

Quantitative comparison on the Beneš SDE

In order to quantify how observing a process in between its initial and terminal states steers the ISB model to areas with higher likelihood, we test its performance on a Beneš SDE model (see, e.g. Särkkä & Solin, 2019) . The Beneš SDE is a non-linear one-dimensional SDE of form dx t = tanh(x t ) dt + dβ t with x 0 = 0, but its marginal density is available in closed-form, allowing for negative log-likelihood evaluation. We simulate trajectories from the Beneš SDE and from the reverse drift and stack the reversed trajectories. The terminal distribution is shifted and scaled so that the Beneš SDE itself does not solve the transport problem from π 0 to π T , see App. B.2 for details and visualizations of the processes. As a baseline, we fit a Schrödinger bridge model with no observational data, using the Beneš SDE drift as the reference model. The ISB model is initialized with a zero-drift model (not with the Beneš as reference), thus making learning more challenging. We compare the models in terms of negative log predictive density in Table 1 , where we see that the ISB model captures the process well on average (over the entire time-horizon) and at selected marginal times. Bird migration Bird migration can be seen as a regular seasonal transport problem, where birds move (typically North-South) along a flyway, between breeding and wintering grounds. We take this as a motivating example of constrained optimal transport, where the geographical and constraints and preferred routes are accounted for by bird sighting data (see Fig. 4 top). By adapting data from Ambrosini et al. ( 2014) and Pellegrino et al. (2015) , we propose a simplified data set for geese migration in Europe (OIBMD: ornithologically implausible bird migration data; available in the supplement). We applied the ISB for 12 iterations, with a linear observation noise schedule from 1 to 0.2, and constant diffusion noise 0.05. The drift function was initialized as a zero-function, and thus the method did not rely on a separately fit model optimized for generating the wintering distribution based on the breeding distribution. For comparison, we include the Schrödinger bridge results in App. B.3.

Constraining an image generation process

We demonstrate that the ISB approach scales well to high-dimensional inputs by studying a proof-of-concept image generation task. We modify the diffusion generative process of the MNIST (LeCun et al., 1998) digit 8 by artificial observations steering the dynamical system in the middle of the generation process. While the concept of observations in case of image generation is somewhat unnatural, it showcases the scalability of the method to high-dimensional data spaces. Here, the drift is initialized using a pre-trained neural network obtained by first running a Schrödinger bridge model for image generation. The process is then given an observation in the form of a bottom-half of a MNIST digit 8 in the middle of the dynamical process. As the learned model uses information from the observation both before and after the observation time, the lower half of the image is sharper than the upper half. We provide further details on this experiment and sampled trajectories in App. B.4. Single-cell embryo RNA-seq Lastly, we evaluated our approach on an Embryoid body scRNA-seq time course (Tong et al., 2020) . The data consists of RNA measurements collected over five time ranges from a developing human embryo system. No trajectory information is available, instead we only have access to snapshots of RNA data. This leads to a data set over 5 time ranges, the first from days 0-3 and the last from days 15-18. In the experiment, we followed the protocol by Tong et al. (2020) , reduced the data dimensionality to d = 5 using PCA, and used the first and last time ranges as the initial and terminal constraints. All other time ranges are considered observational data. Contrary to the other experiments, intermediate data are imprecise (only a time range of multiple days is known) but abundant. We learned the ISB using a zero drift and compared it against an unconditional bridge obtained through the IPFP (De Bortoli et al., 2021)-see Fig. with marginals closer to the observed data while performing comparably to the IPFP at the initial and terminal stages. This improvement is also verified numerically in Table 2 , showing that the ISB obtains a lower Earth mover's distance between the generated marginals and the observational data than IPFP. Additionally, Table 2 lists 

5. DISCUSSION AND CONCLUSION

The dynamic Schrödinger bridge problem provides an appealing setting for posing optimal transport problems as learning non-linear diffusion processes and enables efficient iterative solvers. However, while recent works have state-of-the-art performance in many complex application domains, they are typically limited to learning bridges with only initial and terminal constraints dependent on observed data. In this work, we have extended this paradigm and introduced the Iterative Smoothing Bridge (ISB), an iterative algorithm for learning data-conditional smoothing bridges. For this, we leveraged the strong connections between the constrained bridging problem and particle filtering in sequential Monte Carlo, extending them from pure inference to learning. We thoroughly assessed the applicability and flexibility of our approach in various experimental settings, including synthetic data sets and complex real-world scenarios (e.g., bird migration, conditional image generation, and modelling single-cell RNA-sequencing time-series). In our experiments, we showed that ISB generalizes well to high-dimensional data, is computationally efficient, and provides accurate estimates of the marginals at intermediate and terminal times. Accurately modelling the dynamics of complex systems under both path constraints induced by sparse observations and initial and terminal constraints is a key challenge in many application domains. These include biomedical applications, demographic modelling, and environmental dynamics, but also machine learning specific applications such as reinforcement learning, planning, and time-series modelling. All these applications have in common that the dynamic nature of the problem is driven by progression of time, and not only progression of a generative process as often is the case in, e.g., generative image models. Thus, constraints over intermediate stages have a natural role and interpretation in this wider set of dynamic diffusion modelling applications. We believe the proposed ISB algorithm opens up new avenues for diffusion models in relevant real-world modelling tasks and will be stimulating for future work. For example, more sophisticated observational models, alternative strategies to account for multiple observations, and different noise schedules could be explored. Furthermore, the proposed approach could be extended to other types of optimal transport problems, such as the Wasserstein barycenter, a frequently employed case of the multi-marginal optimal transport problem. 

A METHOD DETAILS

We present the details of the objective function derivation in App. A.1 and explain the connection of the backward drift function to Hamilton-Jacobi equations in App. A.2. In App. A.3, we discuss the behaviour of our model at the limit M → ∞, that is, when the observations fully represent the marginal densities of the stochastic process. A.1 DERIVING THE MEAN-MATCHING LOSS AT OBSERVATION TIMES Proposition 1. Define the forward SDE as dx t = f l,θ (x t , t) dt + g(t) dβ t , x 0 ∼ π 0 , ( ) and a backward SDE drift as b l,φ (x t k+1 , t k+1 ) = f l-1,θ (x t k+1 , t k ) -g(t k+1 ) 2 ∇ ln p t k+1 , ( ) where p t k+1 is the particle filtering density after differential resampling at time t k+1 . Then b l,φ (x t k+1 , t k+1 ) minimizes the loss function ℓ i k+1,obs = b l,φ (x i t k+1 , t k+1 )∆ k -x i t k+1 -f l-1,θ (x i t k+1 , t k )∆ k + 1 C ε,i N n=1 T (ε),i,n x n t k + f l-1,θ (x n t k , t k )∆ k 2 , ( ) where we denote C ε,i = 1 g(t k+1 ) 2 ∆ k Var N n=1 T (ε),i,n xn t k+1 , and {x i t k+1 } N i=1 are the particles before resampling. Proof sketch. Our objective is to find a backward drift function b l,φ (x t k+1 , t k+1 ) as in Eq. ( 13). Notice that at observation times t k , this is not equivalent to finding the reverse drift of the SDE forward transition and differential resampling combined, since the drift function f l-1,θ alone does not map the particles {x i t k } N i=1 to the particles {x i t k+1 } N i=1 . We will derive a loss function for learning the backward drift as in Eq. ( 13) below, leaving the discussion on why it is a meaningful choice of a backward drift to App. A.2. Our derivation closely follows the proof of Proposition 3 in De Bortoli et al. ( 2021), but we provide the details here for the sake of completeness. First, we give the transition density p xt k | x i t k-1 (x k ) and apply it to derive the observation time loss ℓ i k,obs . The derivation for the loss ℓ i k,no obs is skipped, since it is as in proof of Proposition 3 in De Bortoli et al. (2021) . Suppose that at t k , there are observations. By definition, the particles before resampling {x i t k+1 } N i=1 are generated by the Gaussian transition density p(x t k+1 | x i t k ) = N(x t k+1 | x i t k + δ k f l (x i t k , t k ), g(t k+1 ) 2 ∆ k I). ( ) Recall that the resampled particles are defined as a weighted average of all the particles, x i t k = N n=1 xn t k T (ε),i,n . Thus, the transition density from {x i t k } N i=1 to the particles {x i t k+1 } N i=1 is also a Gaussian, p(x t i k+1 | x i t k ) = N(x t k+1 | N n=1 T (ε),i,n (x n t k-1 + ∆ k f l-1,θ (x n t k , t k )), g(t k+1 ) 2 ∆ k C ε,i I d ). ( ) We will derive the loss function Eq. ( 9) by modifying the mean matching proof in De Bortoli et al. ( 2021) by the transition mean Eq. ( 16) and the backward drift definition Eq. ( 13). Using the particle filtering approximation, the marginal density can be decomposed as p t k+1 (x k+1 ) = N i=1 p t k (x i k )p x k+1 | x i k (x k+1 ) . By substituting the transition density Eq. ( 16) it follows that p t k+1 (x t k+1 ) = 1 Z N i=1 p t k (x i t k ) exp   - N n=1 T (ε),i,n (x i t k + f l-1,θ (x t k , t k )) -x t k+1 2 2g(t k+1 ) 2 C ε,i ∆ k   , ( ) where Z is the normalization constant of Eq. ( 16). As in the proof of , we first manipulate ∇ xt k+1 p t k+1 (x t k+1 ), ∇ xt k+1 p t k+1 (x t k+1 ) (18) = 1 Z N i=1 ∇ xt k+1 p(x i t k ) exp   - N n=1 T (ε),i,n (x i t k + f l-1,θ (x t k , t k )) -x t k+1 2 2g(t k+1 ) 2 C ε,i ∆ k   (19) = 1 Z N i=1 p(x i t k ) N n=1 1 g(t k+1 ) 2 ∆ k C ε,i T (ε),i,n (x i t k + f l-1,θ (x t k , t k )) -x t k+1 (20) exp   - N n=1 T (ε),i,n (x i t k + f l-1,θ (x t k , t k )) -x t k+1 2 2g(t k+1 ) 2 C ε,i ∆ k   . ( ) Substituting p t k (x i k ) = pt k+1 (xt k+1 )p x k+1 | x i k (x k+1 ) p x i k | x k+1 (x i k ) to the equation above gives ∇ xt k+1 p t k+1 (x t k+1 ) = p t k+1 (x t k+1 ) N i=1 p x k+1 | x i k (x i k ) N n=1 T (ε),i,n (x i t k + f l-1,θ (x t k , t k )) -x t k+1 g(t k+1 ) 2 ∆ k C ε,i , and dividing by p t k+1 (xt k+1 ) yields ∇ ln p t k+1 (x t k+1 ) = N i=1 p x t i k | xt k+1 (x t i k ) N n=1 T (ε),i,n (x i t k + f l-1,θ (x t k , t k )) -x t k+1 g(t k+1 ) 2 ∆ k C ε,i . Substituting Eq. ( 23) to the definition of the optimal backward drift Eq. ( 13) gives b l,φ (x t k+1 , t k+1 ) = f l-1,θ (x t k+1 , t k ) -g(t k+1 ) 2 ∇ ln p t k+1 (x k+1 ) = f l-1,θ (x t k+1 , t k )- g(t k+1 ) 2 N i=1 p x t i k | xt k+1 (x t k+1 ) N n=1 T (ε),i,n (x i t k + f l-1,θ (x t k , t k )) -x t k+1 g(t k+1 ) 2 ∆ k C ε,i , where taking f l-1,θ (x t k+1 , t k ) inside the sum yields b l,φ (x t k+1 , t k+1 ) = N i=1 p x t i k | xt k+1 (x t k+1 ) 1 C ε,i N n=1 T (ε),i,n (x i t k + f l-1,θ (x t k , t k )) - x t k+1 C ε,i -∆ k f l-1,θ (x t k+1 , t k ) /∆ k ) . Multiplying the equation above by ∆ k gives ∆ k b l,φ (x i t k+1 , t k+1 ) = N n=1 T (ε),i,n (x n t k + f l-1,θ (x n t k , t k )) - x i t k+1 C ε,i -∆ k f l-1,θ (x t i k+1 , t k ). Thus we may set the objective for finding the optimal backward drift b l,φ as ℓ i k+1,no obs = b l,φ (x i t k+1 , t k+1 )∆ k - x i t k+1 C ε,i -f l-1,θ (x i t k+1 , t k )∆ k + 1 C ε,i N n=1 T (ε),i,n x n t k + f l-1,θ (x n t k , t k )∆ k 2 . ( ) Notice that if the weights before resampling are uniform, then T (ε) = I N , and for all i ∈ 1, 2, . . . , N it holds that C ε,i = 1, since all but one of the terms in the sum 1 g(t k+1 ) 2 Var N n=1 T (ε),i,n xn t k+1 vanish. Similarly, for one-hot weights C ε,i = 1. In practice, we set the constant C ε,i = 1 as in Eq. ( 9) and observe good empirical performance with the simplified loss function.

A.2 CONNECTION TO HAMILTON-JACOBI EQUATIONS

We connect the backward drift function b l,φ (x t k+1 , t k+1 ) = f l-1,θ (x t k +1 , t k ) - g(t k+1 ) 2 ∇ ln p t k+1 (x t k+1 ) to the Hamilton-Jacobi equations for stochastic control through following the setting of Maoutsa & Opper (2021) , which applies the drift f l-1,θ (x t , t) -g(t) 2 ∇ ln p t (x t ) for a backwards SDE initialized at π T . Consider a stochastic control problem with a path constraint U (x t , t), optimizing the following loss function, J = 1 N N i=1 T t=0 1 2g(t) 2 f θ (x i t , t) -f (x i t , t) 2 + U (x i t , t) dt -ln χ(x i T ), with the paths x i t sampled as trajectories from the SDE x 0 ∼ π 0 , dx t = f l-1,θ (x t , t) dt + g(t) dβ t , and the loss ln χ(x i T ) measures distance from the distribution π T . Since we set the path constraint via observational data, our method resembles setting U (x i t , t) = 0 when t is not an observation time, and U ( x i t ) = -log p(y | x i t ) , where p(y | x i t ) is the observation model. Let q t (x) denote the marginal density of the controlled (drift f θ ) SDE at time t. In Maoutsa & Opper (2021) , the marginal density is decomposed as q t (x) = ϕ t (x)p t (x), where ϕ t (x) is a solution to a backwards Fokker-Planck-Kolmogorov (FPK) partial differential equation starting from ϕ T (x) = π T , and the density evolves as in dϕ t (x) dt = -L † f ϕ t (x) + U (x, t)ϕ t (x), where L † f is the adjoint FPK operator to the uncontrolled system. The density p t (x) corresponds to the forward filtering problem, initialized with π 0 , dp t (x) dt = L f (p t (x)) -U (x, t)p t (x), where L f is the FPK operator of the uncontrolled SDE (with drift f ). The particle filtering trajectories {x t k } i generated in our method are samples from the density defined by Eq. ( 32). In the context of our method, the path constraint matches the log-weights of particle filtering at observation times and is zero elsewhere. In Maoutsa & Opper (2021) , a backward evolution for q t is applied, using the backwards time qT -τ (x) = q τ (x), yielding a backwards SDE starting from q0 (x) = {x i T } N i=1 , reweighted according to π T . The backward samples from q are generated following the SDE dynamics dx i τ = (f (x i τ , T -τ ) + g(t) 2 ∇ ln p T -τ (x i τ ) dt + g(t) dβ τ . We have thus selected the backward drift b l,φ to match the drift of qt (x), the backward controlled density. Intuitively, our choice of b l,φ is a drift which generates the smoothed particles when initialized at {x i T } N i=1 , the terminal state of the forward SDE. The discrepancy between π T and the distribution induced by {x i T } N i=1 then motivates the use of an iterative scheme after learning to simulate from q t (x).

A.3 OBSERVING THE FULL MARGINAL DENSITY

Suppose that at time t k , we let the number of observations grow unbounded. We analyse the behaviour of our model at the resampling step, at the limit M → ∞ for the number of observations and σ → 0 for the observation noise. When applying the bootstrap proposal, recall that we combined the multiple observations to compute the log-weights as log w i t k = - 1 2σ 2 yj ∈D H i,t k x i t k -y j 2 , which works well in practice for the sparse-data settings we have considered. Below we analyse the behaviour of an alternative way to combine the weights and show that given an infinite number of observations, it creates samples from the true underlying distribution. Proposition 2. Let {x i t k } N i=1 be a set of particles and {y j } M j=1 the observations at time t k . Assume that the observations have been sampled from a density ρ t k and that for all i it holds that x i t k ∈ supp(ρ t k ). Define the particle weights as log w i t k ,σ,M = log 1 Z|D H(M ) i,t k | yj ∈D H(M ) i,t k exp(-x i t k -y j 2 /2σ 2 ) , ( ) where Z is the normalization constant of the observation model Gaussian p(y | x i t k ). Then for each particle x i t k , its weight satisfies lim σ→0 lim M →∞ w i t k ,σ,M = ρ t k (x t k ) Proof sketch. We drop the σ and H(M ) from the weight notation for simplicity of notation, but remark that the particle filtering weights are dependent of both quantities. Consider the number of particles N fixed, and denote the d-dimensional sphere centered at x i t k as B(x i t k , r). Since each particle x i t k lies in the support of the true underlying marginal density ρ t k , then for any radius r > 0 such that B(x i t k , r) ∈ supp(ρ t k ), and H > 0, we may choose M high enough so that the points y j ∈ D H i,t k satisfy y j ∈ B(x i t k , r). It follows from Eq. ( 35) that w i t k = 1 Z|D H(M ) i,t k | yj ∈D H(M ) i,t k exp(-x i t k -y j 2 /2σ 2 ). For any r > 0 and with observation noise σ = cr, we may set c, H(M ) so that the sum above approximates the integral w i r,t k ≈ 1 |B(x i t k , r)| B(x i t k ,r) p(y | x i t k )ρ t (y) dy. By applying the Lebesque differentiation theorem, we obtain that for almost every x i t k , we have lim r→0 w i t k ,r = ρ t k (x i t k ), since as σ → 0, the density p(y | x i t k ) collapses to the Dirac delta of x i t k . Prop. 2 can be interpreted as the infinite limit of a kernel density estimate of the true underlying distribution. Resampling accurately reweights the particles so that the probability of resampling particle x i t k is proportional to the density ρ t k compared to the other particles. Notice that the result does not guarantee that the particles will cover the support of ρ t k , since we did not assume that the drift initialization generates a marginal density at time t k covering its support.

B EXPERIMENTAL DETAILS

B.1 2D TOY DATA SETS For the constrained transport problem for two-dimensional scikit-learn, the observational data we chose to use was different for each of the three data sets presented; two moons, two circles and the Two moons The observational data consists of 10 points selected from the Schrödinger bridge trajectories, all observed at t ∈ [0.25, 0.5, 0.75] with an exponential observation noise schedule κ(l) = 1.25 l-1 . The ISB was ran for 6 epochs, and initialized with a drift from the pre-trained Schrödinger bridge model from the unconstrained problem.

Two circles

The observational data consists of 10 points which lie evenly distributed on a circle, observed at t = 0.5 with an exponential observational noise schedule κ(l) = 0.5 • 1.25 l-1 . The ISB was ran for 6 epochs, and initialized with a drift from the pre-trained Schrödinger bridge model from the unconstrained problem.

S-shape

The observational data consists of 6 points, with pairs being observed at times t ∈ [0.4, 0.5, 0.6]. We used a bilinear observational noise schedule with a linear decay for the first half of the iterations from κ(0) 2 = 4 to κ(L/2) 2 = 1 and a linear ascend for the second half of the iterations from κ(L/2) 2 = 1 to κ(L) 2 = 4. The ISB ran for 6 epochs, with a zero drift initialization. 

B.2 THE BENEŠ SDE

In the Beneš SDE experiment, we obtain the sparse observational data from sampled Beneš SDE trajectories while the terminal state is a shifted and scaled (3 + 5x T ) version of a Beneš marginal density. As the Beneš trajectories were first generated by simulating the SDE until t = 6 and then in reverse from t = 6 to t = 0, we set T = 11.97. We apply the analytical expression for the Beneš marginal density for computing log p t (x), p t (x) = 1 √ 2πt cosh(x) cosh(x 0 ) exp - 1 2 t exp - 1 2t (x -x 0 ) 2 . ( ) See the Beneš SDE trajectories in Fig. 8a . As expected, the transport model with no observations performs well in the generative task, but its trajectories cover also some low-likelihood space around t = 6 (in the middle part in Fig. 8b ). The observations for the ISB model were sampled from the generated trajectories, 10 observations at 10 random time-instances (see Fig. 8c ) Both the unconstrained Schrödinger bridge model and the ISB model were ran for 3 iterations, using a learning rate of 0.001 for the neural networks. Likely due to the fact that the problem was only one-dimensional, convergence of the Schrödinger bridge to a process which matches the desired terminal state was fast, and we chose not to run the model for a higher number of ISB iterations, see Fig. 7 for a comparison of the trained model marginal densities and the true terminal distribution π T . We set the observation noise schedule to the constant 0.7, and at each iteration of the ISB or the unconstrained Schrödinger bridge the drift neural networks were trained for 5000 iterations each with the batch size 256, and the trajectories were refreshed every 500 iterations with a cache size of 1000 particles. The number of nearest neighbours to compare to was H = 10.

B.3 THE BIRD MIGRATION DATA SET

The ISB model learned bird migration trajectories which transport the particles from the Northern Europe summer habitats to the southern winter habitats, see Fig. 10 for a comparison of a Schrödinger bridge and ISB. Since the problem lies on a sphere, Schrödinger bridge methods adjusted for learning on Riemannian manifolds could have been applied here. For simplicity we mapped the probelm to a two-dimensional plane using a Mercator projection, and solved the problem on a [0, 5] × [0, 5] square. The SDE had the discretization t ∈ [0, 0.99], ∆ k = 0.01 and a constant process noise g(t) 2 = 0.05. The model was trained for 12 iterations, and initialized with a zero drift, while the observational data was chosen by the authors to promote learning trajectories clearly different from the unconstrained transport trajectories. The observation noise schedule was piecewise linear (starting at 2, going to 0.1 at iteration 6, then rising linearly to reach 2 at iteration 12). At each ISB iteration, the neural networks were trained for 5000 iterations each, and the trajectories refreshed every 1000 iterations. We used a batch size of 256 and learning rate 0.001. 

B.4 THE MNIST GENERATION TASK

Applying state-space model approaches such as particle filtering and smoothing to generative diffusion models directly in the observation space (that is, not in a lower-dimensional latent space) has to our knowledge not been explored before. Some experimental design choices had a great impact into the training objectives sensibility, as the observational data is completely artificial and its timing during the process modifies the filtering distribution significantly. As the MNIST conditional generative model was trained to display the scalability of our method beyond low-dimensional toy examples, we did not further explore optimizing the hyperparameters or the observation model. To avoid the background noise in MNIST images in the middle of the generative process impacting the particle filtering weights excessively, the observation model is a Gaussian with masked inputs equal to zero in pixels where the observation image is black, see Fig. 9 for sampled trajectories. The figure shows the progression of seven samples, where the lower half of the eights resemble the observation target. The SDE was run for time t ∈ [0, 0.5], with the digit eight observed at t = 0.38. The ISB method was applied for 10 iterations, with a discretization t ∈ [0, 0.495], ∆ k = 0.005, and the process noise g(t) 2 followed a linear schedule from 0.0001 to 1. At each iteration of the method, the forward and backward drift neural networks were trained for 5000 iterations with a batch size of 256, and the trajectory cache regenerated every 1000 iterations. The observational data consisted of a single sample of a lower half of the digit eight, observed at time t = 0.38. The observation noise schedule was a constant κ(l) = 0.3.

B.5 SINGLE-CELL DATA SET

We directly use the preprocessed data from the TrajectoryNet (Tong et al., 2020) repository. A major difference between our implementation and Vargas et al. (2021) is the reference drift. We set the reference drift to zero, which means that we utilize the intermediate data only as observations in the state-space model. On the contrary, Vargas et al. ( 2021) fits a mixture model of 15 Gaussians on the combined data set (across all measurement times) and sets the reference drift to the gradient of the log likelihood of the mixture model. Effectively, such a reference drift aids in keeping the SDE trajectories within the support of the combined data set. We remark that if the intermediate observed marginals had clearly disjoint support, combining all the data would cause the mixture model to have 'gaps' and could cause an unstable reference model drift. Thus we consider our approach of setting the reference drift to zero as more generally applicable. As in Vargas et al. (2021) , we set the process noise to g(t) = 1 and model the SDE between time t ∈ [0, 4]. The learning rate is set to 0.001 with batch size 256 and number of neural network training iterations 5000, and we apply the ISB for 6 iterations. We filter using 1000 points from the intermediate data sets, but compute the Earth mover's distance by a comparison to all available data. As the observational data at T = 1, 2, 3 consists of a high number of data points, the parameters H (number of nearest neighbours) and σ (observation noise) need to be carefully set. We set H = 10 to only include the close neighbourhood of each particle, and set the observation noise schedule as constant 0.7.

C COMPUTATIONAL CONSIDERATIONS

In Sec. 3.2, we raised a number of important computational considerations for the constrained transport problem. Below we discuss them in detail, analyzing the limit L → ∞ from the perspective of setting the observation noise schedule in App. C.1, and presenting ablation results on modifying the initial drift in the bird migration experiment in App. C.2.

C.1 DISCUSSION ON OBSERVATION NOISE

We briefly mentioned in Sec. 3.2 that when letting L → ∞, the choice of observation noise should be carefully planned in order for the ISB procedure to have a stationary point. Here we explain why an unbounded observation noise schedule κ(l) implies convergence to the IPF method for uncontrolled Schrödinger bridges (De Bortoli et al., 2021) , when using a nearest neighbour bootstrap filter as the proposal density. Proposition 3. Let Ω ∈ R d be a bounded domain where both the observations and SDE trajectories lie, and let the particle filtering weights {w i l,t k } N i=1 be as in Eq. ( 11), but after normalization. If the schedule κ(l) is unbounded with respect to l, then for any δ there exists l ′ such that for the normalized weights it holds | ŵi l ′ ,t k - 1 N | ≤ δ. Proof sketch. Since we set the proposal density to be the bootstrap filter, the observation weights at ISB iteration l are equal to log w i l,t k = - 1 2κ(l) 2 yj ∈D H t k x i t k -y j 2 . ( ) Since κ(l) is unbounded, for any S > 0 ∃ l ′ such that κ(l ′ ) ≥ S. We choose the value of S so that the following derivation yields Eq. ( 40). Let S = 0.5R -1 |D H t k | diam(Ω) 2 , and apply the property that x i t k -y j 2 ≤ diam(Ω) 2 to Eq. ( 41), log w i l ′ ,t k ≥ - 1 2S 2 yj ∈D H t k x i t k -y j 2 ≥ - yj ∈D H t k x i t k -y j 2 R -1 |D H t k | diam(Ω) 2 ≥ - yj ∈D H t k diam(Ω) 2 R -1 |D H t k | diam(Ω) 2 ≥ -R. The bound above is for the unnormalized weights, and the normalized log-weights are defined as log ŵi l ′ ,t k = log w i l ′ ,t k -log N j=1 exp(log w j l ′ ,t k ) , where for the normalizing constant it holds that  since w j l ′ ,t k is the value of a probability density and thus always w j l ′ ,t k ≤ 1. Combining Eq. ( 43), Eq. ( 42) and Eq. ( 44), it follows that log ŵi l ′ ,t k -(-log(N ) ≥ -R, where taking exponentials on both sides gives ŵi l ′ ,t k - 1 N ≥ -(1 -exp(-R)) 1 N .



Figure 1: Illustrative example transport between an initial unit Gaussian and a shifted unit Gaussian at the terminal time T . Unconstrained transport on the left and the solution constrained by sparse observations ( ) on the right. Colour coding of the initial points is only for distinguishing the paths.

Figure2: Sketch of a diffusion bridge between a 2D data distribution (π 0 ) and an isotropic Gaussian (π T ) constrained by sparse observations ( ). The forward diffusion at the first iteration (ISB 1) learns to account for the sparse observations but does not converge to the correct terminal distribution (t = T ), and the backward diffusion vice versa. After iterating (ISB 6), the forward and backward diffusions converge to the correct targets and are able to account for the sparse observational data.

Figure 3: 2D toy experiments from scikit-learn with both cases starting from a Gaussian: The TWO CIRCLES (top) and TWO MOONS (bottom) data sets, with observations (red markers) constraining the problem. For the circles, the 10 circular observations at t = 0.5 first force the method to create a circle that then splits into two; in the lower plot the observations at t ∈ [0.25, 0.5, 0.75] split the data into clusters before joining them into two moons. See Fig. 6 in the Appendix for the IPFP result.

Figure 4: Bird migration example. The top row describes nesting and wintering areas for the birds as well as example sightings during migration. At the bottom, we show the marginal densities of the ISB model from the initial to terminal distribution that match the sightings along the migration.

Figure 5: Illustration of the trajectories of the high-dimensional single-cell experiment for the Schrödinger bridge (a) and the ISB (b), projected onto the first two principal components. The first five trajectories are highlighted in colour, and intermediate observation densities visualized as slices.

the performance of previous works that do not use the intermediate data during training(Tong et al., 2020) or only use it to construct an informative reference drift(Vargas et al., 2021), see App. B.5 for details. In both cases, ISB outperforms the other approaches w.r.t. the intermediate marginal distributions (t = 1, 2, 3), while IPML(Vargas et al., 2021) outperforms ISB at the initial and terminal stages due to its data-driven reference drift. Notice that while we reduced the dimensionality via PCA to 5 for fair comparisons toVargas et al. (2021), the ISB model would also allow modelling the full state-space model, with observations in the high-dimensional gene space and a latent SDE.

Figure 6: The IPFP result for the experiment in Fig. 3 in the main paper.2D toy experiments, where observations (red markers) not used while training but included in the figure for reference. The dynamics learned by IPFP are clearly different from the ISB learned dynamics.

Figure 7: A kernel density estimate of the Beneš SDE terminal state. We compare π T to the Schrödinger bridge and ISB terminal states. Both unconstrained Schrödinger bridge and ISB terminal states succeed in representing π T well, with the Schrödinger bridge terminal state more closely matching π T near its mean.

Figure 8: Comparison of the solution for the SBP (with Beneš SDE reference drift) and the ISB (with zero initial drift) on the Beneš SDE under sparse observations ( ). The target distribution π T is slightly shifted and scaled from the Beneš SDE. Even if the SBP has the true model as reference drift, its trajectories degenerate into a unimodal distribution, while the ISB manages to cover both modes even if only sparse observations are available.

Figure 9: Model trajectories for MNIST digit '8' conditioned on a lower-loop of a single '8' at t = 0.38 to bias the lower half of the digits to look alike, with the effect still visible at terminal time T .

Finally, we demonstrate our model both in a highly multimodal bird migration task, conditioned image generation, and in a single-cell embryo RNA modelling problem. Ablation studies are found in App. C.Experiment setupIn all experiments, the forward and backward drift functions f θ and b φ are parametrized as neural networks. For low-dimensional experiments we apply the MLP block design as in DeBortoli et al. (2021), and for the image experiment an U-Net as inNichol & Dhariwal (2021). The latent state SDE was simulated by Euler-Maruyama with a fixed time-step of 0.01 over 100 steps

Results for the Beneš experiment.

Results for single-cell embryo RNA experiment.

Gefei Wang, Yuling Jiao, Qian Xu, Yang Wang, and Can Yang. Deep generative learning via Schrödinger bridge. In Proceedings of the 38th International Conference on Machine Learning (ICML), volume 139 of Proceedings of Machine Learning Research, pp. 10794-10804. PMLR, 2021.

Proposition 3 of De Bortoli et al. (2021), we derive an expression for the score function. Since ∇ ln p t k+1 (x t k+1 ) =

annex

Marginal densities of our model, using both initial and terminal distributions and observational data and a zero drift initialization. Bottom row: Same as third row, but with the second row dynamics as initialization.Since the weights are normalized, even the largest particle weight ŵj l ′ ,t k can differ from 1 N as much as every smaller weight in total lies under 1 N ,implying that for any weight ŵj l ′ ,t k , it holds thatand selecting R = -log(1 -δ) is sufficient for δ < 1.Effectively, the above derivation implies that for an unbounded observation noise schedule κ(l), the particle weights will converge to uniform weights. Since performing differentiable resampling on uniform weights implies that T (ε) = I N , the ISB method trajectory generation step and the objective in training Nthe backward drift converge to those of the IPF method for solving unconstrained Schrödinger bridges. Intuitively, this means that at the limit L → ∞, our method will focus on reversing the trajectories and matching the terminal distribution while not further utilizing information from the observations.

C.2 ABLATION ON INITIAL DRIFT

We conducted an ablation study on drift initialization for the bird migration problem. As the distributions π 0 and π T (as pictured in Fig. 10 ) are complex, we consider the problem setting to be interesting for setting f 0 as the unconstrained transport problem drift. To this end, we trained a Schrödinger bridge model for 10 epochs, and trained an ISB model with the same hyperparameter selections as explained in App. B.3, using the Schrödinger bridge as the initialization. Compare the two bottom rows of Fig. 10 to see a selection of marginal densities of the two processes. Based on a visual analysis of the densities, it seems that the zero drift and pre-trained diffusion model initializations produce similar results around the observations, although the Schrödinger bridge initialization gave slightly sharper results at terminal time.

D DIFFERENTIABLE RESAMPLING

In the ISB model steps 1 and 3 presented in Sec. 3.1, we applied differentiable resampling (see Corenflos et al., 2021) . Resampling itself is a basic block of particle filtering. A differentiable resampling step transports the particles and weights (x i t k , w i t k ) to an uniform distribution over a set of particles through applying the differentiable ensemble transport map T (ε) , that iswhere Xt k ∈ R N ×d denotes the stacked particles {x i t k } N i=1 at time t k before resampling and x i t k denotes the particles post resampling. Here we give the definition of the map T (ε) and review the regularized optimal transport problem which has to be solved to compute it. We partly follow the presentation in Sections 2 and 3 of Corenflos et al. (2021) , but directly apply the notation we use for particles and weights and focus on explaining the transport problem rather than the algorithm used to solve it.The standard particle filtering resampling step consists of sampling N particles from the categorical distribution defined by the weights {w i t k } N i=1 , resulting in the particles with large weights being most likely to be repeated multiple times. A result from Reich (2013) gives the property that the random resampling step can be approximated by a deterministic ensemble transform T. In heuristic terms, the ensemble transform map will be selected so that the particles {x i t k } N i=1 will be transported with minimal cost, while allowing all the weights to be uniform.Let µ and ν be atomic measures, µ =, where δ x is the Dirac delta at x. Then µ is the particle filtering distribution before resampling. Define the elements of a cost matrix2 , and the 2-Wasserstein distance between two atomic measures asAbove the optimal matrix P is to be found within S(µ, ν), which is a space consisting of mixtures of N particles to N particles such that the marginals coincide with the weights of µ and ν, formallyThe entropy-regularized Wasserstein distance with regularization parameter ε is thenThe unique minimizing transport map of the above Wasserstein distance is denoted by P OPT ε , and the ensemble transport map is then set as T (ε) = N P OPT ε . This means that we can find the matrix T (ε) via minimizing the regularized Wasserstein distance, which is done by applying the iterative Sinkhorn algorithm for entropy-regularized optimal transport (Cuturi, 2013) .

