GATED NEURAL ODES: TRAINABILITY, EXPRESSIV-ITY AND INTERPRETABILITY

Abstract

Understanding how the dynamics in biological and artificial neural networks implement the computations required for a task is a salient open question in machine learning and neuroscience. In particular, computations requiring complex memory storage and retrieval pose significant challenge for these networks to implement or learn. Recently, a family of models described by neural ordinary differential equations (nODEs) has emerged as powerful dynamical neural network models capable of capturing complex dynamics. Here, we extend nODEs by endowing them with adaptive timescales using gating interactions. We refer to these as gated neural ODEs (gnODEs). Using a task that requires memory of continuous quantities, we demonstrate the inductive bias of the gnODEs to learn (approximate) continuous attractors. We further show how reduced-dimensional gnODEs retain their modeling power while greatly improving interpretability, even allowing explicit visualization of the structure of learned attractors. We introduce a novel measure of expressivity which probes the capacity of a neural network to generate complex trajectories. Using this measure, we explore how the phase-space dimension of the nODEs and the complexity of the function modeling the flow field contribute to expressivity. We see that a more complex function for modeling the flow field allows a lower-dimensional nODE to capture a given target dynamics. Finally, we demonstrate the benefit of gating in nODEs on several real-world tasks.

1. INTRODUCTION

How can the dynamical motifs exhibited by an artificial or a biological network implement certain computations required for a task? This is a long-standing question in computational neuroscience and machine learning (Vyas et al., 2020; Khona & Fiete, 2022) . Recurrent neural networks (RNNs) have often been used to probe this question (Mante et al., 2013; Vyas et al., 2020; Driscoll et al., 2022) , as they are flexible dynamical systems that can be easily trained (Rumelhart et al., 1986) to perform computational tasks. RNNs, particularly ones that incorporate gating interactions (Hochreiter & Schmidhuber, 1997; Cho et al., 2014) , have been wildly successful in solving complex realworld tasks (Jozefowicz et al., 2015) . While RNN models provide a link between dynamics and computation, how their (typically) highdimensional dynamics implement computation remains hard to interpret. On this note, we may turn to neural ordinary differential equations (nODEs), a class of dynamical models with a velocity field parametrized by a deep neural network (DNN), which can potentially implement more complex computations in lower dimensions than classical RNNs (Chen et al., 2018; Kidger, 2022) .foot_0 This increased complexity in lower latent/phase-space dimensions subsequently helps in extracting interpretable, effective low-dimensional dynamics that may underlie a dataset or task (Kim et al., 2021) . Despite their promise, nODEs remain under-explored in the following crucial aspects. Trainability: Can we improve performance of nODEs by introducing gating interactions (Hochreiter & Schmidhuber, 1997; Cho et al., 2014) to tame gradients in dynamical systems? Expressivity: How does the structure of the neural network modeling the velocity flow field influence a nODE's capacity to model complex trajectories? Interpretability: Does the capability of low-dimensional nODEs to model complex data improve interpretability of the dynamical computation? We summarize below the main insights of our exploration of these questions.

Main Contributions

• We leverage our understanding of gating interactions to introduce the gated neural ODE (gnODE). We find that gating endows nODEs with adaptive timescales, and improves trainability of nODEs on tasks involving long timescales or rich representations (Section 2, Appendix B). • We introduce a novel measure of expressivity related to the capacity of a neural network to store complex dynamical trajectories. nODEs and gnODEs are more expressive compared to RNNs in many parameter regimes (Sections 4, 5.3, Appendices C, F). • We demonstrate an inductive bias of gnODEs and other gated networks to utilize marginally-stable fixed points in a "flip-flop" task that requires storing continuous memory. We further demonstrate the interpretability of the gnODEs' solutions, which organize the marginally-stable fixed-points in an approximate continuous attractor (Section 5.2, Appendix E). • We show the advantage of gating in nODEs on real-world tasks (Sections 5.4-5.5, Appendix G). • We determine the critical initialization for nODEs using dynamical mean-field theory (Appendix A).

2. GATED NEURAL ODE

The gated neural ODE (gnODE) is described by τ ḣ = G φ (h, x) ⊙ [-h + F θ (h, x)] , ( ) where τ is the time constant, h ∈ R N is the hidden/latent state vector, and x(t) ∈ R D is the input vector. The velocity vector field F θ : R N × R D → R N and the gating function G φ : R N × R D → R N are parameterized (via θ and φ, respectively) by neural networks. While F θ and G φ in general can each be parametrized by any neural network, in this work, we restrict F θ and G φ to fullyconnected feedforward neural networks (FNN) F θ (h, x) = s L h and G φ (h, x) = s Lz , where s 1 = ϕ a (W 0 * h + U * x + b 0 * ), s ℓ * +1 = ϕ a (W ℓ * * s ℓ * + b ℓ * * ), s L * = ϕ * (W L * -1 * s L * -1 + b L * -1 * ) with * ∈ {h, z}. Here, W ℓ * ∈ R N ℓ * +1 ×N ℓ * , s ℓ * ∈ R N ℓ * , b ℓ * * ∈ R N ℓ * +1 , and N 0 = N L * = N is the phase-space (or latent) dimension. ϕ h ∈ {I, tanh} and ϕ z = σ, where I is the identity function and σ(x) = [1 + e -x ] -1 . When L = 1, ϕ a = ϕ * . When L > 1, we typically set ϕ a to be ReLU. Without the leak term -h and the gating interaction (i.e., setting G φ (h, x) = 1), this reverts to a form in which nODEs are typically studied (Chen et al., 2018): τ ḣ = F θ (h, x(t)). 2 We include the leak term -h in our formulation because it allows us to initialize the weights of the (gated or non-gated) nODE in either the stable or critical regime. Without the leak term, we show that the nODE is always dynamically unstable for any initialization, except for the zero initialization, and we expect this to hinder training (Abarbanel et al. (2008) ; see Appendix A for details). When we set L h = L z = 1, Equation (1) reduces to a "minimal gated recurrent unit" (mGRU; Ravanelli et al. ( 2018)), which is a simplified version of the popularly used gated recurrent unit (GRU; Cho et al. ( 2014)). When in addition the gating interaction is removed (G φ (h, x) = 1), Equation (1) reduces to a widely studied class of models known as "Elman" (or "vanilla") RNNs. 3



By classical RNNs, we mean the form of RNNs often considered in the neuroscience, physics and cognitive-science literature, where the interaction between units are additive, and the interaction strengths are represented by a matrix(McCulloch & Pitts, 1943; Sompolinsky et al., 1988; Elman, 1990; Vogels et al., 2005; Sussillo & Abbott, 2009; Song et al., 2016; Yang et al., 2019). In Chen et al. (2018), τ = 1 and x(t) = t. τ ḣ = -h + W 0 h ϕ h (h) + U 0 h x + b 0h is also popular in neuroscience models, where h can be interpreted as the internal voltage of a neuron, and ϕ h (h) as the output firing rate of the neuron; W 0 h,ij is the synaptic strength between neuron j and neuron i (Sompolinsky etal., 1988).

