LONG RANGE LANGUAGE MODELING VIA GATED STATE SPACES

Abstract

State space models have shown to be effective at modeling long range dependencies, especially on sequence classification tasks. In this work we focus on autoregressive sequence modeling over English books, Github source code and ArXiv mathematics articles. Based on recent developments around the effectiveness of gated activation functions, we propose a new layer named Gated State Space (GSS) and show that it trains significantly faster than the diagonal version of S4 (i.e. DSS) on TPUs, is competitive with several well-tuned Transformer-based baselines and exhibits zero-shot generalization to longer inputs while being straightforward to implement. Finally, we show that leveraging self-attention to model local dependencies improves the performance of GSS even further.

1. INTRODUCTION

Modeling long range dependencies on sequential data is a crucial step towards closing the gap with human-level performance on many tasks. Attention based models like Transformer (Vaswani et al., 2017) have proven to be a strong choice of backbone architecture for a considerable number of tasks across modalities and scale (Devlin et al., 2019; Brown et al., 2020; Dosovitskiy et al., 2021) . Vanilla Multi-Head-Attention famously incurs Ω(L 2 ) penalty in modeling a sequence of length L. This is prohibitive at best for tasks where the model is required to capture long range dependencies from various parts of the input. Over the years, a variety of improvements have been proposed to alleviate this quadratic complexity (Tay et al., 2020; Choromanski et al., 2021; Ramsauer et al., 2021; Wang et al., 2020; Katharopoulos et al., 2020; Peng et al., 2021; Ainslie et al., 2020; Zaheer et al., 2020; Beltagy et al., 2020; Dai et al., 2020; Kitaev et al., 2020; Vyas et al., 2020) . On a somewhat orthogonal direction, attention-free models based on state spaces, such as S4 (Gu et al., 2022a) and DSS (Gupta et al., 2022) , have shown remarkable improvements on Long Range Arena (LRA) (Tay et al., 2021), a benchmark designed with long range modeling as its focus and consists of diverse tasks with 1k-16k sequence length across modalities. These models require careful initialization, originally borrowing ideas from the theory of HiPPO matrices (Voelker et al., 2019; Gu et al., 2020) , to achieve good results on LRA. In this work, we explore and extend the use of state space models by focusing solely on the task of autoregressive sequence modeling (Brown et al., 2020; Rae et al., 2021; Chowdhery et al., 2022; Zhang et al., 2022; Hoffmann et al., 2022; Srivastava et al., 2022) . Several key properties endowed by the state space model family makes it particularly attractive in the context of language modeling. First, it reduces the Ω(L 2 ) complexity on input sequence length to O(L log L). This complexity results from the use of Fast Fourier Transform (FFT) (Cooley & Tukey, 1965) for performing convolutions, which will be described in detail in later sections. Second, the state space model is fully parallelizable in the length dimension. This is an arguably subtle but an important property at training time. Note

