MULTI-HEAD STATE SPACE MODEL FOR SEQUENCE MODELING

Abstract

Recently, state space models (SSMs) have shown promising results on sequence modeling tasks. However, a potential challenge of existing works is that SSMs are usually introduced or initialized in a homogeneous way, encouraging the model to only capture similar temporal dynamics on different features. In this paper, we propose a multi-head state space model (MSSM), in which parallel heads are introduced to learn different temporal dynamics on sequence data. Furthermore, we propose a novel variant of the Transformer, referred to as the Stateformer, which combines MSSMs with attention. Experiments on large-scale automatic speech recognition (ASR) and language modeling tasks show the MSSM outperforming a range of attention-based baselines. The Stateformer further improves performance, achieving the state-of-the-art performance on the LibriSpeech ASR task.

1. INTRODUCTION

Transformers (Vaswani et al., 2017) and attention based models (Bahdanau et al., 2015) have for more than half a decade shown state-of-the-art performance on a wide range of tasks anywhere from speech recognition (Zhang et al., 2020a; Gulati et al., 2020) and neural machine translation (Ng et al., 2019; Chen et al., 2020; Tran et al., 2021) to computer vision (Dosovitskiy et al., 2021; Liu et al., 2021b) and biological applications such as protein sequence modeling (Jumper et al., 2021; Rives et al., 2021) . One of the strengths of these models comes from having a significantly smaller maximum path lengths across time, which is the shortest path linking the first encoder input and final decoder output (Hochreiter et al., 2001) . This is unlike previous state-of-the-art recurrent and convolutional neural networks which have linearly and logarithmically scaling path lengths respectively. However, a well known drawback of transformers is the quadratic space and time complexity of the self-attention layer, restricting its applicability to fields requiring longer sequences or to devices with strict limitations on compute resources (Tay et al., 2022) . To combat both of these issues, a wide range of restricted sparse attention mechanisms have been proposed, all aiming at both reducing the computational cost and scaling to longer sequences but retaining as much of the original transformer performance as possible (Katharopoulos et al., 2020; Kitaev et al., 2020; Beltagy et al., 2020; Zaheer et al., 2020) . There has also been a notable amount effort been poured into designing alternative more efficient attention schemes instead of using the dot-product approach (Wang et al., 2020a; Choromanski et al., 2020; Shen et al., 2021; Zhang et al., 2019) . Meanwhile, the machine learning community is paying more attention to a well established signal processing and control theory technique, the state space model (SSM) (Kalman, 1960) , which historically has been widely used in many time-series and control problems (Brogan, 1991; Hyndman et al., 2008; Durbin & Koopman, 2012) . More recently, the linear time-variant state space model has also been used within neural networks for improving time-series forecasting (Seeger et al., 2017; 2016; Rangapuram et al., 2018) . Furthermore, a simplified linear time-invariant (LTI) SSM, which is closely related to fully linear recurrent neural network (RNN) layers, has a well-known convolution equivalent making it particularly attractive for parallelized training while inference mode can utilize its fast recurrent formulation (Brogan, 1991; Gu et al., 2021) . One potential limitation of existing works in this direction is that SSMs are usually initialized and used in a homogeneous way. For instance, in Gu et al. (2022a) , the S4 approach equips SSMs with careful options on parameter initialization for long-range modeling ability. This design would force the model to capture similar temporal dynamics on different features. As long and short term dependency would both be useful in sequence modeling, in this paper, we develop a multi-head state space model, which consists of parallel heads to learn different temporal dynamics on sequence data. We investigate the use of multi-head state space models in sequence tasks as both a replacement and complement to attention. Since SSMs have been shown to be promising on long sequences, we hypothesize such a model is able to handle both short and long-term dependencies and can operate as an effective alternative to attention. Following are our main technical contributions around the state space model: 1. Stacked and multi-head generalization. We extend the SSM approach by allowing multihead parallel processing of projected lower-dimensional signals and stack such a layer (within a residual block) for better performance. 2. Head gating. We propose an inter-head gating approach allowing different SSMs within the multi-head layer to communicate. 3. Combination with attention. We also augment the transformer architecture by simply including a bidirectional SSM residual block prior to the attention for computationally unrestricted applications and state-of-the-art performance; referred to as the Stateformer. With these contributions, we advance the state of attention-free models on large-scale speech recognition and language modeling tasks, and show that our novel multi-head state space model outperforms strong attention-based baselines. We also show that combining the multi-head state space with attention can achieve state-of-the-art performance in large-scale speech recognition.

