How to Train Your HiPPO: State Space Models with Generalized Orthogonal Basis Projections

Abstract

Linear time-invariant state space models (SSM) are a classical model from engineering and statistics, that have recently been shown to be very promising in machine learning through the Structured State Space sequence model (S4). A core component of S4 involves initializing the SSM state matrix to a particular matrix called a HiPPO matrix, which was empirically important for S4's ability to handle long sequences. However, the specific matrix that S4 uses was actually derived in previous work for a particular time-varying dynamical system, and the use of this matrix as a time-invariant SSM had no known mathematical interpretation. Consequently, the theoretical mechanism by which S4 models long-range dependencies actually remains unexplained. We derive a more general and intuitive formulation of the HiPPO framework, which provides a simple mathematical interpretation of S4 as a decomposition onto exponentially-warped Legendre polynomials, explaining its ability to capture long dependencies. Our generalization introduces a theoretically rich class of SSMs that also lets us derive more intuitive S4 variants for other bases such as the Fourier basis, and explains other aspects of training S4, such as how to initialize the important timescale parameter. These insights improve S4's performance to 86% on the Long Range Arena benchmark, with 96% on the most difficult Path-X task.

1. Introduction

The Structured State Space model (S4) is a recent deep learning model based on continuoustime dynamical systems that has shown promise on a wide variety of sequence modeling tasks (Gu et al., 2022a) . It is defined as a particular linear time-invariant (LTI) state space model (SSM), which give it multiple properties (Gu et al., 2021) : as an SSM, S4 can be simulated as a discrete-time recurrence for efficiency in online or autoregressive settings, and as a LTI model, S4 can be converted into a convolution for parallelizability and computational efficiency at training time. These properties give S4 remarkable computational efficiency and performance, especially when modeling continuous signal data and long sequences. Despite its potential, several aspects of the S4 model remain poorly understood. Most notably, Gu et al. (2022a) claim that the long range abilities of S4 arise from instantiating it with a particular "HiPPO matrix" (Gu et al., 2020) . However, this matrix was actually derived in prior work for a different (time-varying) setting, and the use of this matrix in S4 (a time-invariant SSM) did not have a mathematical interpretation. Consequently, the mechanism by which S4 truly models long-range dependencies is actually not known. Beyond this initialization, several other aspects of parameterizing and training S4 remain poorly understood. For example, S4 involves an important timescale parameter ∆, and suggests a method for parameterizing and initializing this parameter, but does not discuss its meaning or provide a justification. This work aims to provide a comprehensive theoretical exposition of several aspects of S4. The major contribution of this work is a cleaner, more intuitive, and much more general formulation of the HiPPO framework. This result directly generalizes all previous known results in this line of work (Voelker et al., 2019; Gu et al., 2020; 2021; 2022a) . As immediate consequences of this framework: • We prove a theoretical interpretation of S4's state matrix A, explaining S4's ability to capture long-range dependencies via decomposing the input with respect to an infinitely long, exponentially-decaying measure. • We derive new HiPPO matrices and corresponding S4 variants that generalize other nice basis functions. For example, our new method S4-FouT produces truncated Fourier basis functions. This method thus automatically captures sliding Fourier transforms (e.g. the STFT and spectrograms), which are ubiquitous as a hand-crafted signal processing tool, and can also represent any local convolution, thus generalizing conventional CNNs. • We provide an intuitive explanation of the timescale ∆, which has a precise interpretation as controlling the length of dependencies that the model captures. Our framework makes it transparent how to initialize ∆ for a given task, as well as how to initialize the other parameters (in particular, the last SSM parameter C) to make a deep SSM variance-preserving and stable. Empirically, we validate our theory on synthetic function reconstruction and memorization tasks, showing that empirical performance of state space models in several settings is predicted by the theory. For example, our new S4-FouT method, which can provably encode a spike function as its convolution kernel, performs best on a continuous memorization task compared to other SSMs and other models, when ∆ is initialized correctly. Finally, we show that the original S4 method is still best on very long range dependencies, achieving a new state of the art of 86% average on Long Range Arena, with 96% on the most difficult Path-X task that even the other SSM variants struggle with.

2. Framework

We present our improved framework for state space models and online reconstruction of signals. Section 2.1 discusses background on SSMs, including their connection to convolutions for timeinvariant systems. Section 2.2 defines new subclasses of SSMs with special properties that can be used for online function reconstruction, simplifying and generalizing the original HiPPO framework. An extended background and related work section can be found in Appendix A.

2.1. State Space Models: A Continuous-time Latent State Model

The state space model (SSM) is defined by the differential equation ( 1) and (2). Given an input sequence u of length N , it maps a 1-D input signal u(t) to an N -D latent state x(t) before projecting to a 1-D output signal y(t). x ′ (t) = A(t)x(t)+B(t)u(t) (1) y(t) = C(t)x(t)+D(t)u(t) K(t) = Ce tA B y(t) = (K * u)(t) We will generally assume D = 0 ∈ R and omit it for simplicity, unless explicitly mentioned. SSMs can in general have dynamics that change over time, i.e. the matrix A ∈ R N ×N , and vectors B ∈ R N ×1 ,C ∈ R 1×N are a function of t in (1) and (2). However, when they are constant the system is linear time invariant (LTI), and is equivalent to a convolutional system (3). The function K(t) is called the impulse response which can also be defined as the output of the system when the input u(t) = δ(t) is the impulse or Dirac delta function. We will call these time-invariant state space models (TSSM). These are particularly important because the equivalence to a convolution makes TSSMs parallelizable and very fast to compute, which is critical for S4's efficiency. Our treatment of SSMs will consider the (A,B) parameters separately from C. We will refer to an SSM as either the tuple (A,B,C) (referring to (3)) or (A,B) (referring to Definition 1) when the context is unambiguous. We also drop the T in TSSM when the context is clearly time-invariant. Definition 1. Given a TSSM (A,B), e tA B is a vector of N functions which we call the SSM basis. The individual basis functions are denoted K n (t) = e ⊤ n e tA B, which satisfy x n (t) = (u * K n )(t) = t -∞ K n (t-s)u(s)ds. Here e n is the one-hot basis vector. This definition is motivated by noting that the SSM convolutional kernel is a linear combination of the SSM basis controlled by the vector of coefficients C, K(t) = N -1 n=0 C n K n (t).

