JUMPY RECURRENT NEURAL NETWORKS

Abstract

Recurrent neural networks (RNNs) can learn complex, long-range structure in time series data simply by predicting one point at a time. Because of this ability, they have enjoyed widespread adoption in commercial and academic contexts. Yet RNNs have a fundamental limitation: they represent time as a series of discrete, uniform time steps. As a result, they force a tradeoff between temporal resolution and the computational expense of predicting far into the future. To resolve this tension, we propose a Jumpy RNN model which does not predict state transitions over uniform intervals of time. Instead, it predicts a sequence of linear dynamics functions in latent space and intervals of time over which their predictions can be expected to be accurate. This structure enables our model to jump over long time intervals while retaining the ability to produce fine-grained or continuous-time predictions when necessary. In simple physics simulations, our model can skip over long spans of predictable motion and focus on key events such as collisions between two balls. On a set of physics tasks including coordinate and pixel observations of a small-scale billiards environment, our model matches the performance of a baseline RNN while using a fifth of the compute. On a real-world weather forecasting dataset, it makes more accurate predictions while using fewer sampling steps. When used for model-based planning, our method matches a baseline RNN while using half the compute.

1. INTRODUCTION

It is said that change happens slowly and then all at once. Billiards balls move across a table before colliding and changing trajectories; water molecules cool slowly and then undergo a rapid phase transition into ice; and economic systems enjoy periods of stability interspersed with abrupt market downturns. That is to say, many time series exhibit periods of relatively homogeneous change divided by important events. Despite this, recurrent neural networks (RNNs), popular for time series modeling, treat time in uniform intervals -potentially wasting prediction resources on long intervals of relatively constant change. One reason for this is that standard RNNs are sequence models without an explicit notion of time. Instead, the amount of time represented by a single RNN update is implicitly set by the training data. For example, a model trained on sequences of daily average temperatures has an implicit time step of a day. For a fixed computational budget, this introduces a trade-off between fidelity and temporal range. A model trained at a resolution of one time step per minute would require over 10K iterations to make a prediction for one week in the future. At the other end of the spectrum, a one-week resolution model could achieve this in a single step but could not provide information about the intervening days. As such, selecting a point on this spectrum is a troublesome design decision. In this work, we present Jumpy RNNs, a simple recurrent architecture that takes update steps at variable, data-dependent time-scales while being able to provide dense predictions at intervening points. The core innovation is to define the hidden state as a continuous, piece-wise linear function of time. Specifically, each Jumpy RNN step predicts not only a hidden state h i , but also a hidden velocity ḣi and a span of time ∆ over which the linear latent dynamics h(t) = h i + ḣi (t -i) should be applied. Our model then jumps forward in time by ∆ before updating again. Any intermediate time step can be produced by decoding the corresponding hidden state h(t). During training, our model learns to use these functions to span the non-uniform time durations between key events, where key events emerge as time points where linear latent extrapolation is ineffective. In Figure 1 , for example, we see that our model updates at the collision points between the two balls and the walls. During time spans when the balls are undergoing constant motion, our model does not perform cell updates. In contrast, a standard RNN must tick uniformly through time. We demonstrate our proposed model in several physical dynamics prediction tasks. We show Jumpy RNNs achieve comparable performance to the baseline while being between three and twenty times more efficient to sample. This includes settings with non-linear pixel-based observations. Further, we show that our model outperforms RNNs with any fixed step length, showing the importance of data-dependent step sizes. Finally, we demonstrate that a learned Jumpy RNN dynamics model can be leveraged as an efficient forward predictor in a planning domain. Our key contributions are to: -Identify a trade-off between temporal resolution and the computational expense of RNNs, -Propose Jumpy RNNs, which make jumpy predictions and interpolate between them, -Show empirically that Jumpy RNNs are efficient and effective at jumpy time series prediction.

2. JUMPY RECURRENT NEURAL NETWORKS

Consider a continuous-time function x(t) sampled at uniform time steps to form the sequence x 0 , x 1 , . . . , x T . We study the problem of generative modeling, where given an initial set of observations, the goal is to auto-regressively predict a likely continuation. Standard RNN. The per-step RNN computation during auto-regressive sequence generation is: h t = RNNCell( h t-1 , φ(x t ) ), xt+1 = f (h t ) where φ(•) and f (•) are observation encoders and decoders respectively, and h t is the hidden state. This computation is performed at each time-step even when the dynamics of x(t) are simple (or even constant) for long stretches of time. One way to linearly reduce the computation required to predict into the future is to increase the time-span ∆ between RNN ticks. Standard RNNs, however, then lose the ability to predict at clock times in between the ∆ time steps. This introduces a trade-off between predictive resolution and the computational cost of predicting far into the future.

2.1. JUMPY RNN ARCHITECTURE

Continuous Hidden Dynamics with Constant Jumps. Our first step toward resolving the trade-off is to upgrade the standard RNN so that it can learn to linearly interpolate a continuous-time hidden state h(t) between updates. Let ∆ be the time between RNN ticks such that RNN tick i (starting at i = 0) corresponds to continuous time point τ i = i∆. For update i, the RNN predicts both a hidden state h i and hidden velocity ḣi that describes how the hidden state h(t) evolves over the time interval [τ i , τ i + ∆]. Specifically, the operation of this linear-dynamics RNN with constant jump is given by: h i , ḣi = RNNCell h i-1 , ḣi-1 ∆ , φ (x(τ i )) (1) h(t) = h i + (t -τ i ) ḣi for t ∈ [τ i , τ i + ∆] (2) x(t) = f (h(t)) where [•, •] denotes concatenation and x(t) is the continuous time output prediction that can be immediately produced on demand for any time in between ticks. Under this model, the hidden state



Figure1: Predicting the dynamics of two billiards balls (left) using a baseline RNN cell (center) and a Jumpy RNN cell (right). Whereas the baseline model produces a hidden state h t at each time step, our jumpy model predicts a continuous-time hidden state, over a predicted interval ∆ i . This allows it to skip over long spans of predictable motion and focus on key events such as collisions.

