MALI: A MEMORY EFFICIENT AND REVERSE ACCU-RATE INTEGRATOR FOR NEURAL ODES

Abstract

Neural ordinary differential equations (Neural ODEs) are a new family of deeplearning models with continuous depth. However, the numerical estimation of the gradient in the continuous case is not well solved: existing implementations of the adjoint method suffer from inaccuracy in reverse-time trajectory, while the naive method and the adaptive checkpoint adjoint method (ACA) have a memory cost that grows with integration time. In this project, based on the asynchronous leapfrog (ALF) solver, we propose the Memory-efficient ALF Integrator (MALI), which has a constant memory cost w.r.t number of solver steps in integration similar to the adjoint method, and guarantees accuracy in reverse-time trajectory (hence accuracy in gradient estimation). We validate MALI in various tasks: on image recognition tasks, to our knowledge, MALI is the first to enable feasible training of a Neural ODE on ImageNet and outperform a well-tuned ResNet, while existing methods fail due to either heavy memory burden or inaccuracy; for time series modeling, MALI significantly outperforms the adjoint method; and for continuous generative models, MALI achieves new state-of-theart performance.We provide a pypi package: https://jzkay12.github. io/TorchDiffEqPack 

1. INTRODUCTION

Recent research builds the connection between continuous models and neural networks. The theory of dynamical systems has been applied to analyze the properties of neural networks or guide the design of networks (Weinan, 2017; Ruthotto & Haber, 2019; Lu et al., 2018) . In these works, a residual block (He et al., 2016) is typically viewed as a one-step Euler discretization of an ODE; instead of directly analyzing the discretized neural network, it might be easier to analyze the ODE. Another direction is the neural ordinary differential equation (Neural ODE) (Chen et al., 2018) , which takes a continuous depth instead of discretized depth. The dynamics of a Neural ODE is typically approximated by numerical integration with adaptive ODE solvers. Neural ODEs have been applied in irregularly sampled time-series (Rubanova et al., 2019) , free-form continuous generative models (Grathwohl et al., 2018; Finlay et al., 2020) , mean-field games (Ruthotto et al., 2020) , stochastic differential equations (Li et al., 2020) and physically informed modeling (Sanchez-Gonzalez et al., 2019; Zhong et al., 2019) . Though the Neural ODE has been widely applied in practice, how to train it is not extensively studied. The naive method directly backpropagates through an ODE solver, but tracking a continuous trajectory requires a huge memory. Chen et al. (2018) proposed to use the adjoint method to determine the gradient in continuous cases, which achieves constant memory cost w.r.t integration time; however, as pointed out by Zhuang et al. (2020) , the adjoint method suffers from numerical errors due to the inaccuracy in reverse-time trajectory. Zhuang et al. (2020) proposed the adaptive checkpoint adjoint (ACA) method to achieve accuracy in gradient estimation at a much smaller memory cost compared to the naive method, yet the memory consumption of ACA still grows linearly with integration time. Due to the non-constant memory cost, neither ACA nor naive method are suitable for large scale datasets (e.g. ImageNet) or high-dimensional Neural ODEs (e.g. FFJORD (Grathwohl et al., 2018) ). In this project, we propose the Memory-efficient Asynchronous Leapfrog Integrator (MALI) to achieve advantages of both the adjoint method and ACA: constant memory cost w.r.t integration time and accuracy in reverse-time trajectory. MALI is based on the asynchronous leapfrog (ALF) integrator (Mutze, 2013) . With the ALF integrator, each numerical step forward in time is reversible. Therefore, with MALI, we delete the trajectory and only keep the end-time states, hence achieve constant memory cost w.r.t integration time; using the reversibility, we can accurately reconstruct the trajectory from the end-time value, hence achieve accuracy in gradient. Our contributions are: 1. We propose a new method (MALI) to solve Neural ODEs, which achieves constant memory cost w.r.t number of solver steps in integration and accuracy in gradient estimation. We provide theoretical analysis. 2. We validate our method with extensive experiments: (a) for image classification tasks, MALI enables a Neural ODE to achieve better accuracy than a well-tuned ResNet with the same number of parameters; to our knowledge, MALI is the first method to enable training of Neural ODEs on a large-scale dataset such as ImageNet, while existing methods fail due to either heavy memory burden or inaccuracy. (b) In time-series modeling, MALI achieves comparable or better results than other methods. (c) For generative modeling, a FFJORD model trained with MALI achieves new state-of-the-art results on MNIST and Cifar10.

2. PRELIMINARIES 2.1 NUMERICAL INTEGRATION METHODS

An ordinary differential equation (ODE) typically takes the form dz(t) dt = f θ (t, z(t)) s.t. z(t 0 ) = x, t ∈ [t 0 , T ], Loss = L(z(T ), y) where z(t) is the hidden state evolving with time, T is the end time, t 0 is the start time (typically 0), x is the initial state. The derivative of z(t) w.r.t t is defined by a function f , and f is defined as a sequence of layers parameterized by θ. The loss function is L(z(T ), y), where y is the target variable. Eq. 1 is called the initial value problem (IVP) because only z(t 0 ) is specified. Algorithm 1: Numerical Integration Input initial state x, start time t 0 , end time T , error tolerance etol, initial stepsize h. Initialize z(0) = x, t = t 0 While t < T error est = ∞ While error est > etol h ← h × DecayF actor ẑ, error est = ψ h (t, z) If error est < etol h ← h × IncreaseF actor t ← t + h, z ← ẑ Notations We summarize the notations following Zhuang et al. (2020) . • z i (t i )/z(τ i ): hidden state in forward/reverse time trajectory at time t i /τ i . • ψ h (t i , z i ): the numerical solution at time t i + h, starting from (t i , z i ) with a stepsize h. • N f , N z : N f is the number of layers in f in Eq. 1, N z is the dimension of z. • N t /N r : number of discretized points (outer iterations in Algo. 1) in forward / reverse integration. • m: average number of inner iterations in Algo. 1 to find an acceptable stepsize.