2. BACKGROUND

This section will cover theory and related work regarding the state space model and its use in deep learning. We first start by briefly covering the state space model and a block diagonal structured restriction. We further mention a convolution point of view for parallel training (Section 2.1). We finish with a brief discussion of the state space layer as used in neural nets (Section 2.2).

2.1. STATE SPACE MODEL: THE LINEAR RNN

A specific realization of the state space model (SSM) is a linear time-invariant (LTI) model (Brogan, 1991) that transforms some input signal u(t) ∈ R Di to some output y(t) ∈ R CDi through some hidden process x(t) ∈ R D h according to ẋ(t) = Ax(t) + Bu(t), y(t) = Cx(t) + Du(t) where Du(t) can be neglected since this is a simple term to compute. To ensure such a model is compatible with discrete signals such as sequences of speech frames or sub-word unit embeddings, a simple discretization, e.g. the bilinear method (Tustin, 1947 ) can be utilized: x k = Ax k-1 + Bu k , A = (I -∆A/2) -1 (I + ∆A/2), y k = Cx k , B = (I -∆A/2) -1 ∆B, C = C (2) which simply represents a fully linear recurrent neural network (RNN) cell. Note that the discretization matrix ∆ ∈ R D h ×D h is diagonal with positive entries, but can be subsumed by A, B due to scale invariance. Furthermore, while such a model is highly flexible, providing it with structure can increase its efficiency. Therefore, one can restrict the matrices A, B, C to have a block diagonal structure as follows: A =     A 1 0 . . . 0 0 A 2 . . . 0 . . . . . . . . . . . . 0 0 . . . A Di     , B =     B 1 0 . . . 0 0 B 2 . . . 0 . . . . . . . . . . . . 0 0 . . . B Di     , C =     C 1 0 . . . 0 0 C 2 . . . 0 . . . . . . . . . . . . 0 0 . . . C Di     (3) where the structure of this block recurrent model is based on the number of input dimensions D ifoot_0 . The original multi-variate state space model can now be seen as D i smaller independent sub-SSMs, where each one of these smaller models are R → R C mappings running in parallel.

Denoting the sub-hidden dimension by

D sh = D h /D i we obtain A n ∈ R D sh ×D sh , B n ∈ R D sh ×1 and C n ∈ R C×D sh . The skip-connection can also be split into blocks of D n ∈ R C×1 . Similarly, the discretization matrix can now be seen as a diagonal matrix with a isotropic blocks ∆ n = ∆ n I of the same size as the transition blocks A n , meaning that each sub-state space model has its own discretization time-step that can be tuned to its own input channel. Furthermore, due to the model being fully linear, one can utilize the widely known duality between convolutions and recurrences, resulting in an efficient expression: K = (CB, CAB, . . . , CA L B), y = K * u (4) where L is the input sequence length. Therefore, state space models and linear RNNs can be trained in convolution mode (Equation 4), fully utilizing the power of GPUs, and be deployed using recurrent mode (Equation 2), efficiently and incrementally generating predictions. To further improve the efficiency, a specific diagonal plus low-rank restriction of the sub-transition blocks can be used Gu et al. (2022a) , A n = Λ n + p n q T n , p n , q n ∈ R D sh ∀n = {1, . . . , D i } (5) where Λ n is a diagonal matrix. This allows the kernel K to be computed efficiently utilizing the Woodbury matrix identity with several fast implementations (Pan, 2001; 2014) .

2.2. PRACTICAL LAYER DESIGN

