AN ATTENTION FREE TRANSFORMER

Abstract

We introduce Attention Free Transformer (AFT), an efficient variant of Transformers (Vaswani et al., 2017) that eliminates the need for dot product attention. AFT offers great simplicity and efficiency compared with standard Transformers, where the multi-head attention operation is replaced with the composition of element-wise multiplications/divisions and global/local pooling. During training time, AFT has linear time and space complexity w.r.t. both the sequence length and feature dimension; in the autoregressive decoding mode, AFT has constant memory and time complexity per step. We show that, surprisingly, we are able to train AFT effectively on challenging benchmarks, and also to match or surpass the standard Transformer counterparts and other efficient variants. In particular, AFT achieves the state-of-the-art result on CIFAR10 autoregressive modeling with much reduced complexity, and also outperforms several efficient Transformer variants on Enwik8.

1. INTRODUCTION

Attention mechanisms, represented by Transformers (Vaswani et al., 2017) , have driven the advancement of various machine learning problems, including language modeling (Devlin et al., 2018; Radford et al.) , image modeling (Chen et al.), and set modeling (Lee et al., 2019) . Different from other well known model architectures such as Convolutional Neural Nets (CNNs) or Recurrent Neural Nets (RNNs), Transformers enable direct interaction between every pair of elements within a sequence, which makes them especially powerful at capturing long term dependencies. However, Transformers require high computational costs. The root cause of this challenge is the need to perform attention operations that have quadratic time and space complexity w.r.t the context size. This makes it especially difficult for Transformers to scale to inputs with large context sizes. A number of recent works have been dedicated to addressing the scalability issue of Transformers (Child et al., 2019; Kitaev et al., 2020; Rae et al., 2020; Wang et al., 2020b; Katharopoulos et al., 2020; Tay et al., 2020a; Choromanski et al., 2020) . While the techniques adopted in the literature range from sparsity, locality sensitive hashing, low rank decomposition, kernel approximation and etc., most of them are trying to approximate the full attention operation. In this paper, we take a bolder step towards the same goal, by proposing a computational module that does not use or approximate the standard dot product attention. We hence name our model the attention free transformer (AFT). Similar to dot product attention, AFT is composed of the interaction of three quantities, namely the query, key and value. What's different, however, is that AFT operates solely based on element-wise operations. To be more concrete, they key and value are first multiplied element-wise, the result of which is then pooled over the context dimension (in the causal model, this corresponds to a cumulative sum). The query is then multiplied with the reduced key-value representation element-wise to produce the final output. See Figure 1a for an illustration. AFT maintains the full advantage of dot product attention, namely direct interaction between any two elements in a sequence (up to proper masking). However, the computational cost is drastically reduced to a O(T d) complexity for time and space, where T, d are the context length and feature dimension, respectively. In the autoregressive decoding mode, AFT also provides constant decoding time and space complexity per step, compared to O(T ) for standard transformers. To the best of our knowledge, AFT is the first model that achieves such efficiency in the context of Transformers. See Table 1 for the complexity analysis of AFT in comparison to other variants. 

Model

Time @ train Space @ train Time/step @ decode Space/step @ decode Full Attention O(T 2 d) O(T 2 + T d) O(T d) O(T d) Reformer O(T log T d) O(T log T + T d) O(log T + d) O(T d) Synthesizer O(T 2 d) O(T 2 + T d) O(T d) O(T d) Linear Transformer O(T d 2 ) O(T d + d 2 ) O(d 2 ) O(d 2 ) AFT (ours) O(Td) O(Td) O(d) O(d) We show that we can interpret AFT as an extreme case of multi head dot product attention (MHA). In particular, we show that by 1) setting the number of heads equal to the feature dimension in MHA and 2) using relu in place of sof tmax as the non-linearity, MHA can decomposed into the summation of two AFT modules (see Equation 6). However, this relationship is not true in a general sense, i.e., by varying the non-linearity injected after the query and key in AFT, we can obtain models that do not have a MHA counterpart. This realization allows us to freely explore the design choices (e.g., nonlinearity) of AFT to achieve the best performance. This philosophy is in direct contrast with previous and concurrent "linearized attention" works (Katharopoulos et al., 2020; Choromanski et al., 2020) , which are constrained by the design space of MHA. We perform experiments with AFT on several benchmarks, including unconditional image modeling, image super-resolution, language modeling, machine translation and point cloud generation. We show that AFT works very well as an alternative to the standard Transformer, providing competitive results as well as excellent efficiency. To summarize, our contributions are as follows: • 

2. MULTI-HEAD ATTENTION

At the core of Transformers is the Multi-Head Attention (MHA) operation. Given three sequences, namely the query Q ∈ R T ×d , key K ∈ R T ×d and value V ∈ R T ×d , and the number of heads h, MHA performs a scaled dot product attention for each head i, defined as: f i (Q, K, V ) = σ( Q i (K i ) T √ d k )V i , s.t. Q i = QW Q i , K i = KW K i , V i = V W V i , where W Q i ∈ R d×d k , W K i ∈ R d×d k , W V i ∈ R d×dv are linear transformations for head i, and σ is the non-linearity by default set as the sof tmax r function (subscript r indicates softmax is applied to each row of a matrix). d k , d v are dimensions for key and value, respectively. MHA concatenates the output of h attention heads along the channel dimension, resulting in feature dimension hd v . Unless otherwise mentioned, we assume d k = d v and h = d d k . This means the query, key and value are the same dimension within each head, and output dimension matches that of the input.

3.1. ATTENTION FREE TRANSFORMER

We now define Attention free transformer (AFT), which provides an alternative to MHA. Given Q, K, V , AFT first linearly transforms them into Q = QW Q , K = KW K , V = V W V , then



We propose AFT, a new family of Transformer models that achieves O(T d) time and space complexity in training, as well as O(d) time and space complexity in autoregressive decoding. • We show strong performance of AFT as a drop in replacement of MHA on various benchmarks, including setting the state-of-the-art result on CIFAR10 in the standard setting and outperforming other efficient Transformer variants.

Complexity comparison with different Transformers: Reformer (Kitaev et al., 2020),Synthesizer (Tay et al., 2020a), Linear Transformer (Katharopoulos et al., 2020)  (only variants that support the causal mode are shown). Here T, d denote the sequence length and feature dimension, respectively.