Numerical Integration

The algorithm for general adaptive-stepsize numerical ODE solvers is summarized in Algo. 1 (Wanner & Hairer, 1996) . The solver repeatedly advances in time by a step, which is the outer loop in Algo. 1 (blue curve in Fig. 1 ). For each step, the solver decreases the stepsize until the estimate of error is lower than the tolerance, which is the inner loop in Algo. 1 (green curve in Fig. 1 ). For fixed-stepsize solvers, the inner loop is replaced with a single evaluation of ψ h (t, z) using predefined stepsize h. Different methods typically use different ψ, for example different orders of the Runge-Kutta method (Runge, 1895) .

2.2. ANALYTICAL FORM OF GRADIENT IN CONTINUOUS CASE

We first briefly introduce the analytical form of the gradient in the continuous case, then we compare different numerical implementations in the literature to estimate the gradient. The analytical form  × N t × m × 2 N zN f × (N t + N r ) × m N zN f × N t × (m + 1) N zN f × N t × (m + 2) Memory N zN f × N t × m N zN f N z(N f + N t ) N z(N f + 1) Computation graph depth N f × N t × m N f × N r N f × N t N f × N t Reverse accuracy

Search for optimal stepsize

Step forward with optimal stepsize Ground-truth trajectory of the gradient in the continuous case is dL dθ = - 0 T a(t) ∂f (z(t), t, θ) ∂θ dt (2) da(t) dt + ∂f (z(t), t, θ) ∂z(t) a(t) = 0 ∀t ∈ (0, T ), a(T ) = ∂L ∂z(T ) where a(t) is the "adjoint state". Detailed proof is given in (Pontryagin, 1962) . In the next section we compare different numerical implementations of this analytical form.

2.3. NUMERICAL IMPLEMENTATIONS IN THE LITERATURE FOR THE ANALYTICAL FORM

We compare different numerical implementations of the analytical form in this section. The forwardpass and backward-pass of different methods are demonstrated in Fig. 1 and Fig. 2 respectively. Forward-pass is similar for different methods. The comparison of backward-pass among different methods are summarized in Table . 1. We explain methods in the literature below. Naive method The naive method saves all of the computation graph (including search for optimal stepsize, green curve in Fig. 2 ) in memory, and backpropagates through it. Hence the memory cost is N z N f × N t × m and depth of computation graph are N f × N t × m, and the computation is doubled considering both forward and backward passes. Besides the large memory and computation, the deep computation graph might cause vanishing or exploding gradient (Pascanu et al., 2013) . Adjoint method Note that we use "adjoint state equation" to refer to the analytical form in Eq. 2 and 3, while we use "adjoint method" to refer to the numerical implementation by Chen et al. (2018) . As in Fig. 1 and 2, the adjoint method forgets forward-time trajectory (blue curve) to achieve memory cost N z N f which is constant to integration time; it takes the end-time state (derived from forward-time integration) as the initial state, and solves a separate IVP (red curve) in reverse-time. Theorem 2.1. (Zhuang et al., 2020) For an ODE solver of order p, the error of the reconstructed initial value by the adjoint method is N -1 k=0 h p+1 k DΦ T t k (z k )l(t k , z k ) + (-h k ) p+1 DΦ t k T (z k )l(t k , z k ) + O(h p+1 ), where Φ is the ideal solution, DΦ is the Jacobian of Φ, l(t, z) and l(t, z) are the local error in forward-time and reverse-time integration respectively. Theorem 2.1 is stated as Theorem 3.2 in Zhuang et al. (2020) ; please see reference paper for detailed proof. To summarize, due to inevitable errors with numerical ODE solvers, the reverse-time trajectory (red curve, z(τ )) cannot match the forward-time trajectory (blue curve, z(t)) accurately. The error in z propagates to dL dθ by Eq. 2, hence affects the accuracy in gradient estimation. Adaptive checkpoint adjoint (ACA) To solve the inaccuracy of adjoint method, Zhuang et al. (2020) proposed ACA: ACA stores forward-time trajectory in memory for backward-pass, hence guarantees accuracy; ACA deletes the search process (green curve in Fig. 2 ), and only backpropagates through the accepted step (blue curve in Fig. 2 ), hence has a shallower computation graph (N f × N t for ACA vs N f × N t × m for naive method). ACA only stores {z(t i )} Nt i=1 , and deletes the computation graph for {f z(t i ), t i } Nt i=1 , hence the memory cost is N z (N f + N t ). Though the memory cost is much smaller than the naive method, it grows linearly with N t , and can not handle very high dimensional models. In the following sections, we propose a method to overcome all these disadvantages of existing methods.

3.1. ASYNCHRONOUS LEAPFROG INTEGRATOR

