NEURAL JUMP ORDINARY DIFFERENTIAL EQUATIONS: CONSISTENT CONTINUOUS-TIME PREDICTION AND FILTERING

Abstract

Combinations of neural ODEs with recurrent neural networks (RNN), like GRU-ODE-Bayes or ODE-RNN are well suited to model irregularly observed time series. While those models outperform existing discrete-time approaches, no theoretical guarantees for their predictive capabilities are available. Assuming that the irregularly-sampled time series data originates from a continuous stochastic process, the L 2 -optimal online prediction is the conditional expectation given the currently available information. We introduce the Neural Jump ODE (NJ-ODE) that provides a data-driven approach to learn, continuously in time, the conditional expectation of a stochastic process. Our approach models the conditional expectation between two observations with a neural ODE and jumps whenever a new observation is made. We define a novel training framework, which allows us to prove theoretical guarantees for the first time. In particular, we show that the output of our model converges to the L 2 -optimal prediction. This can be interpreted as solution to a special filtering problem. We provide experiments showing that the theoretical results also hold empirically. Moreover, we experimentally show that our model outperforms the baselines in more complex learning tasks and give comparisons on real-world datasets.

1. INTRODUCTION

Stochastic processes are widely used in many fields to model time series that exhibit a random behaviour. In this work, we focus on processes that can be expressed as solutions of stochastic differential equations (SDE) of the form dX t = µ(t, X t )dt + σ(t, X t )dW t , with certain assumptions on the drift µ and the diffusion σ. With respect to the L 2 -norm, the best prediction of a future value of the process is provided by the conditional expectation given the current value. If the drift and diffusion are known or a good estimation is available, the conditional expectation can be approximated by a Monte Carlo (MC) simulation. However, since µ and σ are usually unknown, this approach strongly depends on the assumptions made on their parametric form. A more flexible approach is given by neural SDEs, where the drift µ and diffusion σ are modelled by neural networks (Tzen & Raginsky, 2019; Li et al., 2020; Jia & Benson, 2019) . Nevertheless, modelling the diffusion can be avoided if one is only interested in forecasting the behaviour instead of sampling new paths. An alternative widely used approach is to use Recurrent Neural Networks (RNN), where a neural network dynamically updates a latent variable with the observations of a discrete input time-series. RNNs are successfully applied to tasks for which time-series are regularly sampled, as for example speech or text recognition. However, often observations are irregularly observed in time. The standard approach of dividing the time-line into equally-sized intervals and imputing or aggregating observations might lead to a significant loss of information (Rubanova et al., 2019) . Frameworks that overcome this issue are the GRU-ODE-Bayes (Brouwer et al., 2019) and the ODE-RNN (Rubanova et al., 2019) , which combine a RNN with a neural ODE (Chen et al., 2018) . In standard RNNs, the hidden state is updated at each observation and constant in between. Conversely, in the GRU-ODE-Bayes and ODE-RNN framework, a neural ODE is trained to model the continuous evolution of the hidden state of the RNN between two observations. While GRU-ODE-Bayes and ODE-RNN both provide convincing empirical results, they lack thorough theoretical guarantees. Contribution. In this paper, we introduce a mathematical framework to precisely describe the problem statement of online prediction and filtering of a stochastic process with temporal irregular observations. Based on this rigorous mathematical description, we introduce the Neural Jump ODE (NJ-ODE). The model architecture is very similar to the one of GRU-ODE-Bayes and ODE-RNN, however we introduce a novel training framework, which in contrast to them allows us to prove convergence guarantees for the first time. Moreover, we demonstrate empirically the capabilities of our model. Precise problem formulation. We emphasize that a precise definition of all ingredients is needed, to be able to show theoretical convergence guarantees, which is the main purpose of this work. Since the objects of interest are stochastic processes, we use tools from probability theory and stochastic calculus. To make the paper more readable and comprehensible also for readers without background in these fields, the precise formulations and demonstrations of all claims are given in the appendix, while the main part of the paper focuses on giving well understandable heuristics.

2. PROBLEM STATEMENT

The problem we consider in this work, is the online forecasting of temporal data. We assume that we make observations of a Markovian stochastic process described by the stochastic differential equation (SDE) dX t = µ(t, X t )dt + σ(t, X t )dW t , at irregularly-sampled time points. Between those observation times, we want to predict the stochastic process, based only on the observations that we made previously in time, excluding the possibility to interpolate observations. Due to the Markov property, only the last observation is needed for an optimal prediction. Hence, after each observation we extrapolate the current observation into the future until we make the next observation. The time at which the next observation will be made is random and assumed to be independent of the stochastic process itself. More precisely, we suppose to have a training set of N independent realisations of the R d Xdimensional stochastic process X defined in (1). Each realisation j is observed at n j random observation times t (j) 1 , . . . , t (j) nj ∈ [0, T ] with values x (j) 1 , . . . , x (j) nj ∈ R d X . We assume that all coordinates of the vector x (j) i are observed. We are interested in forecasting how a new independent realization evolves in time, such that our predictions of X minimize the expected squared distance (L 2 -metric) to the true unknown path. The optimal prediction, i.e. the L 2 -minimizer, is the conditional expectation. Given that the value of the new realization at time t is x t , we are therefore interested in estimating the function f (x t , t, s) := E[X t+s |X t = x t ], s ≥ 0, which is the L 2 -optimal prediction until the next observation is made. To learn an approximation f of f we make use of the N realisations of the training set. After training, f is applied to the new realization. Hence, this can be interpreted as a special type of filtering problem. The following example illustrates the considered problem. Example. A complicated to measure vital parameter of patients in a hospital is measured multiple times during the first 48 hours of their stay. For each patient, this happens at different times depending on the resources, hence the observation dates are irregular and exhibit some randomness. Patient 1 has n 1 = 4 measurements at hours (t (1) 1 , t 2 , t 3 , t 4 ) = (1, 14, 27, 34) where the values (x 



(0.74, 0.65, 0.78, 0.81)  are measured. Patient 2 only has n 2 = 2 measurements at hours (t

(3, 28)  where the values (x (0.56, 0.63) are measured. Similarly, the j-th patient has n j measurements at times (t (j) 1 , . . . , t (j) nj ) and has the measured values (x (j) 1 , . . . , x (j) nj ). Based on this data, we want to forecast the vital parameter of new patients coming to the hospital. In particular, for a patient with measured values x 1 at time t 1 , we want to predict what

