MULTI-HEAD STATE SPACE MODEL FOR SEQUENCE MODELING

Abstract

Recently, state space models (SSMs) have shown promising results on sequence modeling tasks. However, a potential challenge of existing works is that SSMs are usually introduced or initialized in a homogeneous way, encouraging the model to only capture similar temporal dynamics on different features. In this paper, we propose a multi-head state space model (MSSM), in which parallel heads are introduced to learn different temporal dynamics on sequence data. Furthermore, we propose a novel variant of the Transformer, referred to as the Stateformer, which combines MSSMs with attention. Experiments on large-scale automatic speech recognition (ASR) and language modeling tasks show the MSSM outperforming a range of attention-based baselines. The Stateformer further improves performance, achieving the state-of-the-art performance on the LibriSpeech ASR task.

1. INTRODUCTION

Transformers (Vaswani et al., 2017) and attention based models (Bahdanau et al., 2015) have for more than half a decade shown state-of-the-art performance on a wide range of tasks anywhere from speech recognition (Zhang et al., 2020a; Gulati et al., 2020) and neural machine translation (Ng et al., 2019; Chen et al., 2020; Tran et al., 2021) to computer vision (Dosovitskiy et al., 2021; Liu et al., 2021b) and biological applications such as protein sequence modeling (Jumper et al., 2021; Rives et al., 2021) . One of the strengths of these models comes from having a significantly smaller maximum path lengths across time, which is the shortest path linking the first encoder input and final decoder output (Hochreiter et al., 2001) . This is unlike previous state-of-the-art recurrent and convolutional neural networks which have linearly and logarithmically scaling path lengths respectively. However, a well known drawback of transformers is the quadratic space and time complexity of the self-attention layer, restricting its applicability to fields requiring longer sequences or to devices with strict limitations on compute resources (Tay et al., 2022) . To combat both of these issues, a wide range of restricted sparse attention mechanisms have been proposed, all aiming at both reducing the computational cost and scaling to longer sequences but retaining as much of the original transformer performance as possible (Katharopoulos et al., 2020; Kitaev et al., 2020; Beltagy et al., 2020; Zaheer et al., 2020) . There has also been a notable amount effort been poured into designing alternative more efficient attention schemes instead of using the dot-product approach (Wang et al., 2020a; Choromanski et al., 2020; Shen et al., 2021; Zhang et al., 2019) . Meanwhile, the machine learning community is paying more attention to a well established signal processing and control theory technique, the state space model (SSM) (Kalman, 1960) , which historically has been widely used in many time-series and control problems (Brogan, 1991; Hyndman et al., 2008; Durbin & Koopman, 2012) . More recently, the linear time-variant state space model has also been used within neural networks for improving time-series forecasting (Seeger et al., 2017; 2016; Rangapuram et al., 2018) . Furthermore, a simplified linear time-invariant (LTI) SSM, which is closely related to fully linear recurrent neural network (RNN) layers, has a well-known convolution equivalent making it particularly attractive for parallelized training while inference mode can utilize its fast recurrent formulation (Brogan, 1991; Gu et al., 2021) . One potential limitation of existing works in this direction is that SSMs are usually initialized and used in a homogeneous way. For instance, in Gu et al. (2022a) , the S4 approach equips SSMs with careful options on parameter initialization for long-range modeling ability. This design would force the model to capture similar temporal dynamics on different features. As long and short term dependency would both be useful in sequence modeling, in this paper, we develop a multi-head state space model, which consists of parallel heads to learn different temporal dynamics on sequence data. We investigate the use of multi-head state space models in sequence tasks as both a replacement and complement to attention. Since SSMs have been shown to be promising on long sequences, we hypothesize such a model is able to handle both short and long-term dependencies and can operate as an effective alternative to attention. Following are our main technical contributions around the state space model: 1. Stacked and multi-head generalization. We extend the SSM approach by allowing multihead parallel processing of projected lower-dimensional signals and stack such a layer (within a residual block) for better performance. 2. Head gating. We propose an inter-head gating approach allowing different SSMs within the multi-head layer to communicate. 3. Combination with attention. We also augment the transformer architecture by simply including a bidirectional SSM residual block prior to the attention for computationally unrestricted applications and state-of-the-art performance; referred to as the Stateformer. With these contributions, we advance the state of attention-free models on large-scale speech recognition and language modeling tasks, and show that our novel multi-head state space model outperforms strong attention-based baselines. We also show that combining the multi-head state space with attention can achieve state-of-the-art performance in large-scale speech recognition.

2. BACKGROUND

This section will cover theory and related work regarding the state space model and its use in deep learning. We first start by briefly covering the state space model and a block diagonal structured restriction. We further mention a convolution point of view for parallel training (Section 2.1). We finish with a brief discussion of the state space layer as used in neural nets (Section 2.2).

2.1. STATE SPACE MODEL: THE LINEAR RNN

A specific realization of the state space model (SSM) is a linear time-invariant (LTI) model (Brogan, 1991) that transforms some input signal u(t) ∈ R Di to some output y(t) ∈ R CDi through some hidden process x(t) ∈ R D h according to ẋ(t) = Ax(t) + Bu(t), y(t) = Cx(t) + Du(t) where Du(t) can be neglected since this is a simple term to compute. To ensure such a model is compatible with discrete signals such as sequences of speech frames or sub-word unit embeddings, a simple discretization, e.g. the bilinear method (Tustin, 1947) can be utilized: x k = Ax k-1 + Bu k , A = (I -∆A/2) -1 (I + ∆A/2), y k = Cx k , B = (I -∆A/2) -1 ∆B, C = C (2) which simply represents a fully linear recurrent neural network (RNN) cell. Note that the discretization matrix ∆ ∈ R D h ×D h is diagonal with positive entries, but can be subsumed by A, B due to scale invariance. Furthermore, while such a model is highly flexible, providing it with structure can increase its efficiency. Therefore, one can restrict the matrices A, B, C to have a block diagonal structure as follows: (3) A =     A 1 0 . . . 0 0 A 2 . . . where the structure of this block recurrent model is based on the number of input dimensions D ifoot_0 . The original multi-variate state space model can now be seen as D i smaller independent sub-SSMs, where each one of these smaller models are R → R C mappings running in parallel.



A block diagonal treatment of structured state spaces is also covered inSmith et al. (2022)