In this section we give a brief introduction to the asynchronous leapfrog (ALF) method (Mutze, 2013) , and we provide theoretical analysis which is missing in Mutze (2013) . For general firstorder ODEs in the form of Eq. 1, the tuple (z, t) is sufficient for most ODE solvers to take a step numerically. For ALF, the required tuple is (z, v, t), where v is the "approximated derivative". Most numerical ODE solvers such as the Runge-Kutta method (Runge, 1895) track state z evolving with time, while ALF tracks the "augmented state" (z, v). We explain the details of ALF as below. Algorithm 2: Forward of ψ in ALF Input (z in , v in , s in , h) where s in is current time, z in and v in are correponding values at time s in , h is stepsize. Forward s 1 = s in + h/2 k 1 = z in + v in × h/2 u 1 = f (k 1 , s 1 ) v out = v in + 2(u 1 -v in ) z out = k 1 + v out × h/2 s out = s 1 + h/2 Output (z out , v out , s out , h) Algorithm 3: ψ -1 (Inverse of ψ) in ALF Input (z out , v out , s out , h) where s out is current time, z out and v out are corresponding values at s out , h is stepsize. Inverse s 1 = s out -h/2 k 1 = z out -v out × h/2 u 1 = f (k 1 , s 1 ) v in = 2u 1 -v out z in = k 1 -v in × h/2 s in = s 1 -h/2 Output (z in , v in , s in , h) Figure 3 : With ALF method, given any tuple (zj, vj, tj) and discretized time points {ti} N t i=1 , we can reconstruct the entire trajectory accurately due to the reversibility of ALF. Procedure of ALF Different ODE solvers have different ψ in Algo. 1, hence we only summarize ψ for ALF in Algo. 2. Note that for a complete algorithm of integration for ALF, we need to plug Algo. 2 into Algo. 1. The forward-pass is summarized in Algo. 2. Given stepsize h, with input (z in , v in , s in ), a single step of ALF outputs (z out , v out , s out ). As in Fig. 3 , given (z 0 , v 0 , t 0 ), the numerical forwardtime integration calls Algo. 2 iteratively: (z i , v i , t i , h i ) = ψ(z i-1 , v i-1 , t i-1 , h i ) s.t. h i = t i -t i-1 , i = 1, 2, ...N t (4) Invertibility of ALF An interesting property of ALF is that ψ defines a bijective mapping; therefore, we can reconstruct (z in , v in , s in , h) from (z out , v out , s out , h), as demonstrated in Algo. 7. As in Fig. 3 , we can reconstruct the entire trajectory given the state (z j , v j ) at time t j , and the discretized time points {t 0 , ...t Nt }. For example, given (z Nt , v Nt ) and {t i } Nt i=0 , the trajectory for Eq. 4 is reconstructed: (z i-1 , v i-1 , t i-1 , h i ) = ψ -1 (z i , v i , t i , h i ) s.t. h i = t i -t i-1 , i = N t , N t -1, ..., 1 In the following sections, we will show the invertibility of ALF is the key to maintain accuracy at a constant memory cost to train Neural ODEs. Note that "inverse" refers to reconstructing the input from the output without computing the gradient, hence is different from "back-propagation".

Initial value

For an initial value problem (IVP) such as Eq. 1, typically z 0 = z(t 0 ) is given while v 0 is undetermined. We can construct v 0 = f (z(t 0 ), t 0 ), so the initial augmented state is (z 0 , v 0 ). Difference from midpoint integrator The midpoint integrator (Süli & Mayers, 2003) is similar to Algo. 2, except that it recomputes v in = f (z in , s in ) for every step, while ALF directly uses the input v in . Therefore, the midpoint method does not have an explicit form of inverse. Local truncation error Theorem 3.1 indicates that the local truncation error of ALF is of order O(h 3 ); this implies the global error is O(h 2 ). Detailed proof is in Appendix A.3. Theorem 3.1. For a single step in ALF with stepsize h, the local truncation error of z is O(h 3 ), and the local truncation error of v is O(h 2 ).

A-Stability

The ALF solver has a limited stability region, but this can be solved with damping. The damped ALF replaces the update of v out in Algo. 2 with v out = v in + 2η(u 1 -v in ), where η is the "damping coefficient" between 0 and 1. We have the following theorem on its numerical stability. Theorem 3.2. For the damped ALF integrator with stepsize h, where σ i is the i-th eigenvalue of the Jacobian ∂f ∂z , then the solver is A-stable if 1 + η(hσ i -1) ± η 2hσ i + η(hσ i -1) 2 < 1, ∀i Proof is in Appendix A.4 and A.5. Theorem 3.2 implies the following: when η = 1, the damped ALF reduces to ALF, and the stability region is empty; when 0 < η < 1, the stability region is nonempty. However, stability describes the behaviour when T goes to infinity; in practice we always use a bounded T and ALF performs well. Inverse of damped ALF is in Appendix A.5.

3.2. MEMORY-EFFICIENT ALF INTEGRATOR (MALI) FOR GRADIENT ESTIMATION

An ideal solver for Neural ODEs should achieve two goals: accuracy in gradient estimation and constant memory cost w.r.t integration time. Yet none of the existing methods can achieve both goals. We propose a method based on the ALF solver, which to our knowledge is the first method to achieve the two goals simultaneously.  = 0 For i in {N t , N t -1, ..., 2, 1}: Reconstruct (z i-1 , v i-1 ) from (z i , v i ) by Algo. 7 Local forward (z i , v i , t i , h i ) = ψ(z i-1 , v i-1 , t i-1 , h i ) Local backward, get ∂f (zi-1,ti-1,θ) ∂zi-1 and ∂f (zi-1,ti-1,θ) ∂θ Update a(t) and dL dθ by Eq. 2 and Eq. 3 discretized at time points t i-1 and t i Delete local computation graph Output the adjoint state a(t 0 ) (gradient w.r.t input z 0 ) and parameter gradient dL dθ Procedure of MALI Details of MALI are summarized in Algo. 4. For the forward-pass, we only keep the end-time state (z Nt , v Nt ) and the accepted discretized time points (blue curves in Fig. 1 and 2 ). We ignore the search process for optimal stepsize (green curve in Fig. 1 and 2 ), and delete other variables to save memory. During the backward pass, we can reconstruct the forward-time trajectory as in Eq. 5, then calculate the gradient by numerical discretization of Eq. 2 and Eq. 3. Constant memory cost w.r.t number of solver steps in integration We delete the computation graph and only keep the end-time state to save memory. The memory cost is N z (N f + 1), where  (N f + 1). Accuracy Our method guarantees the accuracy of reverse-time trajectory (e.g. blue curve in Fig. 2 matches the blue curve in Fig. 1 ), because ALF is explicitly invertible for free-form f (see Algo. 7). Therefore, the gradient estimation in MALI is more accurate compared to the adjoint method. Computation cost Recall that on average it takes m steps to find an acceptable stepsize, whose error estimate is below tolerance. Therefore, the forward-pass with search process has computation burden N z × N f × N t × m. Note that we only reconstruct and backprop through the accepted step and ignore the search process, hence it takes another 1 . N z × N f × N t × 2 computation. The overall computation burden is N z N f × N t × (m + 2) as in Table Shallow computation graph Similar to ACA, MALI only backpropagates through the accepted step (blue curve in Fig. 2 ) and ignores the search process (green curve in Fig. 2 ), hence the depth of computation graph is N f × N t . The computation graph of MALI is much shallower than the naive method, hence is more robust to vanishing and exploding gradients (Pascanu et al., 2013) . Summary The adjoint method suffers from inaccuracy in reverse-time trajectory, the naive method suffers from exploding or vanishing gradient caused by deep computation graph, and ACA finds a balance but the memory grows linearly with integration time. MALI achieves accuracy in reversetime trajectory, constant memory w.r.t integration time, and a shallow computation graph.

