LATENT NEURAL ODES WITH SPARSE BAYESIAN MULTIPLE SHOOTING

Abstract

Training dynamic models, such as neural ODEs, on long trajectories is a hard problem that requires using various tricks, such as trajectory splitting, to make model training work in practice. These methods are often heuristics with poor theoretical justifications, and require iterative manual tuning. We propose a principled multiple shooting technique for neural ODEs that splits the trajectories into manageable short segments, which are optimised in parallel, while ensuring probabilistic control on continuity over consecutive segments. We derive variational inference for our shooting-based latent neural ODE models and propose amortized encodings of irregularly sampled trajectories with a transformer-based recognition network with temporal attention and relative positional encoding. We demonstrate efficient and stable training, and state-of-the-art performance on multiple largescale benchmark datasets.

1. INTRODUCTION

Dynamical systems, from biological cells to weather, evolve according to their underlying mechanisms, often described by differential equations. In data-driven system identification we aim to learn the rules governing a dynamical system by observing the system for a time interval [0, T ], and fitting a model of the underlying dynamics to the observations by gradient descent. Such optimisation suffers from the curse of length: complexity of the loss function grows with the length of the observed trajectory (Ribeiro et al., 2020) . For even moderate T the loss landscape can become highly complex and gradient descent fails to produce a good fit (Metz et al., 2021) . To alleviate this problem previous works resort to cumbersome heuristics, such as iterative training and trajectory splitting (Yildiz et al., 2019; Kochkov et al., 2021; HAN et al., 2022; Lienen & Günnemann, 2022) . The optimal control literature has a long history of multiple shooting methods, where the trajectory fitting is split into piecewise segments that are easy to optimise, with constraints to ensure continuity across the segments (van Domselaar & Hemker, 1975; Bock & Plitt, 1984; Baake et al., 1992) . Multiple-shooting based models have simpler loss landscapes, and are practical to fit by gradient descent (Voss et al., 2004; Heiden et al., 2022; Turan & Jäschke, 2022; Hegde et al., 2022) . Inspired by this line of work, we develop a shooting-based latent neural ODE model (Chen et al., 2018; Rubanova et al., 2019; Yildiz et al., 2019; Massaroli et al., 2020) . Our multiple shooting formulation generalizes standard approaches by sparsifying the shooting variables in a probabilistic setting to account for irregularly sampled time grids and redundant shooting variables. We furthermore introduce an attention-based (Vaswani et al., 2017) encoder architecture for latent neural ODEs that is compatible with our sparse shooting formulation and can handle noisy and partially observed high-dimensional data. Consequently, our model produces state-of-the-art results, naturally handles the problem with long observation intervals, and is stable and quick to train. Our contributions are: • We introduce a latent neural ODE model with quick and stable training on long trajectories. • We derive sparse Bayesian multiple shooting -a Bayesian version of multiple shooting with efficient utilization of shooting variables and a continuity-inducing prior. • We introduce a transformer-based encoder with novel time-aware attention and relative positional encodings, which efficiently handles data observed at arbitrary time points. Figure 2 : Method overview with two blocks (see Section 3.1). The encoder maps the input sequence y 1:5 observed at arbitrary time points t 1:5 to two distributions q ψ1 (s 1 ), q ψ2 (s 2 ) from which we sample shooting variables s 1 , s 2 . Then, s 1 , s 2 are used to compute two sub-trajectories that define the latent trajectory x 1:5 from which the decoder reconstructs the input sequence.

2. PROBLEM SETTING AND BACKGROUND

Data. We observe a dynamical system at arbitrary consecutive time points t 1:N = (t 1 , ..., t N ), which generates an observed trajectory y 1:N = (y 1 , . . . , y N ), where y i := y(t i ) ∈ R D . Our goal is to model the observations and forecast the future states. For brevity we present our methodology for a single trajectory, but extension to many trajectories is straightforward. An L-NODE model is defined as: x i = ODEsolve(x 1 , t 1 , t i , f θdyn ), i = 2, ..., N, y i |x i ∼ p(y i |g θdec (x i )), i = 1, ..., N. (2) Variable x 1 is the initial state at time t 1 . Dynamics function f θdyn is the time derivative of x(t), and ODEsolve(x 1 , t 1 , t i , f θdyn ) is defined as the solution of the following initial value problem at time t i : dx(t) dt = f θdyn (t, x(t)), x(t 1 ) = x 1 , t ∈ [t 1 , t i ]. (3) Decoder g θdec maps the latent state x i to the parameters of p(y i |g θdec (x i )). Dynamics and decoder functions are neural networks with parameters θ dyn and θ dec . In typical applications, data is high-dimensional whereas the dynamics are modeled in a low-dimensional latent space, i.e., d ≪ D. L-NODE models are commonly trained by minimizing a loss function, e.g., evidence lower bound (ELBO), via gradient descent (Chen et al., 2018; Yildiz et al., 2019) . In gradient-based optimization complexity of the loss landscape plays a crucial role in the success of the optimization. However, it has been empirically shown that the loss landscape of L-NODE-like models (i.e., models that compute latent trajectory x 1:N from initial state x 1 ) is strongly affected by the length of the simulation interval [t 1 , t N ] (Voss et al., 2004; Metz et al., 2021; Heiden et al., 2022 ). Furthermore, Ribeiro et al. (2020) show that the loss complexity in terms of Lipschitz constant can grow exponentially with the length of [t 1 , t N ]. Figure 1 shows an example of this phenomenon (details in Appendix A).



Figure 1: Top: Train loss of L-NODE model using iterative training heuristic. We start training on a short trajectory (N = 10), and double its length every 3000 iterations. The training fails for the longest trajectory. Bottom: 1-D projection of the loss landscape around the parameters to which the optimizer converged for a given trajectory length. Complexity of the loss grows dramatically with N .

