TRANSPORT WITH SUPPORT: DATA-CONDITIONAL DIFFUSION BRIDGES

Abstract

The dynamic Schrödinger bridge problem provides an appealing setting for posing optimal transport problems as learning non-linear diffusion processes and enables efficient iterative solvers. Recent works have demonstrated state-of-the-art results (e.g., in modelling single-cell embryo RNA sequences or sampling from complex posteriors) but are typically limited to learning bridges with only initial and terminal constraints. Our work extends this paradigm by proposing the Iterative Smoothing Bridge (ISB). We combine learning diffusion models with Bayesian filtering and optimal control, allowing for constrained stochastic processes governed by sparse observations at intermediate stages and terminal constraints. We assess the effectiveness of our method on synthetic and real-world data and show that the ISB generalises well to high-dimensional data, is computationally efficient, and provides accurate estimates of the marginals at intermediate and terminal times.

1. INTRODUCTION

Generative diffusion models have gained increasing popularity and achieved impressive results in a variety of challenging application domains, such as computer vision (e.g., Ho et al., 2020; Song et al., 2021a; Dhariwal & Nichol, 2021) , reinforcement learning (e.g., Janner et al., 2022) , and time series modelling (e.g., Rasul et al., 2021; Vargas et al., 2021; Tashiro et al., 2021; Park et al., 2022) . Recent works have explored connections between denoising diffusion models and the dynamic Schrödinger bridge problem (SBP, e.g., Vargas et al., 2021; De Bortoli et al., 2021; Shi et al., 2022) to adopt iterative schemes for solving the dynamic optimal transport problem more efficiently. The solution of the SBP that correspond to denoising diffusion models is then given by the finite-time process, which is the closest in Kullback-Leibler (KL) divergence to the forward noising process of the diffusion model under marginal constraints. Data is then generated by time-reversing the process. In many applications, the interest is not purely in modelling transport between an initial and terminal state distribution In naturally occurring generative processes, we typically observe snapshots of realizations along intermediate stages of individual sample trajectories (see Fig. 1 ). Such problems arise in medical diagnosis (e.g., tissue changes and cell growth), demographic modelling, environmental dynamics, and animal movement modelling-see Fig. 4 for modelling bird migration and wintering patterns. Recently, constrained optimal control problems have been explored by adding additional fixed path constraints (Maoutsa et al., 2020; Maoutsa & Opper, 2021) or modifying the prior processes (Fernandes et al., 2021) . However, defining meaningful fixed path constraints or prior processes for the optimal control problems can be challenging, while sparse observational data are accessible in many real-world applications. In this work, we propose the Iterative Smoothing Bridge (ISB), an iterative method for solving control problems under sparse observational data constraints and constraints on the initial and terminal distribution. We perform the conditioning by leveraging the iterative pass idea from the Iterative Proportional Fitting procedure (IPFP) (Kullback, 1968; De Bortoli et al., 2021) procedure and applying differentiable particle filtering (Reich, 2013; Corenflos et al., 2021) We summarize the contributions as follows. (i) We propose a novel method for solving constrained optimal transport as a bridge problem under sparse observational data constraints. (ii) Thereof, we utilize the strong connections between the constrained bridging problem and particle filtering in sequential Monte Carlo, extending them from pure inference to learning. Additionally, (iii) we demonstrate practical efficiency and show that the iterative smoothing bridge approach scales to high-dimensional data.

Schrödinger bridges

The problem of learning a stochastic process moving samples from one distribution to another can be posed as a type of a transport problem known as a dynamic Schrödinger bridge problem (SBP, e.g., Schrödinger, 1932; Léonard, 2014) , where the resulting marginal densities are desired to resemble a given reference measure. In machine learning literature, the problem has been studied through learning the drift function of the dynamical system (De Bortoli et al., 2021; Wang et al., 2021; Vargas et al., 2021; Bunne et al., 2022) . When an SDE system also defines the reference measure, the bridge problem becomes synonymous with a constrained optimal control problem (e.g., Caluya & Halder, 2022; 2021; Chen et al., 2021) , which has been leveraged in learning Schrödinger bridges by Tianrong Chen (2022) through forward-backward SDEs. An optimal control problem with both constraints on the initial and terminal distribution and a fixed path constraint has been studied in Maoutsa et al. (2020) and Maoutsa & Opper (2021) , where particle filtering is applied to continuous path constraints but the boundary constraints are defined by a single point. Furthermore, the combination of Schrödinger bridges and state-space models has been studied by Reich (2019) , in a setting where Schrödinger bridges are applied to the transport problem between filtering distributions.

Diffusion models in machine learning

The recent advances in diffusion models in machine learning literature have been focused in generating samples from complex distributions defined by data through transforming samples from an easy-to-sample distribution by a dynamical system (e.g., Ho et al., 2020; Song et al., 2021b; a; Nichol & Dhariwal, 2021) . The concept of reversing SDE trajectories via score-based learning (Hyvärinen & Dayan, 2005; Vincent, 2011) has allowed for models scalable enough to be applied to high-dimensional data sets directly in the data space. In earlier work, scorebased diffusion models have been applied to problems where the dynamical system itself is of interest, for example, for the problem of time series amputation in Tashiro et al. ( 2021) and inverse problems in imaging in Song et al. (2022) . Other dynamical models parametrized by neural networks have been applied to modelling latent time-series based on observed snapshots of dynamics (Rubanova et al., 2019; Li et al., 2020) , but without further constraints on the initial or terminal distributions.

State-space models

In their general form, state-space models combine a latent space dynamical system with an observation (likelihood) model. Evaluating the latent state distribution based on observational data can be performed by applying particle filtering and smoothing (Doucet et al., 2000) or by approximations of the underlying state distribution of a non-linear state-space model by a specific model family, for instance, a Gaussian (see Särkkä, 2013, for an overview). Speeding up parameter inference and learning in state-space models has been widely studied (e.g., Schön



Figure 1: Illustrative example transport between an initial unit Gaussian and a shifted unit Gaussian at the terminal time T . Unconstrained transport on the left and the solution constrained by sparse observations ( ) on the right. Colour coding of the initial points is only for distinguishing the paths.