4.1. VALIDATION ON A TOY EXAMPLE

We compare the performance of different methods on a toy example, defined as L(z(T )) = z(T ) 2 s.t. z(0) = z 0 , dz(t)/dt = αz(t) The analytical solution is z(t) = z 0 e αt , L = z 2 0 e 2αT , dL/dz 0 = 2z 0 e 2αT , dL/dα = 2T z 2 0 e 2αT We plot the amplitude of error between numerical solution and analytical solution varying with T (integrated under the same error tolerance, rtol = 10 -5 , atol = 10 -6 ) in Fig 4 . ACA and MALI have similar errors, both outperforming other methods. We also plot the memory consumption for different methods on a Neural ODE with the same input in Fig. 4 . As the error tolerance decreases, the solver evaluates more steps, hence the naive method and ACA increase memory consumption, while MALI and the adjoint method have a constant memory cost. These results validate our analysis in Sec. 3.2 and Table 1 , and shows MALI achieves accuracy at a constant memory cost.

4.2. IMAGE RECOGNITION WITH NEURAL ODE

We validate MALI on image recognition tasks using Cifar10 and ImageNet datasets. Similar to Zhuang et al. (2020) , we modify a ResNet18 into its corresponding Neural ODE: the forward function is y = x + f θ (x) and y = x + T 0 f θ (z)dt for the residual block and Neural ODE respectively, where the same f θ is shared. We compare MALI with the naive method, adjoint method and ACA.

Results on Cifar10

Results of 5 independent runs on Cifar10 are summarized in Fig. 5 . MALI achieves comparable accuracy to ACA, and both significantly outperform the naive and the adjoint method. Furthermore, the training speed of MALI is similar to ACA, and both are almost two times faster than the adjoint memthod, and three times faster than the naive method. This validates our analysis on accuracy and computation burden in Table 1 . Accuracy on ImageNet Due to the heavy memory burden caused by large images, the naive method and ACA are unable to train a Neural ODE on ImageNet with 4 GPUs; only MALI and the adjoint method are feasible due to the constant memory. We also compare the Neural ODE to a standard ResNet. As shown in Fig. 6 , the accuracy of the Neural ODE trained with MALI closely follows ResNet, and significantly outperforms the adjoint method (top-1 validation: 70% v.s. 63%). Invariance to discretization scheme A continuous model should be invariant to discretization schemes (e.g. different types of ODE solvers) as long as the discretization is sufficiently accurate. We test the Neural ODE using different solvers without re-training; since ResNet is often viewed as a one-step Euler discretization of an ODE (Haber & Ruthotto, 2017) , we perform similar experiments. As shown in Table 2 , Neural ODE consistently achieves high accuracy (∼70%), while ResNet drops to random guessing (∼0.1%) because ResNet as a one-step Euler discretization fails to be a meaningful dynamical system (Queiruga et al., 2020) . Hanshu et al. (2019) demonstrated that Neural ODE is more robust to adversarial attack than ResNet on small-scale datasets such as Cifar10. We validate this result on the large-scale ImageNet dataset. The top-1 accuracy of Neural ODE and ResNet under FGSM attack (Goodfellow et al., 2014) are summarized in Table 3 . For Neural ODE, due to its invariance to discretization scheme, we derive the gradient for attack using a certain solver (row in Table 3 ), and inference on the perturbed images using various solvers. For different combinations of solvers and perturbation amplitudes, Neural ODE consistently outperforms ResNet.

Robustness to adversarial attack

Summary In image recognition tasks, we demonstrate Neural ODE is accurate, invariant to discretization scheme, and more robust to adversarial attack than ResNet. Note that detailed explanation on the robustness of Neural ODE is out of the scope for this paper, but to our knowledge, MALI is the first method to enable training of Neural ODE on large datasets due to constant memory cost.

4.3. TIME-SERIES MODELING

We apply MALI to latent-ODE (Rubanova et al., 2019) and Neural Controlled Differential Equation (Neural CDE) (Kidger et al., 2020a; b) . Our experiment is based on the official implementation from the literature. We report the mean squared error (MSE) on the Mujoco test set in Table 4 , which is generated from the "Hopper" model using DeepMind control suite (Tassa et al., 2018) ; for all experiments with different ratios of training data, MALI achieves similar MSE to ACA, and both outperform the adjoint and naive method. We report the test accuracy on the Speech Command dataset for Neural CDE in Table 5 ; MALI achieves a higher accuracy than competing methods.

4.4. CONTINUOUS GENERATIVE MODELS