The structured state space model described in Equation 2 is a general linear R Di → R CDi mapping but processes each channel independently due to the structuring in Equation 3. To amend both these problems, a non-linear activation is introduced followed by a point-wise linear layer (see Figure 1a for a pre-norm residual unit) to mix the information across input channels: f (u) = Linear(Act(SSM(u))) Note the output dimension expansion C is set to 1 when using element-wise activations such as GELU (Hendrycks & Gimpel, 2016) or set to 2 when using a gating mechanisms such as GLU (Dauphin et al., 2017) . To further stabilize the state space, Goel et al. (2022) used the more stable Hurwitz formulation A n = Λ n -p n p T n , ensuring the eigenvalues of the transition matrix remain negative when Λ n is negative definite. Practically, the parameters are stored in log-form: time-step ∆ n = exp(z ∆n ) and diagonal Λ n = -exp(z Λn ). Regarding the initialization of this system, Gu et al. (2020) developed several options for how the transition and input matrices should be set to equip the model with long-range modeling ability which can be motivated from a memorization of signals point of view. However, all of the parameters were made trainable, discarding the initialization.

3. MULTI-HEAD STATE SPACE MODEL

This section introduces the novel multi-head state space model as a natural extension to the linear time-invariant structured state space model as described above. First, we discuss our configuration for the state space layer and its intialization (Section 3.1). Second, we propose novel ideas for how the architecture can be improved, including a multi-head generalization (Section 3.2). Third, a new gating mechanism is introduced which was found notably beneficial (Section 3.3). Furthermore, a bidirectional architecture is discussed (Section 3.4) and how it can be used with attention based models (Section 3.5). Finally, the compute complexity of state space models scales only linearly with sequence length allowing it to be used prior to sub-sampling layers without much added cost (Section 3.6).

3.1. TIED PARAMETERS & INITIALIZATION

The original design in Equation 3 describes a block diagonal simplification to a multi-variate SSM for purposes of efficiency, which can also be seen as several small independent SSMs running in parallel. However, this SSM required a smaller separate learning rate for its parameters compared to the rest of the network to achieve the best possible performance. In this work we opt for a slightly more parameter efficient blockwise isotropic implementation, similar to Gu et al. (2022b) , where we restrict the copies to have shared input and transition matrices across all blocks: A n = Λ n -p n p T n = A 0 = Λ 0 -p 0 p T 0 B n = B 0 ∀n ∈ {1, ..., D i } (7) but keep independent output C n and skip D n matrices. We also let the learning rate remain the same as the remainder of the network. By still having different time-step discretizations for each dimension of the input sequence, temporal features of different scale can be extracted while sharing the recurrence parameters. Furthermore, we choose to simply random initialize all parameters of the SSM by sampling from a the unit Gaussian distribution N (0, 1), apart from the diagonal matrix which is shifted Λ 0 = -exp(z Λ0 ), z Λ0 ∼ N (ln(scale), 1). This is to ensure that the initial eigenvalues of the transition matrix are not too small and can capture enough of the input signal's history. Compared with Gu et al. (2020) , the random initialization is simple but shows good performance in practice. The effectiveness of random initialization was also observed in Mehta et al. (2022) . Also, it provides the potential for the multi-head state space model (see section below) in breaking the symmetry between heads instead of using the same deterministic initialization.

3.2. STACKED & MULTI-HEAD GENERALIZATION

The basic architecture described above is highly parameter efficient, dominated by the linear layer, and is significantly smaller than multi-head attention and feed-forward nets (FFN). Therefore, the design is expanded by stacking y = f s (f s-1 (...f 1 (u))) each with its own set of parameters, see an example in Figure 1b . This would balance the parameters in terms of the ratio between FFN and SSM layers and give the whole model more depthwise/temporal flexibility. Moreover, we found it useful to project the input into several different lower-dimensional representations which are passed on to independent structured state space models, in a design similar to multihead attention, see Figure 2a . Not only does having shared recurrence parameters across a structured SSM stabilize training, but it gives each structured model in a multi-head framework the flexibility to learn both meaningful time-steps and different types of temporal dynamics. This stacked multi-head state space model (MSSM) requires an additional initial projection layer compared to its single-head implementation, leading to a small increase in parameters. In practice, the projection would take a sequence of D i -dimensional features and project it down to H ∈ {2, 4, 8, . . . } different, Di = D i /H-dimensional signals, each of which is processed by isotropic blockwise diagonal SSMs. Since, the effective dimension H • Di = D i remains the same, the cost of a multi-headed approach boils down to an extra linear projection layer.

3.3. INTER-HEAD GATED LINEAR UNIT

