MEGA: MOVING AVERAGE EQUIPPED GATED ATTENTION

Abstract

The design choices in the Transformer attention mechanism, including weak inductive bias and quadratic computational complexity, have limited its application for modeling long sequences. In this paper, we introduce MEGA, a simple, theoretically grounded, single-head gated attention mechanism equipped with (exponential) moving average to incorporate inductive bias of position-aware local dependencies into the position-agnostic attention mechanism. We further propose a variant of MEGA that offers linear time and space complexity yet yields only minimal quality loss, by efficiently splitting the whole sequence into multiple chunks with fixed length. Extensive experiments on a wide range of sequence modeling benchmarks, including the Long Range Arena, neural machine translation, auto-regressive language modeling, and image and speech classification, show that MEGA achieves significant improvements over other sequence models, including variants of Transformers and recent state space models.

1. INTRODUCTION

Designing a single unified model to capture long range dependencies in sequential data is a central and challenging problem in sequence modeling. A number of different archtectures have been developed, including convolutional neural networks (CNNs) (Kim, 2014; Strubell et al., 2017) , recurrent neural networks (RNNs) (Goller & Kuchler, 1996; Hochreiter & Schmidhuber, 1997) , Transformers (Vaswani et al., 2017) and recent state space models (SSMs) (Gu et al., 2022a; Mehta et al., 2022) . Among these models, Transformer (Vaswani et al., 2017) has stood out for its impressive empirical success on a wide range of language and vision tasks, including machine translation (Vaswani et al., 2017; Ott et al., 2018 ), language understanding (Devlin et al., 2019; Liu et al., 2019) , image recognition (Dosovitskiy et al., 2020; Touvron et al., 2021) and genetic sequence modeling (Madani et al., 2020; Jumper et al., 2021) , mainly because of the conceptually attractive attention mechanism (Bahdanau et al., 2015) which directly models interactions between each pair of input tokens. However, there are two common drawbacks in this design: i) weak inductive bias; and ii) quadratic computational complexity. First, the attention mechanism does not assume prior knowledge of the patterns of dependencies between tokens (e.g. positional inductive bias), instead learning to predict the pairwise attention weights directly from data. Second, the cost to compute and store the attention weights is quadratic in the input length. Recent studies have shown the limitations of applying Transformers to longer sequences, both with respect to accuracy and efficiency (Tay et al., 2020) . In this work, we propose a moving average equipped gated attention mechanism (MEGA) to solve the two weaknesses simultaneously. The key idea is to incorporate inductive biases into the attention mechanism across the timestep dimension, by leveraging the classic exponential moving average (EMA) approach (Hunter, 1986) . EMA captures local dependencies that exponentially decay over time (see Figure 1 ), and has been widely used in time series data modeling ( §2). We introduce a multi-dimensional damped form of EMA with learnable coefficients ( §3.1), and subsequently develop the moving average equipped gated attention mechanism by integrating the EMA with a variant of the single-head gated attention (Hua et al., 2022) ( §3.2). Theoretically, we show that the single-head gated attention is as expressive as the most commonly used multi-head attention ( §3.5). Benefiting Experimentally, on five sequence modeling tasks across various data types, including long-context sequence modeling, neural machine translation, auto-regressive language modeling, and image and speech classification, MEGA significantly outperforms a variety of strong baseline models, in terms of both effectiveness and efficiency ( §4) (see Table 1 ).

2. BACKGROUND

We use X = {x 1 , x 2 , . . . , x n } ∈ R n×d to denote a sequence of input representations with length n. Let Y = {y 1 , y 2 , . . . , y n } ∈ R n×d be the sequence of output representations of each layer with the same length n as the input X. In this paper, we assume the representations of the input and output sequences have the same dimension d.

2.1. SELF-ATTENTION MECHANISM

The traditional self-attention mechanism is a function: Y = Attn(X) = f QK T τ (X) V where Attn : R n×d → R n×d is the self-attention function. Q = XW q + b q , K = XW k + b k , and V = XW v + b v are the sequences of queries, keys and values, with learnable parameters et al., 2021; Hua et al., 2022) . τ (X) is a scaling term, which is commonly set to τ (X) = √ d for f softmax (•), or τ (X) = n for f relu 2 (•). The commonly used multi-head variant of attention performs the attention function h times in parallel. W q , W k , W v ∈ R d×d , and b q , b k , b v ∈ R d . f We can define a matrix A = f ( QK T τ (X) ) ∈ R n×n following (1), which is called the attention matrix. Since it specifies pairwise dependency weights, the matrix A in principle delivers a flexible and powerful mechanism to learn long-distance dependencies with minimal inductive biases. However, it is in practice a challenging task to recognize all the dependency patterns directly from data, particularly when processing long sequences. Moreover, calculating A with h attention heads takes O(hn 2 ) time and space, which becomes a significant bottleneck.

2.2. EXPONENTIAL MOVING AVERAGE (EMA)

The moving average is a classic approach for sequential data modeling, which has been widely used in time series data to smooth out short-term fluctuations and highlight long-term trends or cycles. The Exponential Moving Average (EMA) (Winters, 1960; Hunter, 1986) , a special case of moving average, applies weighting factors that decrease exponentially: y t = α ⊙ x t + (1 -α) ⊙ y t-1 , where α ∈ (0, 1) d is the EMA coefficient representing the degree of weighting decrease, and ⊙ is the element-wise product. A higher α discounts older observations faster (see Figure 1 ). Using an EMA places a strong inductive bias on the learning of pairwise dependencies: the dependency weight between two tokens decreases exponentially over time with an input-agnostic decay factor α. This property favors local dependencies, and limits long-distance dependencies. Despite the recurrent formulation in (2), the computation of EMA can be represented as n individual convolutions, which can be computed efficiently using fast Fourier transforms (FFTs) (see Appendix A for details).



•) is an attention function, e.g. the softmax function f softmax (•) (Bahdanau et al., 2015), or the recently proposed squared ReLU function f relu 2 (•) (So

Experimental results of Transformer (XFM), S4 and MEGA on five sequence modeling benchmarks, including long range arena (LRA), machine translation (WMT16 en-de), language modeling (WikiText-103), image classification (ImageNet-1k), raw speech classification (SC-Raw).LRA (Acc. ↑) WMT16 (BLEU ↑) WT103 (PPL. ↓) ImageNet (Acc. ↑) SC (Acc. ↑)

