"HEY, THAT'S NOT AN ODE": FASTER ODE ADJOINTS WITH 12 LINES OF CODE

Abstract

Neural differential equations may be trained by backpropagating gradients via the adjoint method, which is another differential equation typically solved using an adaptive-step-size numerical differential equation solver. A proposed step is accepted if its error, relative to some norm, is sufficiently small; else it is rejected, the step is shrunk, and the process is repeated. Here, we demonstrate that the particular structure of the adjoint equations makes the usual choices of norm (such as L 2 ) unnecessarily stringent. By replacing it with a more appropriate (semi)norm, fewer steps are unnecessarily rejected and the backpropagation is made faster. This requires only minor code modifications. Experiments on a wide range of tasks-including time series, generative modeling, and physical control-demonstrate a median improvement of 40% fewer function evaluations. On some problems we see as much as 62% fewer function evaluations, so that the overall training time is roughly halved.

1. INTRODUCTION

We begin by recalling the usual set-up for neural differential equations.

1.1. NEURAL ORDINARY DIFFERENTIAL EQUATIONS

The general approach of neural ordinary differential equations (E, 2017; Chen et al., 2018) is to use ODEs as a learnable component of a differentiable framework. Typically the goal is to approximate a map x → y by learning functions 1 ( • , φ), 2 ( • , ψ) and f ( • , • , θ), which are composed such that z(τ ) = 1 (x, φ), z(t) = z(τ ) + t τ f (s, z(s), θ) ds and y ≈ 2 (z(T ), ψ). The variables φ, θ, ψ denote learnable parameters and the ODE is solved over the interval [τ, T ]. We include the (often linear) maps 1 ( • , φ), 2 ( • , ψ) for generality, as in many contexts they are important for the expressiveness of the model (Dupont et al., 2019; Zhang et al., 2020) , though our contributions will be focused around the ODE component and will not depend on these maps. Here we will consider neural differential equations that may be interpreted as a neural ODE.

1.2. APPLICATIONS

Neural differential equations have to the best our knowledge three main applications: 1. Time series modeling. Rubanova et al. (2019) interleave Neural ODEs with RNNs to produce ODEs with jumps. Kidger et al. (2020) take f (t, z, θ) = g(z, θ) dX dt (t), dependent on some time-varying input X, to produce a neural controlled differential equation. 

1.3. ADJOINT EQUATIONS

The integral in equation ( 1) may be backpropagated through either by backpropagating through the internal operations of a numerical solver, or by solving the backwards-in-time adjoint equations with respect to some (scalar) loss L. a z (T ) = dL dz(T ) , a z (t) = a z (T ) - t T a z (s) • ∂f ∂z (s, z(s), θ) ds and dL dz(τ ) = a z (τ ), a θ (T ) = 0, a θ (t) = a θ (T ) - t T a z (s) • ∂f ∂θ (s, z(s), θ) ds and dL dθ = a θ (τ ), a t (T ) = dL dT , a t (t) = a t (T ) - t T a z (s) • ∂f ∂s (s, z(s), θ) ds and dL dτ = a t (τ ), These equations are typically solved together as a joint system a(t) = [a z (t), a θ (t), a t (t)]. (They are already coupled; the latter two equations depend on a z .) As additionally their integrands require z(s), and as the results of the forward computation of equation ( 1) are usually not stored, then the adjoint equations are typically additionally augmented by recovering z by solving backwards-intime z(t) = z(T ) + t T f (s, z(s), θ)ds. (3)

1.4. CONTRIBUTIONS

We demonstrate that the particular structure of the adjoint equations implies that numerical equation solvers will typically take too many steps, that are too small, wasting time during backpropagation. Specifically, the accept/reject step of adaptive-step-size solvers is too stringent. By applying a correction to account for this, we demonstrate that the number of steps needed to solve the adjoint equations may be reduced by typically about 40%. We observe improvements on some problems by as much as 62%. Factoring in the forward pass (which is unchanged), the overall training time is roughly halved. Our method is hyperparameter-free and requires no tuning. We do not observe any change in model performance, and at least with the torchdiffeq package (our chosen differential equation package), this correction may be applied with only 12 lines of code.

2. METHOD

2.1 NUMERICAL SOLVERS Both the forward pass given by equation ( 1), and the backward pass given by equations ( 2) and (3), are solved by invoking a numerical differential equation solver. Our interest here is in adaptive-stepsize solvers. Indeed the default choice for solving many equations is the adaptive-step-size Runge-Kutta 5(4) scheme of Dormand-Prince (Dormand & Prince, 1980) , for example as implemented by dopri5 in the torchdiffeq package or ode45 in MATLAB. A full discussion of the internal operations of these solvers is beyond our scope here; the part of interest to us is the accept/reject scheme. Consider the case of solving the general ODE y(t) = y(τ ) + Suppose for some fixed t the solver has computed some estimate y(t) ≈ y(t), and it now seeks to take a step ∆ > 0 to compute y(t + ∆) ≈ y(t + ∆). A step is made, and some candidate y candidate (t + ∆) is generated. The solver additionally produces y err ∈ R d representing an estimate of the numerical error made in each channel during that step.



Chen et al. (2018);Grathwohl et al. (2019), in which the overall model acts as coupling or transformation between probability distributions, 3. Modeling or controlling physical environments, for which a differential equation based model may be explicitly desired, see for exampleZhong et al. (2020).

y(s)) ds, with y(t) ∈ R d .