We apply MALI on FFJORD (Grathwohl et al., 2018) , a free-from continuous generative model, and compare with several variants in the literature (Finlay et al., 2020; Kidger et al., 2020a) . Our experiment is based on the official implementaion of Finlay et al. (2020) ; for a fair comparison, we train with MALI, and test with the same solver as in the literature (Grathwohl et al., 2018; Finlay et al., 2020) , the Dopri5 solver with rtol = atol = 10 -5 from the torchdiffeq package (Chen et al., 2018). Bits per dim (BPD, lower is better) on validation set for various datasets are reported in Table 6 . For continuous models, MALI consistently generates the lowest BPD, and outperforms the Vanilla FFJORD (trained with adjoint), RNODE (regularized FFJORD) and the SemiNorm Adjoint (Kidger et al., 2020a) . Furthermore, FFJORD trained with MALI achieves comparable BPD to stateof-the-art discrete-layer flow models in the literature. Please see Sec. B.3 for generated samples.

5. RELATED WORKS

Besides ALF, the symplectic integrator (Verlet, 1967; Yoshida, 1990 ) is also able to reconstruct trajectory accurately, yet it's typically restricted to second order Hamiltonian systems (De Almeida, 1990) , and are unsuitable for general ODEs. Besides aforementioned methods, there are other methods for gradient estimation such as interpolated adjoint (Daulbaev et al., 2020) and spectral method (Quaglino et al., 2019) , yet the implementations are involved and not publicly available. Other works focus on the theoretical properties of Neural ODEs (Dupont et al., 2019; Tabuada & Gharesifard, 2020; Massaroli et al., 2020) . Neural ODE is recently applied to stochastic differential equation (Li et al., 2020) , jump differential equation (Jia & Benson, 2019) and auto-regressive models (Wehenkel & Louppe, 2019) .

6. CONCLUSION

Based on the asynchronous leapfrog integrator, we propose MALI to estimate the gradient for Neural ODEs. To our knowledge, our method is the first to achieve accuracy, fast speed and a constant memory cost. We provide comprehensive theoretical analysis on its properties. We validate MALI 

7. ACKNOWLEDGEMENT

This research was funded by the National Institutes of Health (NINDS-R01NS035193)

A THEORETICAL PROPERTIES OF ALF INTEGRATOR A.1 ALGORITHM OF ALF

For the ease of reading, we write the algorithm for ψ in ALF below, which is the same as Algo. 2 in the main paper, but uses slightly different notations for the ease of analysis. Algorithm 1: Forward of ψ in ALF Input ( z in , v in , s in , h) = ( z 0 , v 0 , s 0 , h) where s 0 is current time, z 0 and v 0 are correponding values at time s 0 ; stepsize h. Forward s 1 = s 0 + h/2 (1) z 1 = z 0 + v 0 × h/2 (2) v 1 = f ( z 1 , s 1 ) (3) v 2 = v 1 + ( v 1 -v 0 ) (4) z 2 = z 1 + v 2 × h/2 (5) s 2 = s 1 + h/2 (6) Output ( z out , v out , s out , h) = ( z 2 , v 2 , s 2 , h) For simplicity, we can re-write the forward of ALF as z 2 v 2 =   z 0 + hf ( z 0 + h 2 v 0 , s 0 + h 2 ) 2f ( z 0 + h 2 v 0 , s 0 + h 2 ) -v 0   (7) Similarly, the inverse of ALF can be written as z 0 v 0 =   z 2 -hf ( z 2 -h 2 v 2 , s 2 -h 2 ) 2f ( z 2 -h 2 v 2 , s 2 -h 2 ) -v 2   (8) A.2 PRELIMINARIES For an ODE of the form dz(t) dt = f (z(t), t) We have: d 2 z(t) dt 2 = d dt f (z(t), t) = ∂f (z(t), t) ∂t + ∂f (z(t), t) ∂z dz(t) dt (10) For the ease of notation, we re-write Eq. 10 as d 2 z(t) dt 2 = f t + f z f (11) where f t and f z represents the partial derivative of f w.r.t t and z respectively.

A.3 LOCAL TRUNCATION ERROR OF ALF

