"HEY, THAT'S NOT AN ODE": FASTER ODE ADJOINTS WITH 12 LINES OF CODE

Abstract

Neural differential equations may be trained by backpropagating gradients via the adjoint method, which is another differential equation typically solved using an adaptive-step-size numerical differential equation solver. A proposed step is accepted if its error, relative to some norm, is sufficiently small; else it is rejected, the step is shrunk, and the process is repeated. Here, we demonstrate that the particular structure of the adjoint equations makes the usual choices of norm (such as L 2 ) unnecessarily stringent. By replacing it with a more appropriate (semi)norm, fewer steps are unnecessarily rejected and the backpropagation is made faster. This requires only minor code modifications. Experiments on a wide range of tasks-including time series, generative modeling, and physical control-demonstrate a median improvement of 40% fewer function evaluations. On some problems we see as much as 62% fewer function evaluations, so that the overall training time is roughly halved.

1. INTRODUCTION

We begin by recalling the usual set-up for neural differential equations.

1.1. NEURAL ORDINARY DIFFERENTIAL EQUATIONS

The general approach of neural ordinary differential equations (E, 2017; Chen et al., 2018) is to use ODEs as a learnable component of a differentiable framework. Typically the goal is to approximate a map x → y by learning functions 1 ( • , φ), 2 ( • , ψ) and f ( • , • , θ), which are composed such that z(τ ) = 1 (x, φ), z(t) = z(τ ) + t τ f (s, z(s), θ) ds and y ≈ 2 (z(T ), ψ). The variables φ, θ, ψ denote learnable parameters and the ODE is solved over the interval [τ, T ]. We include the (often linear) maps 1 ( • , φ), 2 ( • , ψ) for generality, as in many contexts they are important for the expressiveness of the model (Dupont et al., 2019; Zhang et al., 2020) , though our contributions will be focused around the ODE component and will not depend on these maps. Here we will consider neural differential equations that may be interpreted as a neural ODE.

1.2. APPLICATIONS

Neural differential equations have to the best our knowledge three main applications: 



Time series modeling. Rubanova et al. (2019) interleave Neural ODEs with RNNs to produce ODEs with jumps. Kidger et al. (2020) take f (t, z, θ) = g(z, θ) dX dt (t), dependent on some time-varying input X, to produce a neural controlled differential equation. 2. Continuous Normalising Flows as in Chen et al. (2018); Grathwohl et al. (2019), in which the overall model acts as coupling or transformation between probability distributions, 3. Modeling or controlling physical environments, for which a differential equation based model may be explicitly desired, see for example Zhong et al. (2020).

