LONG RANGE LANGUAGE MODELING VIA GATED STATE SPACES

Abstract

State space models have shown to be effective at modeling long range dependencies, especially on sequence classification tasks. In this work we focus on autoregressive sequence modeling over English books, Github source code and ArXiv mathematics articles. Based on recent developments around the effectiveness of gated activation functions, we propose a new layer named Gated State Space (GSS) and show that it trains significantly faster than the diagonal version of S4 (i.e. DSS) on TPUs, is competitive with several well-tuned Transformer-based baselines and exhibits zero-shot generalization to longer inputs while being straightforward to implement. Finally, we show that leveraging self-attention to model local dependencies improves the performance of GSS even further.

1. INTRODUCTION

Modeling long range dependencies on sequential data is a crucial step towards closing the gap with human-level performance on many tasks. Attention based models like Transformer (Vaswani et al., 2017) have proven to be a strong choice of backbone architecture for a considerable number of tasks across modalities and scale (Devlin et al., 2019; Brown et al., 2020; Dosovitskiy et al., 2021) . Vanilla Multi-Head-Attention famously incurs Ω(L 2 ) penalty in modeling a sequence of length L. This is prohibitive at best for tasks where the model is required to capture long range dependencies from various parts of the input. Over the years, a variety of improvements have been proposed to alleviate this quadratic complexity (Tay et al., 2020; Choromanski et al., 2021; Ramsauer et al., 2021; Wang et al., 2020; Katharopoulos et al., 2020; Peng et al., 2021; Ainslie et al., 2020; Zaheer et al., 2020; Beltagy et al., 2020; Dai et al., 2020; Kitaev et al., 2020; Vyas et al., 2020) . On a somewhat orthogonal direction, attention-free models based on state spaces, such as S4 (Gu et al., 2022a) and DSS (Gupta et al., 2022) , have shown remarkable improvements on Long Range Arena (LRA) (Tay et al., 2021) , a benchmark designed with long range modeling as its focus and consists of diverse tasks with 1k-16k sequence length across modalities. These models require careful initialization, originally borrowing ideas from the theory of HiPPO matrices (Voelker et al., 2019; Gu et al., 2020) , to achieve good results on LRA. In this work, we explore and extend the use of state space models by focusing solely on the task of autoregressive sequence modeling (Brown et al., 2020; Rae et al., 2021; Chowdhery et al., 2022; Zhang et al., 2022; Hoffmann et al., 2022; Srivastava et al., 2022) . Several key properties endowed by the state space model family makes it particularly attractive in the context of language modeling. First, it reduces the Ω(L 2 ) complexity on input sequence length to O(L log L). This complexity results from the use of Fast Fourier Transform (FFT) (Cooley & Tukey, 1965) for performing convolutions, which will be described in detail in later sections. Second, the state space model is fully parallelizable in the length dimension. This is an arguably subtle but an important property at training time. Note that transformers are also fully parallelizable, a worthy advantage over traditional RNNs for modeling sequences, which otherwise incurs only an O(L) penalty. While this parallelism is useful at training time, it may also be a curse at inference time where decoding every token requires attending to the whole past. The ideal model is parallelizable at training time but incurs a small constant cost (per decoded token) at inference time. This brings us to the final point. Due to the inherent convolutionrecurrence equivalence of the state space model, it can be made to accumulate state and unroll like an RNN at inference time without any approximations. Despite these attractive properties, we found that current state space models (e.g. S4, DSS) run slower than we expected on TPUs, our accelerator of choice. We take this opportunity to modify the architecture to reduce dimensionality of specific bottleneck operations. Our proposed changes borrow from a well-supported empirical success of gating units (Shazeer, 2020). Specifically, Hua et al. ( 2022) observed that replacing the Feed-Forward layer in the Transformer with gating units allows for a reduced dimensionality when mixing tokens along the length dimension using self-attention. We extend the use of gating units to state space model family and observe that, even in our context, the use of gating units allows for a reduction in dimensionality when performing FFT operations, which we observed to be the main bottleneck behind slow training. Furthermore, somewhat contrary to observations made by S4 and DSS authors, we found the performance on language modeling tasks to be much less sensitive to initialization: only the scale and structural aspects of initialization of state space variables were important and not the exact values. We were able to successfully train the model while initializing the state space variables randomly. This departs from the reliance of the design on the theory of HiPPO matrices, which led the S4 model to employ several numerical linear algebra tricks to able to make it work. Combining both of these contributions, we propose a layer named Gated State Space (GSS) (Figure 1 ), which we empirically verified to be 2-3× faster than DSS while keeping the perplexity on several language modeling benchmarks (Table 1 ). Going one step further, we also perform a comparison with well-tuned and performant baselines reported in Block Recurrent Transformers (Hutchins et al., 2022) , on several long range language modeling benchmarks over modalities such as English books, raw source code from Github and LaTeX source of ArXiv mathematics articles. As detailed in Table 2 , while our GSS model currently lags behind on some tasks when compared in the fixed-parameter setting, it is competitive in the fixed-compute setting where we measure compute as the amount of TPUv4 hours spent on training, which is a good proxy for the cost of training that model. Furthermore, we also experimented with a hybrid model in which we sparingly interleave Transformer layers (having local attention) in a GSS stack to allow for a richer modeling of short range interactions. To our delight, this further improves performance at (roughly) no extra training cost, both in terms of parameters and compute. In our experiments we train on sequences of length at most 4k, but evaluate on a wide range of sequence lengths up to 65k. The performance actually improves as the sequence length is increased, suggesting that GSS utilizes extra context despite not being trained with this context. Further, at inference time, state space models including GSS are quite efficient since decoding can happen in recurrent mode (as much as 60× better in the case of S4 (Gu et al., 2022a) ). Though, the hybrid model which also uses local attention complicates this advantage a bit.



Figure 1: (a) Our proposed Gated State Space (GSS) layer, (b) Pseudocode for GSS (full implementation in §A.2).