By default, prior work has used the GELU activation function. In some experiments (Gu et al., 2022a; Goel et al., 2022; Mehta et al., 2022) the GLU activation (with C = 2) was also found beneficial which was achieved by gating the channels of an individual sub-state space model: a n = y n1 • σ(y n2 ), y n ∈ R 2 , ∀n ∈ {1, ..., D i } (8) However, a multi-head state space architecture with H headed outputs {y (h) n ∈ R 2 } H, Di h,n=1 , offers many more gating possibilities. We propose an inter-head gating (IHG) approach in which the output is computed by mixing the heads according to (where σ refers to the sigmoid): a (h) nc = y (h) nc • σ(y (h+H/2) nc ), ∀c = {1, 2} h = {1, ..., H/2} allowing different heads to communicate and gate each other, generally leading to improved results. Clearly, flattening out the output in Equation 9 leads to the original D i -dimensional output, and note that the number of heads in this approach has to be even, see Figure 2b . While our approach is wildly different, the concept of letting heads communicate has also been proposed in Shazeer et al. (2020) in the context of attention where linear layers mix information across heads, and was found to improve performance. However, unlike the last-mentioned, our approach allows for communication between different heads without any extra parameters.

3.4. BIDIRECTIONAL DESIGN

So far the proposed architecture has a unidirectional design however, for better performance, bidirectional context is needed. Therefore, we simply concatenate the output of independent forward and backward stacked multi-head modules and project it back down to the original input dimension (Cornegruta et al., 2016; Goel et al., 2022) . This new module is then wrapped in a pre-norm block similar to the architectures above: y ← Cat([MSSM(x), Rev(MSSM(Rev(x)))]) y ← Linear(Act(y)) Since there is a linear layer in the bidirectional block, the final linear layer of the MSSM modules are removed. This attention free design has many advantages over multi-head attention such as memory requirements scaling linearly with sequence length, making it computationally feasible for longer sequences.

3.5. STATEFORMER: STATE SPACE AUGMENTED TRANSFORMER

A pure multi-head state space model architecture is attractive due to its efficiency and ability to capture both short and long range dependencies. However, since the state space model is equivalent to a linear RNN, it is expressively more limited, and is still plagued by a linearly scaling path length. Therefore, for scenarios where one is less restricted by compute limitations, we propose a model combining the MSSM with attention by simply inserting a pre-norm (bidirectional) block prior to the self-attention unit in the transformer architecture, referred to as the Stateformer, see Figure 3 . 10. This type of design is inspired by the Conformer (Gulati et al., 2020) where a dedicated convolutional block is inserted to handle local relations. However, we hypothesize that the advantage of the Stateformer, is the MSSM with its variable length kernel and the multi-head attention both can handle short and long-term dependencies and can complement each other. Furthermore, the attention enables maximum path lengths of 1, resulting in a possibly more flexible model.

3.6. MULTI-SCALE ARCHITECTURE

In many long sequence problems, it is standard practice to sub-sample or stride the input sequence before being fed to the main model for reducing computation. For example, both Listen Attend and Spell (Chan et al., 2016) and transducer (Graves, 2012) based models for speech recognition utilize this type of sub-sampling to achieve lower frame rates and shorter sequences. In this work we investigate the multi-scale architecture, using MSSM layers prior to the striding, for sub-sampling. There are two advantages of this multi-scale design: (1) MSSMs scale linearly and can potentially handle longer sequences; (2) processing the sequence prior to sub-sampling can introduce better contextual information.

4. EXPERIMENTS

We investigate our proposed models on large-scale automatic speech recognition (ASR) and masked language modeling (MLM) tasks, and compare with a range of transformer-based schemes on each task. Baseline and proposed models are implemented using an extension of the Fairseq framework (Ott et al., 2019) .

4.1. AUTOMATIC SPEECH RECOGNITION