Theorem A.1 (Theorem 3.1 in the main paper). For a single step in ALF with stepsize h, the local truncation error of z is O(h 3 ), and the local truncation errof of v is O(h 2 ). Proof. Under the same notation as Algo. 1, denote the ground-truth state of z and v starting from ( z 0 , s 0 ) as z and v respectively. Then the local truncation error is L z = z(s 0 + h) -z 2 , L v = v(s 0 + h) -v 2 Published as a conference paper at ICLR 2021 We estimate L z and L v in terms of polynomial of h. Under mild assumptions that f is smooth up to 2nd order almost everywhere (this is typically satisfied with neural networks with bounded weights), hence Taylor expansion is meaningful for f . By Eq. 11, the Taylor expansion of z around point ( z 0 , v 0 , s 0 ) is z(s 0 + h) = z 0 + h dz dt + h 2 2 d 2 z dt 2 + O(h 3 ) (13) = z 0 + hf ( z 0 , s 0 ) + h 2 2 f t ( z 0 , s 0 ) + f z ( z 0 , s 0 )f ( z 0 , s 0 ) + O(h 3 ) Next, we analyze accuracy of the numerical approximation. For simplicity, we directly analyze Eq. 7 by performing Taylor Expansion on f . f ( z 0 + h 2 v 0 , s 0 + h 2 ) = f ( z 0 , s 0 ) + h 2 f t ( z 0 , s 0 ) + h v 0 2 f z ( z 0 , s 0 ) + O(h 2 ) ( ) z 2 = z 0 + hf ( z 0 + h 2 v 0 , s 0 + h 2 ) Plug Eq. 14, Eq. 15 and E.q. 16 into the definition of L z , we get L z = z(s 0 + h) -z 2 (17) = z 0 + hf ( z 0 , s 0 ) + h 2 2 f t ( z 0 , s 0 ) + f z ( z 0 , s 0 )f ( z 0 , s 0 ) -z 0 + h f ( z 0 , s 0 ) + h 2 f t ( z 0 , s 0 ) + h v 0 2 f z ( z 0 , s 0 ) + O(h 3 ) (18) = h 2 2 f z ( z 0 , s 0 ) f ( z 0 , s 0 ) -v 0 + O(h 3 ) (19) Therefore, if f ( z 0 , s 0 ) -v 0 is of order O(1), L z is of order O(h 2 ); if f ( z 0 , s 0 ) -v 0 is of order O(h) or smaller, then L z is of order O(h 3 ). Specifically, at the start time of integration, we have f ( z 0 , s 0 ) -v 0 = 0 , by induction, L z at end time is O(h 3 ). Next we analyze the local truncation error in v, denoted as L v . Denote the ground truth as v(t 0 + h), we have v(s 0 + h) = f z(s 0 + h), s 0 + h (20) = f ( z 0 , s 0 ) + hf t ( z 0 , s 0 ) + z(s 0 + h) -z 0 f z ( z 0 , s 0 ) + O(h 2 ) Next we analyze the error in the numerical approximation. Plug Eq. 15 into Eq. 7, v 2 = 2f ( z 0 + h 2 v 0 , s 0 + h 2 ) -v 0 (22) = f ( z 0 , s 0 ) + f ( z 0 , s 0 ) -v 0 + hf t ( z 0 , s 0 ) + h v 0 f z ( z 0 , s 0 ) + O(h 2 ) From Eq. 14, Eq. 21 and Eq. 23, we have L v = v(s 0 + h) -v 2 (24) = f ( z 0 , s 0 ) -v 0 + z(s 0 + h) -z 0 + h v 0 f z ( z 0 , s 0 ) + O(h 2 ) (25) = f ( z 0 , s 0 ) -v 0 + h f ( z 0 , s 0 ) -v 0 f z ( z 0 , s 0 ) + O(h 2 ) (26) The last equation is derived by plugging in Eq. 14. Note that Eq. 26 holds for every single step forward in time, and at the start time of integration, we have f ( z 0 , s 0 ) -v 0 = 0 due to our initialization as in Sec. 3.1 of the main paper. Therefore, by induction, L v is of order O(h 2 ) for consecutive steps. Proof. See (Silvester, 2000) for a detailed proof. Theorem A.2. For ALF integrator with stepsize h, if hσ i is 0 or is imaginary with norm no larger than 1, where σ i is the i-th eigenvalue of the Jacobian ∂f ∂z , then the solver is on the critical boundary of A-stability; otherwise, the solver is not A-stable. Proof. A solver is A-stable is equivalent to the eigenvalue of the numerical forward has a norm below 1. We calculate the eigenvalue of ψ below. For the function defined by Eq. 7, the Jacobian is J =   ∂ z2 ∂z0 ∂ z2 ∂ v0 ∂ v2 ∂z0 ∂ v2 ∂ v0   =   I + h ∂f ∂z h 2 2 ∂f ∂z 2 × ∂f ∂z h ∂f ∂z -I   We determine the eigenvalue of J by solving the equation det(J -λI) =   h ∂f ∂z + (1 -λ)I h 2 2 ∂f ∂z 2 × ∂f ∂z h ∂f ∂z -(1 + λ)I   = 0 It's trivial to check J satisfies conditions for Lemma A.1.1.Therefore, we have det(J -λI) = det h ∂f ∂z + (1 -λ)I h ∂f ∂z -(1 + λ)I - h 2 2 ∂f ∂z 2 × ∂f ∂z (29) = det -2λh ∂f ∂z + (λ 2 -1)I Suppose the eigen-decompostion of ∂f ∂z can be written as ∂f ∂z = Λ    σ 1 σ 2 ... σ N    Λ -1 Note that I = ΛIλ -1 , hence we have det(J -λI) = det Λ -2λh    σ 1 σ 2 ... σ N    + (λ 2 -1)I Λ -1 (32) = N i=1 (λ 2 -2hσ i λ -1) Hence the eigenvalues are λ i± = hσ i ± h 2 σ 2 i + 1 A-stability requires |λ i± | < 1, ∀i, and has no solution. The critical boundary is |λ i± | = 1, the solution is: hσ i is 0 or on the imaginary line with norm no larger than 1. Proof. The proof is similar to Thm. A.3. By similar calculations using the Taylor Expansion in Eq. 15 and Eq. 14, we have z 2 -z(s 0 + h) = (1 -η)h v 0 + hη f ( z 0 , s 0 ) + h 2 f t ( z 0 , s 0 ) + h v 0 2 f z ( z 0 , s 0 ) -h f ( z 0 , s 0 ) + h 2 f t z 0 , s 0 + h 2 f z ( z 0 , s 0 )f ( z 0 , s 0 ) + O(h 2 ) (50) = (1 -η)h v 0 -f ( z 0 , s 0 ) + η -1 2 h 2 f t ( z 0 , s 0 ) + h 2 2 η v 0 -f ( z 0 , s 0 ) f z ( z 0 , s 0 ) + O(h 2 ) Using Eq. 21, Eq. 15 and Eq. 14, we have ṽ2 -v 2 = (1 -2η) v 0 + (2η -1)f ( z 0 , s 0 ) + (1 -η)hf t ( z 0 , s 0 ) + z(s 0 + h) -z 0 -ηh v 0 f z ( z 0 , s 0 ) + O(h 2 ) (52) = (2η -1) f ( z 0 , s 0 ) -z 0 + (1 -η)hf t ( z 0 , s 0 ) + η hf ( z 0 , s 0 ) -h v 0 f z ( z 0 , s 0 ) + O(h 2 ) (53) Note that when η = 1, Eq. 51 reduces to Eq. 19, and Eq. 53 reduces to Eq. 26. By initialization, we have |f ( z 0 , s 0 ) -v 0 | = 0 at initial time, hence by induction, the local truncation error for z is O(h 2 ); the local truncation error for v is O(h) when η < 1, and is O(h 2 ) when η = 1. Theorem A.4 (Theorem 3.2 in the main paper). For Dampled ALF integrator with stepsize h, where σ i is the i-th eigenvalue of the Jacobian ∂f ∂z , then the solver is A-stable if 1 + η(hσ -1) ± η 2hσ i + η(hσ i -1) 2 < 1, ∀i. Proof. The Jacobian of the forward-pass of a single step damped ALF is J =   I + ηh ∂f ∂z (1 -η)hI + η h 2 2 ∂f ∂z 2η ∂f ∂z ηh ∂f ∂z + (1 -2η)I   when η = 1, J reduces to Eq. 27. We can determine the eigenvalue of J using similar techniques. Assume the eigenvalues for ∂f ∂z are {σ i }, then we have det(J -λI) = det   (1 -λ)I + ηh ∂f ∂z (1 -η)hI + η h 2 2 ∂f ∂z 2η ∂f ∂z ηh ∂f ∂z + (1 -2η -λ)I   (55) = det (1 -λ)I + ηh ∂f ∂z ηh ∂f ∂z + (1 -2η -λ)I -(1 -η)hI + η h 2 2 ∂f ∂z 2η ∂f ∂z (56) = N i=1 1 + η(hσ i -1) ± η 2hσ i + η(hσ i -1) 2 (57) when η < 1, it's easy to check that 1+η(hσ i -1)± η 2hσ i + η(hσ i -1) 2 < 1 has non-empty solutions for hσ. For a quick validation, we plot the region of A-stability on the imaginary plane for a single eigenvalue in Fig. 1 . As η increases, the area of stability decreases. When η = 1, the system is no-where A-stable, and the boundary for A-stability is on the imaginary axis [-i, i] where i is the imaginary unit. We directly modify a ResNet18 into a Neural ODE, where the forward of a residual block (y = x + f (x)) and the forward of an ODE block (y = x + T 0 f (z, t)dt where T = 1) share the same parameterization f , hence they have the same number of parameters. Our experiment is based on the official implementation by Zhuang et al. (2020) and an open-source repository (Liu, 2017) . All models are trained with SGD optimizer for 90 epochs, with an initial learning rate of 0.01, and decayed by a factor of 10 at 30th epoch and 60th epoch respectively. Training scheme is the same for all models (ResNet, Neural ODE trained with adjoint, naive, ACA and MALI). For ACA, we follow the settings in (Zhuang et al., 2020) and use the official implementation torch ACAfoot_2 , and use a Heun-Euler solver with rtol = 10 -1 , atol = 10 -foot_3 during training. For MALI, we use an adaptive version and set rtol = 10 -1 , atol = 10 -2 . For the naive and adjoint method, we use the default Dopri5 solver from the torchdiffeq 2 package with rtol = atol = 10 -5 . We train all models for 5 independent runs, and report the mean and standard deviation across runs.

