HUNGRY HUNGRY HIPPOS: TOWARDS LANGUAGE MODELING WITH STATE SPACE MODELS

Abstract

State space models (SSMs) have demonstrated state-of-the-art sequence modeling performance in some modalities, but underperform attention in language modeling. Moreover, despite scaling nearly linearly in sequence length instead of quadratically, SSMs are still slower than Transformers due to poor hardware utilization. In this paper, we make progress on understanding the expressivity gap between SSMs and attention in language modeling, and on reducing the hardware barrier between SSMs and attention. First, we use synthetic language modeling tasks to understand the gap between SSMs and attention. We find that existing SSMs struggle with two capabilities: recalling earlier tokens in the sequence and comparing tokens across the sequence. To understand the impact on language modeling, we propose a new SSM layer, H3, that is explicitly designed for these abilities. H3 matches attention on the synthetic languages and comes within 0.4 PPL of Transformers on OpenWebText. Furthermore, a hybrid 125M-parameter H3-attention model that retains two attention layers surprisingly outperforms Transformers on OpenWebText by 1.0 PPL. Next, to improve the efficiency of training SSMs on modern hardware, we propose FLASHCONV. FLASHCONV uses a fused block FFT algorithm to improve efficiency on sequences up to 8K, and introduces a novel state passing algorithm that exploits the recurrent properties of SSMs to scale to longer sequences. FLASHCONV yields 2× speedup on the long-range arena benchmark and allows hybrid language models to generate text 2.4× faster than Transformers. Using FLASHCONV, we scale hybrid H3-attention language models up to 2.7B parameters on the Pile and find promising initial results, achieving lower perplexity than Transformers and outperforming Transformers in zeroand few-shot learning on a majority of tasks in the SuperGLUE benchmark. * Equal Contribution. Order determined by coin flip.

1. INTRODUCTION

State space models (SSMs) have achieved state-of-the-art sequence modeling performance in domains ranging from time series analysis (Gu et al., 2022a) to audio generation (Goel et al., 2022) . However, they have yet to match the performance of Transformers on language modeling, often underperforming Transformers by multiple points in perplexity (Gu et al., 2022a ). An natural question is whether this gap in performance is due to inherent inductive biases and capabilities in attention (Edelman et al., 2022; Olsson et al., 2022) , or whether it is a function of the significant organizational resources that have been spent training and tuning large attention-based language models (Chowdhery et al., 2022; Hoffmann et al., 2022; Zhang et al., 2022) , as well as specialized hardware support for attention, ranging from tensor cores (NVIDIA, 2017) to transformer chips (NVIDIA, 2022b; Kao et al., 2021) . We take first steps towards answering these questions in this paper. First, we use synthetic language modeling tasks to show that there is an expressivity gap between SSMs and attention. Using our insights, we design a new SSM layer that nearly matches attention in language modeling. Second, we propose better hardware-aware algorithms for SSMs that allow them to take advantage of modern accelerators-and run faster than attention. Understanding the Expressivity Gap. To understand the gap between SSMs and attention, we draw on synthetic language modeling tasks that have been proposed as a mechanistic basis for in-context learning in Transformers (Olsson et al., 2022) These synthetic languages focus on the ability to manipulate text-recalling tokens from earlier time steps, or comparing tokens from different points in a sequence. We find that existing SSMs struggle to model these synthetic languages. To probe how important these skills are for language modeling, we propose H3 (Hungry Hungry Hippo), a new SSM-based layer designed to solve these language modeling tasks. H3 stacks two SSMs, with multiplicative interactions between their outputs and input projections. The SSMs allow H3 to keep a log of tokens (to recall them later), while the multiplicative interactions allow for comparisons across the sequence. H3 matches attention on the synthetic languages and almost closes the gap with Transformers on language modeling-coming within 0.4 perplexity of Transformers on OpenWebText (compared to 3.4 ppl for existing SSMs-even those explicitly designed for language modeling (Mehta et al., 2022)). Furthermore, a simple hybrid H3-attention model that retains two attention layers surprisingly outperforms Transformers on OpenWebText by 1.0 perplexity. To further evaluate H3 on language modeling, we train 125M-, 355M-, 1.3B-, and 2.7B-parameter hybrid H3-attention language models on the Pile (Gao et al., 2020), using hyperparameters from GPT-3 (Brown et al., 2020) . These hybrid models outperform Transformer-based language models of the same size in perplexity, and match or outperform them on a majority of tasks in the SuperGLUE benchmark in zero-and few-shot learning. Since the SSM layers in these hybrid models admit a recurrent view, they can also perform 2.4× faster inference than Transformers. Scaling SSMs. Next, we improve the efficiency of SSMs on modern hardware, to reduce the hardware barrier between attention and SSMs. SSMs scale nearly linearly in sequence length instead of quadratically like attention, but still run slower on modern hardware due to poor hardware utilization. To close this gap, we propose FLASHCONV, a hierarchical algorithm for computing SSMs, inspired by IO-Aware attention (Dao et al., 2022b) . The technical challenge is that SSMs require a FFT-based convolution over the input sequence, which requires an FFT, pointwise multiply, and inverse FFT. When implemented in cuFFT (NVIDIA, 2022a), this operation incurs expensive GPU memory reads/writes, and cannot utilize the specialized matrix multiply units available on modern hardwarefoot_0 . To use specialized matrix multiply units, we appeal to classical techniques that split the FFT into blocks and compute it using a series of matrix multiplications. Combined with kernel fusion, this "block" FFT solution increases hardware efficiency, but only as long as the sequence length can fit into GPU SRAM (on-chip memory, analogous to L1 cache on the CPU)-up to sequence length 8K on modern A100. To scale to sequences longer than 8K, we propose a state passing algorithm (Figure 1 right), specialized to SSMs. The key insight is that we can use the recurrent properties of SSMs to process the input in chunks-as long as we keep track of an additional state vector. The state passing algorithm splits the input into the largest chunks that can fit into GPU SRAM, efficiently computes the FFT-based convolution using block FFT, and updates an intermediate state to start the next chunk. Using this state-passing algorithm, FLASHCONV can scale SSMs to any sequence length-even longer than can fit on GPU SRAM at once-while maintaining a near linear compute complexity. FLASHCONV sets state-of-the-art speed on long range arena using S4 (Gu et al., 2022a) , outperforming Transformers by 5.8× and previous S4 models by 2×. FLASHCONV trains H3 4-8× times faster than attention for long sequences, and is a critical component for scaling to billion-parameter models 2 .



An A100 GPU has a maximum of 312 TFLOPs/s of FP16 with tensor cores, but only TFLOPs/s of FP32 (and 40 TFLOPs/s of FP16) without tensor cores(NVIDIA, 2020). This trend started with the V100 GPUs (NVIDIA, 2017) and has continued with the H100 GPUs (NVIDIA, 2022b).2 Code for H3 is available at https://github.com/HazyResearch/H3.



Figure1: Left: H3 stacks two discrete SSMs with shift and diagonal matrices and uses multiplicative interactions between input projections and their outputs to model comparisons between points in a sequence. Middle: H3 can perform associative recall-which is easy for attention, but not existing SSMs. Right: FLASHCONV uses a new state-passing algorithm over fused block FFTConv to increase hardware efficiency of SSMs, allowing H3 to scale to billion-parameter models.