We first evaluate the proposed models on automatic speech recognition using the widely known LibriSpeech corpus (Panayotov et al., 2015) , consisting of about 960 hours of speech data sampled at 16kHz. All speech recognition models were built in the transducer framework (Graves, 2012) , which has three components: encoder, predictor and joiner. We train various transducers by keeping the predictor and joiner fixed, and compare attention-based and S4 (Gu et al., 2022a) baselines, and the proposed models on the encoder. In these experiments, we used a sliding window of 25ms with a 10ms frame shift to extract 80dimension log Mel-filter bank features which were immediately linearly projected to a dimension of 128. Depending on the model, we used one of the following sub-samplers (SS) in order to reduce the frame rate to 40ms and increase the dimensionality to 512: • Conv: 2D-convolution network with an effective stride of 4 and 512 output channels (Wang et al., 2020b; Gulati et al., 2020) . • TR: Time-reduction layer which stacks each 4 consecutive frames (stride = 4). • MS: Multi-scale front introduces MSSM layers prior to time-reductions (Section 3.6). 2 MSSM blocks are passed on to TR(stride = 2) followed by 4 MSSM blocks and a final TR(stride = 2). The attention-based baselines used the convolutional sub-sampler and consist of 20-36 Transformer or Conformer (Gulati et al., 2020) blocks with 512-dimensional features, 8 attention heads, and feed-forward nets with a dimension expansion of 4. Different to Gulati et al. (2020) , our Conformer baseline did not use a macaron style block as it was not found useful, and had the convolutional module prior to the attenion with a kernel size of 31. Also, an S4 model was prepared, for it is a single-head SSM baseline with careful parameter initialization; the scaled Legendre (LegS) initialization is used. The S4 baseline and our proposed models simply replace the attention block of the transformer layer, and the configuration of this layer, such as number of stacks, heads, use of inter-head gating and sub-sampler design, is investigated; D sh = 32 is fixed for all experiments. The Stateformer uses the same parameters as the transformer combined with the best MSSM configuration. The predictor will remain fixed for all experiments and contain three 512-dimensional LSTM layers with layer-norm and dropout. The outputs of both encoder and predictor are projected to 1024 dimensions and fed into an additive joiner with a single linear layer of |Y| = 4097 sentence-piece (Kudo & Richardson, 2018) output units. All models used the Adam optimizer (Kingma & Ba, 2015) with a learning rate linearly warming up to the peak value in 8000 iterations, fixed until the 60th epoch and thereafter, exponentially decayed by a factor of 0.96 each epoch. A dropout value of 0.10 is used in all encoders and 0.30 in all predictors and the batch size is set based on occupying maximal GPU memory. All models were trained up to 200 epochs using 32 NVIDIA A100 GPUs. In evaluation, decoding results are generated using the beam search algorithm on the transducers without any external language model. Best checkpoints for test-sets are selected on the respective dev-sets. First, the impact of SSM stack depth is examined. As reported in Table 1 , the double stacked 16-layer configuration (illustrated in Figure 1b ) achieves similar performance to the 20-layer nonstacking model. This demonstrates that the double stacked design is more effective and will be the default topology used in subsequent experiments. Second, in Table 2 , we compare the impact of multi-head, inter-head gating and sub-sampler designs on the 16-layer MSSM models. Although directly increasing the number the number of heads from 1 to 2 shows limited improvement, by turning on inter-head gating one can observe consistent gains on multi-head models. Furthermore, additional improvement can be achieved when increasing the number of heads from 2 to 8, keeping IHG activated. Continuing, we replace the time-reduction layer with a multi-scale sub-sampler which notably improves performance. Note the second row is a transformer model with a simple time-reduction (instead of a convolutional) sub-sampler, which can be seen as a fair baseline for MSSMs without the multi-scale option. In this case, the proposed models significantly outperform the respective transformer baselines of similar size. Next we couple attention with the best MSSM configuration to form the Stateformer (described in Section 3.5). The final row of Table 2 indicates some degree of complementarity, that further gains can be achieved when combining both of these modeling techniques. We also scale the model size up to 140M parameters to verify the performance of a large configuration. Baselines and the proposed models in this large configuration are trained using auxiliary classifiers similar to Szegedy et al. (2015) , in which intermediate encoder outputs are trained to predict frame labels every 4 layers. Table 3 compares the proposed models with large-size baselines and state-of-the-art numbers in previous worksfoot_1 . Both the 32-layer MSSM and 25-layer Stateformer outperform our large-size baselines and a range of attention-based results in literature.

4.2. MASKED LANGUAGE MODELING