B.1.2 EXPERIMENTS ON IMAGENET

Training scheme We conduct experiments on ImageNet with ResNet18 and Neural-ODE18. All models are trained on 4 GTX-1080Ti GPUs with a batchsize of 256. All models are trained for 80 epochs, with an initial learning rate of 0.1, and decayed by a factor of 10 at 30th and 60th epoch. Note that due to the large size input 256 × 256, the naive method and ACA requires a huge memory, and is infeasible to train. MALI and the adjoint method requires a constant memory hence is suitable for large-scale experiments. For both MALI and the adjoint menthod, we use a fixed stepsize of 0.25, and integrates from 0 to T = 1. As shown in Table . 2 in the main paper, a stepsize of 0.25 is sufficiently small to train a meaningful continuous model that is robust to discretization scheme. Invariance to discretization scheme To test the influence of discretization scheme, we test our Neural ODE with different solvers without re-training. For fixed-stepsize solvers, we tested various step sizes including {0.1, 0.15, 0.25, 0.5, 1.0}; for adaptive solvers, we set rtol=0.1, atol=0.01 for MALI and Heun-Euler method, and set rtol = 10 -2 , atol = 10 -3 for RK23 solver, and set rtol = 10 -4 , atol = 10 -5 for Dopri5 solver. As shown in Table . 2, Neural ODE trained with MALI is robust to discretization scheme, and MALI significantly outperforms the adjoint method in terms of accuracy (70% v.s. 63% top-1 accuracy on the validation dataset). An interesting finding is that when trained with MALI which is a second-order solver, and tested with higher-order solver (e.g. Furthermore, many papers claim ResNet to be an approximation for an ODE (Lu et al., 2018) . However, Queiruga et al. (2020) argues that many numerical discretizations fail to be meaningful dynamical systems, while our experiments demonstrate that our model is continuous hence invariant to discretization schemes. Adversarial robustness Besides the high accuracy and robustness to discretization scheme, another advantage of Neural ODE is the robustness to adversarial attack. The adversary robustness of Neural ODE is extensively studied in (Hanshu et al., 2019) , but not only validated on small-scale datasets such as Cifar10. To our knowledge, our method is the first to enable effectuve training of Neural ODE on large-scale datasets such as ImageNet and achieve a high accuracy, and we are the first to validate the robustness of Neural ODE on ImageNet. We use the advertorchfoot_4 toolbox to perform adversarial attack. We test the performance of ResNet and Neural ODE under FGSM attack. To be more convincing, we conduct experiment on the pretrained ResNet18 provided by the official PyTorch websitefoot_5 . Since Neural ODE is invariant to discretization scheme, it's possible to derive the gradient for attack using one ODE solver, and inference on the perturbed image using another solver. As summarized in Table . 3, Neural ODE consistently achieves a higher accuracy than ResNet under the same attack.

B.2 TIME SERIES MODELING

We conduct experiments on Latent-ODE models (Rubanova et al., 2019) and Neural CDE (controlled differential equation) (Kidger et al., 2020a) . For all experiments, we use the official implementation, and only replace the solver with MALI. The latent-ODE model is trained on the Mujoco dataset processed with code provided by the official implementation, and we experiment with different ratios (10%,20%,50%) of training data as described in (Rubanova et al., 2019) . All models are trained for 300 epochs with Adamax optimizer, with an initial learning rate of 0.01 and scaled by 0.999 for each epoch. For the Neural CDE model, for the naive method, ACA and MALI, we perform 5 independent runs and report the mean value and standard deviation; results for the adjoint and seminorm adjoint are from (Kidger et al., 2020a) . For Neural CDE, we use MALI with ALF solver with a fixed stepsize of 0.25, and train the model for 100 epochs with an initial learning rate of 0.004. On MNIST and CIFAR dataset, we set the regularization coefficients for kinetic energy and Frobenius norm of the derivative function as 0.05. We train the model for 50 epochs with an initial learning rate of 0.001.

B.3.2 ADDTIONAL RESULTS

We show generated examples on MNIST dataset in Fig. 3 , results for Cifar10 dataset in Fig. 4 , and results for ImageNet64 in Fig. 5 .

B.4 ERROR IN GRADIENT ESTIMATION FOR TOY EXAMPLES WHEN t < 1

We plot the error in gradient estimation for the toy example defined by Eq.6 in the main paper in Fig. 6 . Note that the integration time T is set as smaller than 1, while the main paper is larger than 20. We observe the same results, MALI and ACA generate smaller error than the adjoint and the naive method.

B.5 RESULTS OF DAMPED MALI

For all experiments in the main paper, we set η = 1 and did not use damping. For completeness, we experimented with damped MALI using different values of η. As shown in Table . 7, MALI is robust to different η values. 



1. Rubanova et al. (2019); 2. Zhuang et al. (2020); 3. Kidger et al. (2020a); 4. Chen et al. (2018); 5. Finlay et al. (2020); 6. Dinh et al. (2016); 7. Behrmann et al. (2019); 8. Kingma & Dhariwal (2018); 9. Ho et al. (2019); Chen et al. (2019) https://github.com/juntang-zhuang/torch_ACA https://github.com/rtqichen/torchdiffeq https://github.com/BorealisAI/advertorch https://pytorch.org/docs/stable/torchvision/models.html



Figure1: Illustration of numerical solver in forward-pass. For adaptive solvers, for each step forward-in-time, the stepsize is recursively adjusted until the estimated error is below predefined tolerance; the search process is represented by green curve, and the accepted step (ignore the search process) is represented by blue curve.

Figure 4: Comparison of error in gradient in Eq. 6. (a) error in dL dz 0 . (b) error in dL dα . (c) memory cost.

Figure 5: Results on Cifar10. From left to right: (1) box plot of test accuracy (first 4 columns are Neural

Figure 6: Top-1 accuracy on Ima-geNet validation dataset.

STABILITY ANALYSIS Lemma A.1.1. For a matrix of the form A B C D , if A, B, C, D are square matrices of the same shape, and CD = DC, then we have det A B C D = det(AD -BC)

Figure 1: Region of A-stability for eigenvalue on the imaginary plane for damped ALF. From left to right, the region of stability for η = 0.25, η = 0.7,η = 0.8 respectively. As η increases to 1, the area of stability region decreases.

(a) Training curve on ImageNet. (b) Validation curve on ImageNet.

Figure 2: Results on ImageNet.

(a) Real samples from MNIST dataset. (b) Generated samples from FFJORD.

Figure 3: Results on MNIST dataset.

(a) Real samples from CIFAR10 dataset. (b) Generated samples from FFJORD.

Figure 4: Results on Cifar10 dataset.

(a) Real samples from ImageNet64 dataset. (b) Generated samples from FFJORD.

Figure 5: Results on ImageNet64 dataset.

Figure6: Comparison of error in gradient estimation for the toy example by Eq.6 of the main paper, when t < 1.

Comparison between different methods for gradient estimation in continuous case. MALI achieves reverse accuracy, constant memory w.r.t number of solver steps in integration, shallow computation graph and low computation cost.

Top-1 test accuracy of Neural ODE and ResNet on ImageNet. Neural ODE is trained with MALI, and ResNet is trained as the original model; Neural ODE is tested using different solvers without retraining.

Top-1 accuracy under FGSM attack. is the perturbation amplitude. For Neural ODE models, row names represent the solvers to derive the gradient for attack, and column names represent solvers for inference on the perturbed image.

Test MSE (×0.01) on Mujoco dataset (lower is better). Results marked with superscript numbers correspond to literature in the footnote.

Bits per dim (BPD) of generative models, lower is better. Results marked with superscript numbers correspond to literature in the footnote. RNODE 5 SemiNorm 3 MALI RealNVP 6 i-ResNet 7 Glow 8 Flow++ 9 Residual Flow 10

Results of damped MALI with different η values. We report the test accuracy of Neural CDE on Speech Command dataset, and the test MSE of latent-ODE on Mujoco data. Error in the estimation of gradient w.r.t initial condition. (b) Error in the estimation of gradient w.r.t parameter α.

A.5 DAMPED ALF

Algorithm 2: Forward of ψ in Damped ALF (η ∈ (0, 1] ) Input ( z in , v in , s in , h) = ( z 0 , v 0 , s 0 , h) where s 0 is current time, z 0 and v 0 are correponding values at time s 0 ; stepsize h. Forward) where s out is current time, z out and v out are corresponding values at s out , h is stepsize.The main difference between ALF and Damped ALF is marked in blue in Algo. 2. In ALF, the update ofwhile in Damped ALF, the update is scaled by a factor η between 0 and 1, so the update is v 2 = 2η( v 1 -v 0 ) + v 0 . When η = 1, Damped ALF reduces to ALF.Similar to Sec. A.1, we can write the forward as For simplicity, we can re-write the forward of ALF asSimilarly, the inverse of ALF can be written asTheorem A.3. For a single step in Damped ALF with stepsize h, the local truncation error of z is O(h 2 ), and the local truncation errof of v is O(h).

