JUMPY RECURRENT NEURAL NETWORKS

Abstract

Recurrent neural networks (RNNs) can learn complex, long-range structure in time series data simply by predicting one point at a time. Because of this ability, they have enjoyed widespread adoption in commercial and academic contexts. Yet RNNs have a fundamental limitation: they represent time as a series of discrete, uniform time steps. As a result, they force a tradeoff between temporal resolution and the computational expense of predicting far into the future. To resolve this tension, we propose a Jumpy RNN model which does not predict state transitions over uniform intervals of time. Instead, it predicts a sequence of linear dynamics functions in latent space and intervals of time over which their predictions can be expected to be accurate. This structure enables our model to jump over long time intervals while retaining the ability to produce fine-grained or continuous-time predictions when necessary. In simple physics simulations, our model can skip over long spans of predictable motion and focus on key events such as collisions between two balls. On a set of physics tasks including coordinate and pixel observations of a small-scale billiards environment, our model matches the performance of a baseline RNN while using a fifth of the compute. On a real-world weather forecasting dataset, it makes more accurate predictions while using fewer sampling steps. When used for model-based planning, our method matches a baseline RNN while using half the compute.

1. INTRODUCTION

It is said that change happens slowly and then all at once. Billiards balls move across a table before colliding and changing trajectories; water molecules cool slowly and then undergo a rapid phase transition into ice; and economic systems enjoy periods of stability interspersed with abrupt market downturns. That is to say, many time series exhibit periods of relatively homogeneous change divided by important events. Despite this, recurrent neural networks (RNNs), popular for time series modeling, treat time in uniform intervals -potentially wasting prediction resources on long intervals of relatively constant change. One reason for this is that standard RNNs are sequence models without an explicit notion of time. Instead, the amount of time represented by a single RNN update is implicitly set by the training data. For example, a model trained on sequences of daily average temperatures has an implicit time step of a day. For a fixed computational budget, this introduces a trade-off between fidelity and temporal range. A model trained at a resolution of one time step per minute would require over 10K iterations to make a prediction for one week in the future. At the other end of the spectrum, a one-week resolution model could achieve this in a single step but could not provide information about the intervening days. As such, selecting a point on this spectrum is a troublesome design decision. In this work, we present Jumpy RNNs, a simple recurrent architecture that takes update steps at variable, data-dependent time-scales while being able to provide dense predictions at intervening points. The core innovation is to define the hidden state as a continuous, piece-wise linear function of time. Specifically, each Jumpy RNN step predicts not only a hidden state h i , but also a hidden velocity ḣi and a span of time ∆ over which the linear latent dynamics h(t) = h i + ḣi (t -i) should be applied. Our model then jumps forward in time by ∆ before updating again. Any intermediate time step can be produced by decoding the corresponding hidden state h(t). During training, our model learns to use these functions to span the non-uniform time durations between key events, where key events emerge as time points where linear latent extrapolation is