To explore our model's effectiveness on other modalities, we consider the widely used masked language modeling task on long sequences (Devlin et al., 2018) . We follow Beltagy et al. (2020) in preparing a large-scale long-document corpus, including Stories (Trinh & Le, 2018) , Real-News (Zellers et al., 2019) , Books (Zhu et al., 2015) and English Wikipedia. This corpus, consisting of 6.6 billion tokens, is loaded as in Beltagy et al. (2020) in terms of 4,096-token long sequences. It is prohibitive to apply the full attention in transformers to the 4096-token sequences due to the quadratic memory complexities. Therefore, we pick a few feasible transformer variants as baselines, including sliding-window attention (Beltagy et al., 2020) , Performer (Choromanski et al., 2020) and Linformer (Wang et al., 2020a) . The sliding window has a window size of 256; Performer uses a random feature dimension of 256 and ReLU as kernel function; Linformer has a compression ratio of 8. The MSSM model with D sh = 32 has 8 heads; remaining hyperparameters are shared with attention baselines. To validate that our model works well at scale, we test the attention and MSSM layers using a 24layer model similar in size to RoBERTa-large (Liu et al., 2019) . To make a fair comparison, training stops after the same number of data batches. Following Xiong et al. (2021) , baselines are trained using the standard masked language modeling objective for 100K steps with a per GPU batch size of 8 on 32 A100 GPUsfoot_2 . The MSSM is trained for 50K steps with a per GPU batch size of 8 on 64 GPUs, for rapid development. Perplexities (PPLs) are reported on the validation set. 



A block diagonal treatment of structured state spaces is also covered inSmith et al. (2022) The Conformer performance varies significantly in prior work. We list a few recent works for comparison. The training curves of these models are nearly stable after 100k steps.



Figure1: The standard block used in previous work and an example of double stacked block. Each block is followed by a feed-forward net similar to transformers.

Figure 3: Stateformer: Bidirectional state space augmented transformer encoder example. It simply has an additional block prior to the attention unit with a bidirectional MSSM as seen in Equation10.

Impact of stack depth of state space model in one layer on Librispeech.

Performance of base configuration on Librispeech. #H is the number of heads; IHG is inter-head gating (Section 3.3). SS stands for the sub-sampling scheme used in the bottom layer.

Performance of large configurations on Librispeech, compared with best results in literature.

summarizes the perplexity of different models. The proposed MSSM achieves performance on-par with attention baselines, but shows notably improved perplexity with further training (20K

annex

additional updates). Although the MSSM is expected to learn information from the full sequence, the performance gap to the sliding window is smaller than expected. To better understand the behavior of the proposed model, we analyze the learned state space kernels in Section 4.3.

4.3. MULTI-HEAD INTRODUCES MULTI-MODAL AND MULTI-SCALE MODELING POWER

Now that we have empirically demonstrate the effectiveness of multi-head state space models, we want to answer the following research question: what does the multi heads learn from data? Since the state space model can be presented as a convolution (Equation 4), the shape of the kernel filters can be studied to answer this question.(A) (B) We first look into the learned multi-head kernels from ASR model (Figure 4B ), and compare them with the learned single-head kernel (Figure 4A ). We find the the single head kernel learned a simple monotonic decaying pattern, which computes the weighted average of history in a sliding window. On the other hand, the multi-head kernels are much more interesting -while the first head is still monotonic decaying, the kernels from the second and third heads involve context with different context concentration. In other words, multi-head state space models are able to introduce multimodal modeling power.The kernels from the MLM tasks bring additional insights. To begin with, kernels from the first three heads shows the same pattern as the ASR kernels (Figure 4C ), reaffirming the multi-modal modeling power on the MLM tasks. Also, this localized behavior can be viewed similarly to that of the sliding-window attention. Additionally, kernels from the next four heads (Figure 4D ) are monotonic with different decay rates, which captures history with multi-scale granularity.

5. CONCLUSION

We propose the multi-head state space model (MSSM) for sequence modeling in this paper. Parallel heads of SSMs, together with inter-head gating, are introduced, enabling multiple temporal dynamics to be effectively learned on sequence data. In addition, we combine the MSSM and attention models to form a new type of transformer architecture, referred to as Stateformer. On the LibriSpeech speech recognition task, the proposed MSSM outperforms transformer baseline by a large margin. Besides, the Stateformer further improves the performance, achieving state-of-theart performance. On the masked language modeling task, the MSSM achieves on-par performance compared with a range of attention-based baselines trained on a large corpus.

