S-SOLVER: NUMERICALLY STABLE ADAPTIVE STEP SIZE SOLVER FOR NEURAL ODES

Abstract

A neural ordinary differential equation (ODE) is a relation between an unknown function and its derivatives, where the ODE is parameterized by a neural network. Therefore, to obtain a solution to a neural ODE requires a solver that performs numerical integration. Dopri5 is one of the most popular neural ODE solvers and also the default solver in torchdiffeq, a PyTorch library of ODE solvers. It is an adaptive step size solver based on the Runge-Kutta (RK) numerical methods. These methods rely on estimation of the local truncation error to select and adjust integration step size, which determines the numerical stability of the solution. A step size that is too large leads to numerical instability, while a step size that is too small may cause the solver to take unnecessarily many steps, which is computationally expensive and may even cause rounding error build up. Therefore, accurate local truncation error estimation is paramount for choosing an appropriate step size to obtain an accurate, numerically stable, and fast solution to the ODE. In this paper we propose a novel local truncation error approximation that is the first to consider solutions of four different RK orders to obtain a more reliable error estimate. This leads to a novel solver S-SOLVER (Stable Solver), which is more numerically stable; and therefore accurate. We demonstrate S-SOLVER's competitive performance in experiments on image recognition with ODE-Net, learning hamiltonian dynamics with Symplectic ODE-Net, and continuous normalizing flows (CNF).

1. INTRODUCTION

Neural ODEs are continuous depth deep learning models that combine neural networks and ODEs. Since their first introduction in (Chen et al., 2018) , they have been used in many applications such as: stochastic differential equations (Li et al., 2020) , physically informed modeling (Sanchez-Gonzalez et al., 2019; Zhong et al., 2020) , free-form continuous generative models (Grathwohl et al., 2019; Finlay et al., 2020) , mean-field games (Ruthotto et al., 2020) , and irregularly sampled time-series (Rubanova et al., 2019) . Neural ODEs parameterize the derivative of the hidden state using a neural network; and therefore, learn non-linear mappings via differential equations. A differential equation is a relation between an unknown function and its derivatives. Ordinary differential equations describe the change of only one variable (as opposed to multiple) with respect to time, i.e.: dx/dt = f (t, x). Typically, an ODE is formulated as an initial value problem (IVP), which has the following form. Given a function derivative dx/dt, a time interval t = (a, b) and an initial value (e.i.: x at time t = a), the solution to the IVP yields x evaluated at time t = b. The method for approximating x(b) is numerical integration; therefore, all the various ODE solvers include different methods for performing integration. Adaptive step size solvers are amongst the most popular solvers for neural ODEs. In fact, the default solver in torchdiffeq (a library of ODE solvers implemented in PyTorch) is Dopri5, the Dormand-Prince 5(4) embedded adaptive step size method of the Runge-Kutta (RK) family. Adaptive step size RK solvers perform two approximations: one of order p and another of p -1 and compare them to obtain the local truncation error, which is used to determine the integration step size. Specifically, the error is used to make a decision whether to accept or reject the solution step under the current step size and to decide how to modify the step size for the next step. A step size that is too large leads to numerical instability, while a step size that is too small may cause the solver to take unnecessarily many steps, which is computationally expensive and may even cause the rounding error to build up. Therefore, accurate local estimation is paramount for choosing an appropriate step size to obtain an accurate, numerically stable, and fast solution to the ODE. The local truncation error is defined as the difference between the exact and approximate solution obtained at a given time step. All currently available adaptive step neural ODE solvers rely on estimating the local error as the difference between order p and p -1 solutions, which assumes that the order p solution is exact. This is not necessarily true and if the p solution is far from the exact one, the local error estimate is inaccurate, which results in the solver making poor decisions regarding its step size. In this paper we propose a novel local truncation error estimation that takes into account multiple orders of the RK method as opposed to just order p and p -1 to obtain a more accurate estimate of the local truncation error that guides the integration step size. Specifically, we modify the local truncation error estimation of Dopri8, the Dormand-Prince 8(7) embedded adaptive step size method. Dopri8 calculates the local truncation error as the difference between its 8th and 7th order solution. Our modification computes this error as the average of the difference between both its 8th and 7th, and also 4th and 5th order solution. This leads to a new ODE solver, S-SOLVER (Stable Solver), a modified Dopri8 integrator with more accurate local truncation error estimation that provides more reliable information for step size calculations; and therefore, more numerically stable solution. To our best knowledge, S-SOLVER is the first solver that uses a multiple solution orders to estimate local truncation error for adjusting its step size.

2.1. NEURAL ORDINARY DIFFERENTIAL EQUATIONS

Traditional neural networks are defined as discrete models with a discrete sequence of hidden layers, where the depth of the network corresponds to the number of layers. Neural ODEs (Chen et al., 2018) are continuous depth deep learning models, which parameterize the derivative of the hidden state using a neural network. Specifically, they are ODEs that are parameterized by a neural network, which has many benefits such as memory efficiency, adaptive computation, and parameter efficiency. Neural ODEs are inspired by the dynamic systems interpretation of residual and other networks Haber et al. (2018); Weinan (2017) . These networks perform a sequence of transformations to a hidden state: state t+1 = state t + f (state t , θ t ), which can be viewed as discretized forward Euler method applied to a continuous transformation. Given this interpretation, the transformation to a hidden state can be formulated as an ODE: d state(t)/dt = f (state(t), t, θ), where state(t = 0) is the input layer and state(t = T ) is the output layer. Therefore, the neural ODE is an IVP: dx(t)/dt = f (t, x(t), θ), f or t 0 ≤ t ≤ t 1 , subject to x(t 0 ) = x t0 , where f (., ., θ) is the deep neural network, x t0 is the input, and x t1 is the output. Neural ODEs are trainable through loss minimization, but due to their continuous nature the optimization process is slightly different from classical discrete deep learning models. The forward pass solves the ODE with an ODE solver and the backward pass computes the gradients either by backpropagating through the ODE solver or with the adjoint method (Chen et al., 2018) . In this work we focus on the forward pass, which outputs a solution to the ODE.

