MALI: A MEMORY EFFICIENT AND REVERSE ACCU-RATE INTEGRATOR FOR NEURAL ODES

Abstract

Neural ordinary differential equations (Neural ODEs) are a new family of deeplearning models with continuous depth. However, the numerical estimation of the gradient in the continuous case is not well solved: existing implementations of the adjoint method suffer from inaccuracy in reverse-time trajectory, while the naive method and the adaptive checkpoint adjoint method (ACA) have a memory cost that grows with integration time. In this project, based on the asynchronous leapfrog (ALF) solver, we propose the Memory-efficient ALF Integrator (MALI), which has a constant memory cost w.r.t number of solver steps in integration similar to the adjoint method, and guarantees accuracy in reverse-time trajectory (hence accuracy in gradient estimation). We validate MALI in various tasks: on image recognition tasks, to our knowledge, MALI is the first to enable feasible training of a Neural ODE on ImageNet and outperform a well-tuned ResNet, while existing methods fail due to either heavy memory burden or inaccuracy; for time series modeling, MALI significantly outperforms the adjoint method; and for continuous generative models, MALI achieves new state-of-theart performance.We provide a pypi package: https://jzkay12.github. io/TorchDiffEqPack 

1. INTRODUCTION

Recent research builds the connection between continuous models and neural networks. The theory of dynamical systems has been applied to analyze the properties of neural networks or guide the design of networks (Weinan, 2017; Ruthotto & Haber, 2019; Lu et al., 2018) . In these works, a residual block (He et al., 2016) is typically viewed as a one-step Euler discretization of an ODE; instead of directly analyzing the discretized neural network, it might be easier to analyze the ODE. Another direction is the neural ordinary differential equation (Neural ODE) (Chen et al., 2018) , which takes a continuous depth instead of discretized depth. The dynamics of a Neural ODE is typically approximated by numerical integration with adaptive ODE solvers. Neural ODEs have been applied in irregularly sampled time-series (Rubanova et al., 2019) , free-form continuous generative models (Grathwohl et al., 2018; Finlay et al., 2020) , mean-field games (Ruthotto et al., 2020) , stochastic differential equations (Li et al., 2020) and physically informed modeling (Sanchez-Gonzalez et al., 2019; Zhong et al., 2019) . Though the Neural ODE has been widely applied in practice, how to train it is not extensively studied. The naive method directly backpropagates through an ODE solver, but tracking a continuous trajectory requires a huge memory. Chen et al. (2018) proposed to use the adjoint method to determine the gradient in continuous cases, which achieves constant memory cost w.r.t integration time; however, as pointed out by Zhuang et al. (2020) , the adjoint method suffers from numerical errors due to the inaccuracy in reverse-time trajectory. Zhuang et al. (2020) proposed the adaptive checkpoint adjoint (ACA) method to achieve accuracy in gradient estimation at a much smaller memory cost compared to the naive method, yet the memory consumption of ACA still grows linearly with integration time. Due to the non-constant memory cost, neither ACA nor naive method are suitable for large scale datasets (e.g. ImageNet) or high-dimensional Neural ODEs (e.g. FFJORD (Grathwohl et al., 2018) ). In this project, we propose the Memory-efficient Asynchronous Leapfrog Integrator (MALI) to achieve advantages of both the adjoint method and ACA: constant memory cost w.r.t integration time and accuracy in reverse-time trajectory. MALI is based on the asynchronous leapfrog (ALF) integrator (Mutze, 2013) . With the ALF integrator, each numerical step forward in time is reversible. Therefore, with MALI, we delete the trajectory and only keep the end-time states, hence achieve constant memory cost w.r.t integration time; using the reversibility, we can accurately reconstruct the trajectory from the end-time value, hence achieve accuracy in gradient. Our contributions are: 1. We propose a new method (MALI) to solve Neural ODEs, which achieves constant memory cost w.r.t number of solver steps in integration and accuracy in gradient estimation. We provide theoretical analysis.  (t) dt = f θ (t, z(t)) s.t. z(t 0 ) = x, t ∈ [t 0 , T ], Loss = L(z(T ), y) where z(t) is the hidden state evolving with time, T is the end time, t 0 is the start time (typically 0), x is the initial state. The derivative of z(t) w.r.t t is defined by a function f , and f is defined as a sequence of layers parameterized by θ. The loss function is L(z(T ), y), where y is the target variable. Eq. 1 is called the initial value problem (IVP) because only z(t 0 ) is specified. Algorithm 1: Numerical Integration Input initial state x, start time t 0 , end time T , error tolerance etol, initial stepsize h. Initialize z(0) = x, t = t 0 While t < T error est = ∞ While error est > etol h ← h × DecayF actor ẑ, error est = ψ h (t, z) If error est < etol h ← h × IncreaseF actor t ← t + h, z ← ẑ Notations We summarize the notations following Zhuang et al. (2020) . • z i (t i )/z(τ i ): hidden state in forward/reverse time trajectory at time t i /τ i . • ψ h (t i , z i ): the numerical solution at time t i + h, starting from (t i , z i ) with a stepsize h. • N f , N z : N f is the number of layers in f in Eq. 1, N z is the dimension of z. • N t /N r : number of discretized points (outer iterations in Algo. 1) in forward / reverse integration. • m: average number of inner iterations in Algo. 1 to find an acceptable stepsize.

Numerical Integration

The algorithm for general adaptive-stepsize numerical ODE solvers is summarized in Algo. 1 (Wanner & Hairer, 1996) . The solver repeatedly advances in time by a step, which is the outer loop in Algo. 1 (blue curve in Fig. 1 ). For each step, the solver decreases the stepsize until the estimate of error is lower than the tolerance, which is the inner loop in Algo. 1 (green curve in Fig. 1 ). For fixed-stepsize solvers, the inner loop is replaced with a single evaluation of ψ h (t, z) using predefined stepsize h. Different methods typically use different ψ, for example different orders of the Runge-Kutta method (Runge, 1895).

2.2. ANALYTICAL FORM OF GRADIENT IN CONTINUOUS CASE

We first briefly introduce the analytical form of the gradient in the continuous case, then we compare different numerical implementations in the literature to estimate the gradient. The analytical form



2. We validate our method with extensive experiments: (a) for image classification tasks, MALI enables Neural ODE to achieve better accuracy than a well-tuned ResNet with the same number of parameters; to our knowledge, MALI is the first method to enable training of Neural ODEs on a large-scale dataset such as ImageNet, while existing methods fail due to either heavy memory burden or inaccuracy. (b) In time-series modeling, MALI achieves comparable or better results than other methods. (c) For generative modeling, a FFJORD model trained with MALI achieves new state-of-the-art results on MNIST and Cifar10.

