EFFICIENT ATTENTION VIA CONTROL VARIATES

Abstract

Random-feature-based attention (RFA) is an efficient approximation of softmax attention with linear runtime and space complexity. However, the approximation gap between RFA and conventional softmax attention is not well studied. Built upon previous progress of RFA, we characterize this gap through the lens of control variates and show that RFA can be decomposed into a sum of multiple control variate estimators for each element in the sequence. This new framework reveals that exact softmax attention can be recovered from RFA by manipulating each control variate. Besides, it allows us to develop a more flexible form of control variates, resulting in a novel attention mechanism that significantly reduces the approximation gap while maintaining linear complexity. Extensive experiments demonstrate that our model outperforms state-of-the-art efficient attention mechanisms on both vision and language tasks. 1

1. INTRODUCTION

Random-feature-based attention (RFA, also known as Performer; Choromanski et al., 2021; Peng et al., 2021b) is an established fast approximation to the conventional softmax attention mechanism (Bahdanau et al., 2014; Vaswani et al., 2017) , which successfully scales Transformer models to processing much longer sequences (Choromanski et al., 2021) . At its core is the usage of random features (RF; Rahimi & Recht, 2008) to linearize the exponential kernel in softmax attention, which reduces the computational cost from quadratic to linear runtime and space complexity. Despite its efficiency, recent studies have pointed out that such approximation suffers from substantial performance degeneration (Xiong et al., 2021a; Zheng et al., 2022b) . In this work, we generalize the formulation of RFA via control variates (Owen, 2013) , which characterizes the approximation gap between RFA and softmax attention in theory. We first show that RFA can be decomposed from a global approximation over the whole sequence into a sum of local control variate estimators, each of which is applied to an individual element in the sequence. Under this formulation, RFA is equivalent to employing the same coefficient for all control variate estimators to scale their variance isotropically ( §3.1). Besides, we prove that if we optimize the coefficient of each control variate to minimize the estimation variance individually, RFA estimation becomes exact, that is, softmax attention is recovered with zero bias and zero variance ( §3.2). Our key observation is that such formulation reveals a localized perspective of the RFA approximation. Instead of directly seeking a better estimate over the entire sequence, we can break down the problem into smaller problems that aim at improving the approximation for each subsequence ( §4). The control variate estimator for each subsequence can be tuned separately and combined to yield better estimation, which provably reduces approximation error in the global sense ( §4.1). Nevertheless, one caveat is that as the number of sub-problems increases, the approximation gap will be reduced but at the expense of higher computational complexity. For instance, if we optimize the control variate for every single element, softmax attention would be recovered as desired but with quadratic complexity. To attain a good trade-off between approximation quality and efficiency, we develop a new Efficient attention via control VAriates (EVA) that implements this divide-and-conquer strategy efficiently. In EVA, the sequence is partitioned into a fixed number of disjoint subsets. For the subset that might bear the highest correlations to the query, we explicitly optimize the control variate for each element, which recovers exact softmax attention probabilities; while for the others, the control variate coefficient is shared locally among all elements within the same subset. The resulting attention mechanism is not only highly effective but also runs with the same computational complexity as RFA ( §4.2). Extensive experiments on both language and vision tasks demonstrate that EVA outperforms the state-of-the-art efficient attention methods ( §5).

2. BACKGROUND

2.1 SOFTMAX ATTENTION MECHANISM Assume there exist a set of N queries {q n } N n=1 and M key-value pairs K = [k 1 , . . . , k M ] and V = [v 1 , . . . , v M ], where queries, keys and values are all d-dimensional vectors. The softmax attention mechanism (Bahdanau et al., 2014; Vaswani et al., 2017) is defined as an average over the value vectors weighted by the dot-product similarities of the queries and keys. For the n-th query, the attention mechanism outputs SoftmaxAttn(q n , K, V) := M m=1 exp q ⊤ n k m M m ′ =1 exp (q ⊤ n k m ′ ) v m . In the case of self-attention (Lin et al., 2017; Vaswani et al., 2017) , we have M = N , which results in quadratic computational complexity since we have to compute the similarity for each query-key pair explicitly.

2.2. RANDOM-FEATURE-BASED ATTENTION WITH SELF-NORMALIZED IMPORTANCE SAMPLING

Recently, Zheng et al. (2022b) identifies that softmax attention (Equation 1) can be written as an expectation over an attention-like aggregating function, SoftmaxAttn(q n , K, V) = M m=1 exp q ⊤ n k m M m ′ =1 exp (q ⊤ n k m ′ ) v m = E ω∼pn(ω) [f n (ω)] , where f n (ω) := M m=1 ξ(q n , ω)ξ(k m , ω)v m M m ′ =1 ξ(q n , ω)ξ(k m ′ , ω) , p n (ω) := N (ω; 0, I) M m=1 ξ(q n , ω) ⊤ ξ(k m , ω) Z . Here ξ(•, •) is the randomized mapping defined in such a way that exp q ⊤ n k m = E ω∼N (0,I) ξ(q n , ω) ⊤ ξ(k m , ω) , and Z = M m=1 exp q ⊤ n k m denotes the normalizing constant of distribution p n . Throughout this paper, we consider the positive randomized mapping ξ(x, ω) = exp ω ⊤ x -1 2 ∥x∥ 2 (Choromanski et al., 2021) unless otherwise specified. Random-Feature-based Attention (RFA) methods (Choromanski et al., 2021; Peng et al., 2021b) can be interpreted as performing self-normalized importance sampling (SNIS; Hesterberg, 1995) to approximate Equation 2 (Zheng et al., 2022b) . In SNIS, one draws Monte Carlo samples from some proposal distribution q(ω) instead of the true distribution p n (ω) and estimates the target expectation as E ω∼pn(ω) [f n (ω)] = E ω∼q(ω) pn(ω) q(ω) f n (ω) ≈ S s=1 pn(ω) q(ω) fn(ωs) S s=1 pn (ωs) q(ωs ) , where ω 1 , . . . , ω S ∼ q(ω). Vanilla RFA amounts to constructing the SNIS estimation with q(ω) = N (ω; 0, I). The SNIS representation also turns out equivalent to the more established form of RFA, RFA(q n , K, V) := S s=1 pn(ωs) q(ωs) f (ω s ) S s=1 pn(ωs) q(ωs) = M m=1 ϕ(q n , ω) ⊤ ϕ(k m , ω)v m M m ′ =1 ϕ(q n , ω) ⊤ ϕ(k m ′ , ω) , where the random feature, denoted by ϕ(x, ω) := 1/ √ S[ξ(x, ω 1 ), . . . , ξ(x, ω S )] ⊤ , is proposed to approximate exponential kernels in its original motivation (see Appendix A for a detailed review). As shown in Equation 4, RFA estimation considers all key-value pairs and produces a global approximation over the entire sequence. In contrast, our work develops a decomposed representation of RFA based on the recent advances in SNIS (Vlassis et al., 2021) , which indicates that an SNIS estimate is asymptotically equivalent to a control variate estimate (the detailed derivations is deferred to Appendix B.2). In particular, we have S s=1 pn(ωs) q(ωs) f (ω s ) S s=1 pn(ωs) q(ωs) = 1 S S s=1 p n (ω s ) q(ω s ) f (ω s ) - S s=1 pn(ωs) q(ωs) f (ω s ) S s=1 pn(ωs) q(ωs) 1 S S s=1 p n (ω s ) q(ω s ) -1 := g(ω) -β(ω) (h(ω) -E [h(ω)]) := g(ω), where g(ω) := 1 S S s=1 pn(ωs) q(ωs) f (ω s ) is our base estimate, h(ω) := 1 S S s=1 pn(ωs) q(ωs) is the control variate with control coefficient β(ω) := S s=1 pn(ωs) q(ωs) f (ω s ) S s=1 pn(ωs) q(ωs) = g (ω) h(ω) . We now examine the formulation of g(•) and h(•) in the context of RFA. According to Equation 3, g(ω) = 1 S S s=1 p n (ω s ) q(ω s ) f (ω s ) = S s=1 α(ω s ) M m=1 ξ(q n , ω s )ξ(k m , ω s )v m , h(ω) = 1 S S s=1 p n (ω s ) q(ω s ) = S s=1 α(ω s ) M m=1 ξ(q n , ω s )ξ(k m , ω s ), where α(ω s ) := 1 S N (ωs;0,I) Zq(ωs) collects terms that is constant w.r.t. queries, keys, and values. Our key observation is that by changing the order of summations, both g(•) and h(•) can be decomposed as g(ω) = M m=1 g m (ω) and h(ω) = M m=1 h m (ω) respectively, where g m (ω) = S s=1 α(ω s )ξ(q n , ω s )ξ(k m , ω s )v m , h m (ω) = S s=1 α(ω s )ξ(q n , ω s )ξ(k m , ω s ). As a result, we can decompose the entire RFA estimate in Equation 6 into a summation of M control variate estimates following g(ω) = g(ω) -β(ω) (h(ω) -E [h(ω)]) = M m=1 g m (ω) -β(ω) M m=1 h m (ω) -E M m=1 h m (ω) = M m=1 g m (ω) -β(ω) (h m (ω) -E [h m (ω)]) := M m=1 g m (ω). Here g m (ω) = g m (ω) -β(ω) (h m (ω) -E [h m (ω)] ) denotes the corresponding control variate estimator of the m-th key-value pair,foot_0 and β(ω) is the coefficient shared across the entire sequence.

3.2. OPTIMIZING COEFFICIENTS IN RFA LOCALLY RECOVERS SOFTMAX ATTENTION

Based on the decomposition of RFA in Equation 7, we have one local control variate attached to each key-value pair. To see the benefit of such decomposition, we demonstrate that softmax attention is equivalent to associating each control variate with a locally optimized coefficient β m in RFA. Proposition 1. Let g m (ω) = g m (ω) -β m (h m (ω) -E [h m (ω)]). We denote the variance of some estimator g(ω) as Var [g(ω)] := Cov [g(ω), g(ω)]. Then the optimal β m that minimizes Tr (Var [ g m (ω)]) (i.e., the sum variance over all dimensions) is of the form β * m := arg min β Tr (Var [ g m (ω)]) = v m = g m (ω) h m (ω) . ( ) Furthermore, by letting β m = β * m for all m = 1, 2, . . . , M , we have Tr (Var [ g m (ω)]) = 0. As a result, Tr (Var [ g(ω)]) = 0 and thus RFA(q n , K, V) = g(ω) = SoftmaxAttn(q n , K, V). The proof is deferred to Appendix B.4. This proposition implies optimizing β m for each keyvalue pair in the decomposed formulation of RFA recovers the exact softmax attention. It not only characterizes the theoretical gap introduced by RFA but also sheds light on how to improve RFA towards softmax attention from a localized perspective. Furthermore, it delineates the trade-off between estimation quality and computational costs. On the one hand, if we use a distinct β m for each estimator, we could achieve a perfect estimation, albeit at the expense of computing exp q ⊤ n k m for every query-key pair explicitly with quadratic time and space complexity. On the other hand, if a single shared coefficient is employed, it degrades to conventional RFA, where all the control variate estimators can be merged and computed together in linear complexity (Choromanski et al., 2021; Peng et al., 2021b; Zheng et al., 2022b) .

4. EVA: EFFICIENT ATTENTION VIA CONTROL VARIATES

In this section, we demonstrate that the control variate formulation offers a natural way to improve RFA with a finer-grained treatment over control variates. We describe the improved efficient attention mechanism EVA in §4.1 and its practical implementation in §4.2.

4.1. CONTROL VARIATES WITH LOCALLY SHARED COEFFICIENTS

We denote [M ] := {1, 2, . . . , M } as the set of all key-value indices. Instead of employing the same coefficient for all control variates as in RFA, we propose to partition [M ] into C subsets P 1 , P 2 , . . . , P C and allocate a locally shared β c for each subset P c . For all β c and their optimum β * m for each token, define the weighted mean squared error (weighted MSE) as C c=1 m∈Pc α m ∥β c -β * m ∥ 2 , where α m > 0 and C c=1 m∈Pc α m = 1. To see the benefit of partitioning, we demonstrate that there always exists some {β c } C c=1 that achieves lower weighted MSE than any globally shared coefficient (see Appendix B.5 for a formal argument). The next question is how to determine {β c } C c=1 . According to Proposition 1, a natural choice is to adapt the optimal coefficients (Equation 8) to the case of partitioned subsets. We justify this choice by proving that it is also optimal in minimizing the MSE above weighted by the true attention probabilities. Proposition 2. Suppose U is a set of key-value indices, β * m is the optimal coefficient for each m ∈ U as defined in Proposition 1, and P 1 , P 2 , . . . , P C are an arbitrary partition of U , where each subset P c is associated with a distinct β c . We consider the following weighted mean squared error, J(β 1 , . . . , β C ) := C c=1 m∈Pc exp q ⊤ n k m m ′ ∈U exp (q ⊤ n k m ′ ) ∥β c -β * m ∥ 2 . ( ) Then for each c = 1, . . . , C we have β * c := arg min β c J(β 1 , . . . , β C ) = E m∈Pc g m (ω) E m∈Pc h m (ω) . ( ) As a consequence, with β c = β * c , the partition scheme must achieve lower weighted mean squared error than any globally shared β, that is, J(β 1 = β * 1 , . . . , β C = β * C ) ≤ J(β 1 = β, . . . , β C = β). The proof can be found in Appendix B.6. Apart from measuring the squared errors for all coefficients, Equation 9 also governs the significance of each error by its corresponding softmax weights, which attains closer alignment with true softmax attention. Therefore, this proposition implies that it is much easier for the partitioned control variate estimators to obtain coefficients closer to their optimum while faithfully respecting softmax attention. The optimal coefficients β * c could be estimated via Monte Carlo samples as β * c ≈ β c (ω) = m∈Pc g m (ω) / m∈Pc h m (ω) , which is a widely adopted strategy in the control variate literature (Wang et al., 2013; Owen, 2013) . The resulting estimator for each subset P c takes the form m∈Pc g m (ω) -β c (ω)h m (ω) + β c (ω) exp(q ⊤ n k m ) Z = m∈Pc exp(q ⊤ n k m ) Z β c (ω). (11) Partially Optimized Coefficients. Given the optimality of using a separate coefficient for each key-value pair, we could further improve the estimation by selecting some subset E ⊆ [M ] and employ β m = β * m = v m for each m ∈ E. Without loss of generality, we assume E ∩ P c = ∅ for all c = 1, . . . , C and [M ] = C c=1 P c ∪ E. According to Proposition 1, for each m ∈ E we have g m (ω) = g m (ω) -β m h m (ω) + β m exp(q ⊤ n k m ) Z = exp(q ⊤ n k m )v m Z . We choose E by running an additional sparse attention mechanism (e.g., local window attention (Child et al., 2019) or Reformer (Kitaev et al., 2020) ), which tend to select tokens that are more relevant to the query in sub-quadratic complexity. Since estimates on these critical tokens are exact, this strategy not only reduces the overall squared error (Equation 9), but also produces a more informative context for queries, which often translates into better empirical performance. Combining Equations 12 and 11 together, we obtain an improved Efficient attention via control VAriates (EVA), EVA(q n , K, V) := g(ω) = m∈E g m (ω) + m / ∈E g m (ω) = m∈E exp(q ⊤ n k m ) Z v m + C c=1 m∈Pc exp(q ⊤ n k m ) Z β c (ω). Comparison with Vanilla RFA. EVA and vanilla RFA can be re-written in a similar way (see Appendix B.7 for a detailed derivation), RFA(q n , K, V) = M m=1 g m (ω) M m=1 h m (ω) , EVA(q n , K, V) = m∈E exp(q ⊤ n k m ) Z g m (ω) h m (ω) + C c=1 m∈Pc exp(q ⊤ n k m ) Z m∈Pc g m (ω) m∈Pc h m (ω) . (15) Intuitively, we can think of EVA as a calibrated version of RFA. Instead of directly computing and aggregating the random feature approximation for all tokens as in RFA (Equation 14), EVA (Equation 15) first constructs local estimation for either a single token (m ∈ E) or a subset (e.g., P c ), and then corrects these approximations by their corresponding true attention scores (e.g., m∈Pc exp(q ⊤ n k m ) for P c ). These adjusted local estimates are finally aggregated and globally normalized. Thanks to the decomposed representation of RFA, we can realize this divide-and-conquer strategy in a principled manner, which imposes finer-grained control on the whole estimation accuracy and enjoys increased approximation fidelity. 

4.2. PRACTICAL IMPLEMENTATION

According to the formulation (Equation 13) of EVA, the terms within E could be computed efficiently due to its limited size; however, the partitioning requires computing m∈Pc exp(q ⊤ n k m ) explicitly for each subset, which again builds up to quadratic computational complexity. As discussed above, m∈Pc exp(q ⊤ n k m ) serves as a weight to correct the contribution from each subset P c . In this regard, we propose to approximate such control by m∈Pc exp(q ⊤ n k m ) ≈ exp(q ⊤ n k c ), where k c is an adaptive vector summarizing the information of all keys belonging to P c (see Appendix C for more details). Such heuristic not only avoids computing the exponential dot product of each query-key pair explicitly, but also induces a fast approximation of the normalizing constant, Z = m∈E exp(q ⊤ n k m ) + C c=1 m∈Pc exp(q ⊤ n k m ) ≈ m∈E exp(q ⊤ n k m ) + C c=1 exp(q ⊤ n k c ). Equipped with these results, our EVA estimator (Equation 13) can be reduced as follows, EVA(q n , K, V) ≈ m∈E exp(q ⊤ n k m )v m + C c=1 exp(q ⊤ n k c ) β c (ω) m∈E exp(q ⊤ n k m ) + C c=1 exp(q ⊤ n k c ) . ( ) Parameterization Details. We define E in the same way as a simple block-wise local attention (Xiong et al., 2021a) . The input sequence is first chunked into multiple blocks (or 2D windows for images), and each query q n is associated with a specific E n that only contains tokens within the same block as the query. For the remaining indices [M ] \ E n , we evenly split it into C contiguous chunks {P n 1 , . . . , P n C }. Note that we add the superscript n here to denote the dependence on the query position; however, for notational brevity, we omit the notation when there is no ambiguity. The pseudo-code of EVA is provided in Algorithm 1 of Appendix. More implementation details, including the definition of k c and β c (ω) in Equation 16, are deferred to Appendix C. Extension to Autoregressive Modeling. The decoder (or causal) self-attention, where each query can only attend to previous tokens, is the key ingredient in Transformer-based generative modeling (Vaswani et al., 2017; Brown et al., 2020) . We demonstrate that it is straightforward to extend EVA to support such auto-regressive modeling with few modifications. Thanks to the decomposed formulation of EVA, we only need to incorporate two triangular mask matrices into the computation, which eliminate the information from future singletons m ∈ E and entire future subsets P c respectively. Unlike previous RFA methods, which are slow during training due to their recurrent computation (Choromanski et al., 2021; Peng et al., 2021b) , the resulting causal variant remains highly efficient. More details can be found in Appendix D, including a pseudo-code Algorithm 2.

5. EXPERIMENTAL RESULTS

In this section, we evaluate our proposed method on various tasks, including image classification ( §5.1), language tasks ( §5.2), and Long Range Arena benchmark (Appendix F). Details of experimental protocols and baselines can be found in Appendix E. (Child et al., 2019) 5.7M 4.4G 70.62 Scatterbrain (Chen et al., 2021a) 5.7M 5.2G 73.50 Nyströmformer (Xiong et al., 2021b) 5.7M 4.8G 74.20 LARA (Zheng et al., 2022b) 5.8M 

5.1. IMAGE CLASSIFICATION

We explore the ability to learn visual representations for different attention mechanisms in vision transformers (ViTs; Dosovitskiy et al., 2021) . In particular, we replace softmax attention used in ViTs with its efficient variants and evaluate their performance on the ImageNet1k dataset (Deng et al., 2009) , which contains over 1,280K and 50K images of 1,000 classes for training and validation splits, respectively. For the transformer model, we consider both a plain ViT (DeiT; Dosovitskiy et al., 2020; Touvron et al., 2021) and a pyramidal ViT (PVT; Wang et al., 2021b) to test the performance. The former maintains the same sequence length (which is set to 196 by default) across all transformer layers, while the latter processes much longer sequences (up to 3136 tokens) at early layers and progressively reduces the sequence length to form a hierarchical structure. Detailed experimental settings could be found in Appendix E.2. Results. We first compare the performance of EVA against our main baselines on the standard ViT architectures. As shown in Table 1 , EVA significantly improves the performance of previous RFA approaches (including Performer (Choromanski et al., 2021) and LARA (Zheng et al., 2022b) ) and local attention by a large margin, and even outperforms the conventional softmax attention. We then consider a more challenging setting, where the plain architecture DeiT-Tiny is used but the sequence length is scaled up to 784 (denoted as DeiT-Tiny-784). We compare EVA against other attention variants in this setting and report the classification results in Table 2 . EVA outperforms most previous baselines and remains highly competitive with softmax attention, illustrating its effectiveness.

5.2. MACHINE TRANSLATION AND LANGUAGE MODELING

We further evaluate EVA on the natural language domain. Specifically, we consider three tasks: • Masked language modeling (MLM) on a pretraining-scale book corpus Books3 in the Pile dataset suite (Presser, 2020; Gao et al., 2020) , consisting of over 196,640 published books. • Machine translation (MT) on WMT14 En-De benchmark (Bojar et al., 2014) . • Autoregressive language modeling (Autoregressive LM) on a large-scale token-level LM benchmark Wikitext-103 (Merity et al., 2016) . Results. We report MLM validation perplexity in Ablation Study. In this section, we conduct an ablation study on image classification and MLM tasks to investigate the effects of main hyper-parameters in EVA (see Table 8 for more comprehensive analysis). In particular, we vary |E| and the partition size C and evaluate their performance on both image classification and masked language modeling. As presented in Table 6 and Table 7 , increasing |E| amounts to obtaining exact estimates for more key-value pairs, which greatly improves empirical performance; besides, increasing C would process control variates at a finer scale, also translating into better modeling quality, consistent with our theoretical analysis ( §4.1).

6. RELATED WORK

Control Variates. Control variates are a widely used variance reduction technique in reinforcement learning (Greensmith et al., 2004; Grathwohl et al., 2018; Vlassis et al., 2021) , stochastic optimization (Wang et al., 2013) , variational inference (Paisley et al., 2012; Ranganath et al., 2014; Geffner & Domke, 2018; Tucker et al., 2017; Grathwohl et al., 2018) , Markov chain Monte Carlo (Baker et al., 2019) and many other topics. Our construction with control variates provides a new perspective on designing faster yet more accurate attention approximations. Efficient Attention Mechanisms. A lot of research work has put the focus on reducing the quadratic complexity of conventional softmax attention. A widely used approach is to define a sparse attention pattern so that each query is limited to only attending to a subset of tokens. The sparse pattern could be either learnable (Kitaev et al., 2020; Vyas et al., 2020; Tay et al., 2020; Roy et al., 2021; Madaan et al., 2022) or simply fixed (Liu et al., 2018; Parmar et al., 2018; Child et al., 2019; Beltagy et al., 2020; Ainslie et al., 2020; Zaheer et al., 2020; Liu et al., 2021; Xiong et al., 2021a; Wang et al., 2022; Chen et al., 2022; Hutchins et al., 2022) . Another paradigm is to adopt low-rank approximations, including via the Nyström method (Xiong et al., 2021b) , down-sampling with learnable projections (Wang et al., 2020; Peng et al., 2021a) , or explicitly compressing sequences (Rae et al., 2020; Dai et al., 2020; Ma et al., 2021; Jaegle et al., 2021) . There are also studies improving both sparse and low-rank methods for better attention matrix approximation (Nguyen et al., 2021; Zhu et al., 2021; Chen et al., 2021a; Ren et al., 2021; Zhu & Soricut, 2021; Hua et al., 2022; Zeng et al., 2022) . Instead of adopting approximate methods, a recent line of work (Rabe & Staats, 2021; Dao et al., 2022) proposes to compute the exact softmax attention in an online manner (Milakov & Gimelshein, 2018) without materializing the full attention matrix. In this way, softmax attention can be computed in linear memory complexity, and the runtime can also be greatly improved by further minimizing memory accesses (Dao et al., 2022) . Random-Feature-based Attention. Random-feature-based methods are a popular alternative that uses random features (Rahimi & Recht, 2008) to linearize exponential kernels in softmax attention (Katharopoulos et al., 2020; Choromanski et al., 2021; Peng et al., 2021b) . Recent work attempts to improve RFA approximation from several aspects, such as designing more accurate random feature maps (Choromanski et al., 2022; Likhosherstov et al., 2022; Chowdhury et al., 2022) , incorporating relative positional or other task-specific biases (Liutkus et al., 2021; Luo et al., 2021; Chen, 2021; Zheng et al., 2022a; Qin et al., 2022b; Wu et al., 2022; Qin et al., 2022a) , or leveraging connections to fast weight programmers (Peng et al., 2021b; Schlag et al., 2021; Irie et al., 2021) . Prior work closely related to ours includes Zheng et al. (2022b) , which reinterprets RFA using self-normalized importance sampling (Hesterberg, 1995) and theoretically extends the random feature approximation from individual exponential kernels to the whole softmax attention. Our work further generalizes this result via control variates and characterizes the approximation gap caused by RFA. Scatterbrain (Chen et al., 2021a) is also similar to our work in that it also refines RF approximation on critical local regions. However, it is developed based on a different motivation that attempts to approximate the attention matrix with a combination of sparse and low-rank matrices. Interestingly, we find that Scatterbrain can be cast as a special case under our framework; see Appendix G for a detailed discussion about connections between EVA and previous attention mechanisms.

7. CONCLUSION AND LIMITATIONS

In this work, we develop an efficient attention mechanism EVA via control variates. Our framework reveals a localized perspective of RFA approximation, which not only bridges the gap between RFA and exact softmax attention but also attains a good trade-off between modeling quality and efficiency. We evaluate our method on both vision and language tasks and demonstrate substantial improvements over previous baselines. There are some limitations of our framework. For instance, the approximation in computing control variate estimation for each partitioned subset is crude and might limit the potential modeling capacity; in addition, we only explore the most straightforward partitioning strategy that evenly splits the sequence into multiple contiguous chunks; while in general, the partition could contain arbitrary subsequences or be adaptive to inputs via clustering methods, which can be guided by task-specific inductive biases. It is interesting to investigate these limitations to unleash the expressiveness of EVA further, which we leave for future work.

A A BRIEF REVIEW OF VANILLA RANDOM FEATURE ATTENTION

Vanilla random feature attention methods, such as Performer (Choromanski et al., 2021; Peng et al., 2021b) , seek to approximate the softmax attention mechanism through random features (Rahimi & Recht, 2008) ϕ(x, ω) := 1/ √ S[ξ(x, ω 1 ), . . . , ξ(x, ω S )] ⊤ . Here, ω 1 , . . . , ω S ∼ N (0, I), and ξ(x, ω) is the randomized mapping such that exp q ⊤ n k m = E ωs∼N (0,I) ξ(q n , ω s ) ⊤ ξ(k m , ω s ) . Therefore, we can draw multiple Monte Carlo samples to estimate the exponential kernel, exp q ⊤ n k m ≈ 1 S S s=1 ξ(q n , ω s ) ⊤ ξ(k m , ω s ) := ϕ(q n , ω) ⊤ ϕ(k m , ω), and then approximate the attention mechanism as M m=1 exp q ⊤ n k m M m ′ =1 exp (q ⊤ n k m ′ ) v m ≈ M m=1 ϕ(q n , ω) ⊤ ϕ(k m , ω)v m M m ′ =1 ϕ(q n , ω) ⊤ ϕ(k m ′ , ω) . ( ) It is recently generalized as a self-normalized importance sampling estimator to approximate softmax attention (Zheng et al., 2022b) , as described in §2.2. We refer the generalized random feature based approximations as RFA.

B PROOFS & DERIVATIONS B.1 AN EXTENDED REVIEW OF CONTROL VARIATES

The control variate method takes the following form, g(ω) = g(ω) -βh(ω) + βE [h(ω)] , Given the particular forms of g(•) and h(•), β can be optimized to minimize the estimation variance. For notational convenience, we denote the covariance between a scalar and a random vector as Cov [h(ω), g(ω)] := E [(h(ω) -E [h(ω)]) (g(ω) -E [g(ω)])], and the variance of a random vector as Var [g(ω)] := Cov [g(ω), g(ω)]. In particular, we have Var [ g(ω)] = Var [g(ω) -βh(ω)] = Var [g(ω)] -2 Cov [βh(ω), g(ω)] + Var [βh(ω)] = Var [g(ω)] -2 Cov [h(ω), g(ω)] β ⊤ + Var [h(ω)] ββ ⊤ . We hope an optimal β would minimize Tr (Var [ g(ω)]), that is, the sum of estimating variance for each dimension. By differentiating, we obtain β * = arg min β Tr (Var [ g(ω)]) = Cov [h(ω), g(ω)] Var [h(ω)] . ( ) Since both the covariance and the variance may be intractable to compute, the optimal β * is generally not available in closed form. Nevertheless, with the optimal coefficient, the variance of such control variate estimate would never be larger than the plain estimator g(•).

B.2 DERIVATION OF SNIS AS CONTROL VARIATE ESTIMATION

For notational convenience, we denote the importance weight as W (ω s ) := p n (ω s )/q(ω s ). Then we have g(ω) = S s=1 pn(ωs) q(ωs) f (ω s ) S s=1 pn(ωs) q(ωs) = S s=1 W (ω s )f (ω s ) S s=1 W (ω s ) = S s=1 W (ω s )f (ω s ) S s=1 W (ω s ) - 1 S S s=1 W (ω s )f (ω s ) + 1 S S s=1 W (ω s )f (ω s ) = S s=1 W (ω s )f (ω s ) S s=1 W (ω s ) - S s=1 W (ω s ) S s=1 W (ω s ) 1 S S s=1 W (ω s )f (ω s ) + 1 S S s=1 W (ω s )f (ω s ) = 1 -1 S S s=1 W (ω s ) S s=1 W (ω s ) S s=1 W (ω s )f (ω s ) + 1 S S s=1 W (ω s )f (ω s ) = S s=1 W (ω s )f (ω s ) S s=1 W (ω s ) 1 - 1 S S s=1 W (ω s ) + 1 S S s=1 W (ω s )f (ω s ) = 1 S S s=1 W (ω s )f (ω s ) - S s=1 W (ω s )f (ω s ) S s=1 W (ω s ) 1 S S s=1 W (ω s ) -1 = g(ω) -β(ω) (h(ω) -E [h(ω)]) , Note that the expectation of importance weights equals 1, that is, E [h(ω)] = E 1 S S s=1 W (ω s ) = E ω1,...,ω S ∼q(ω) S s=1 1 S p(ω s ) q(ω s ) = 1 S S s=1 E ωs∼q(ω) p(ω s ) q(ω s ) = 1. Same as SNIS, this estimator is still biased due to the dependence of β(ω) on ω. However, it would asymptotically become unbiased since β(ω) is consistent and converges to a constant β w.r.t. ω given a large number of samples, β(ω) = g(ω) h(ω) p -→ E [g(ω)] E [h(ω)] = E pn(ω) [f (ω)] constant := β.

B.3 DERIVATION OF THE EXPECTATION OF PER-TERM CONTROL VARIATES

According to the definition of randomized mappings, we have E [h m (ω)] = E ω1,...,ω S ∼q(ω) 1 S S s=1 N (ω s ; 0, I) Zq(ω s ) ξ(q n , ω s )ξ(k m , ω s ) = 1 S S s=1 1 Z ξ(q n , ω s )ξ(k m , ω s )N (ω s ; 0, I)dω s = exp(q ⊤ n k m ) Z .

B.4 PROOF OF PROPOSITION 1

Proof. We start with the formulation of g(•) and h(•), g m (ω) h m (ω) = S s=1 N (ωs;0,I) Zq(ωs) ξ(q n , ω s )ξ(k m , ω s )v m S s=1 N (ωs;0,I) Zq(ωs) ξ(q n , ω s )ξ(k m , ω s ) = v m . As a result, we have g m (ω) = h m (ω)v m and E [g m (ω)] = E [h m (ω)] v m . We now investigate the optimal β m according to Equation 20, β * m = arg min β Tr (Var [ g m (ω)]) = Cov [h m (ω), g m (ω)] Var [h m (ω)] = E [(h(ω) -E [h(ω)]) (h(ω) -E [h(ω)])] v m E [(h(ω) -E [h(ω)]) (h(ω) -E [h(ω)])] = v m = g m (ω) h m (ω) . In terms of the variance, we again use g m (ω) = h m (ω)v m to obtain g m (ω) = g m (ω) -β m (h m (ω) -E [h m (ω)]) = g m (ω) -v m h m (ω) + v m E [h m (ω)] = v m E [h m (ω)] = exp(q ⊤ n k m ) Z v m . Since this holds true for every term m = 1, . . . , M , our estimate becomes exactly softmax attention, g(ω) = M m=1 g m (ω) = M m=1 exp(q ⊤ n k m ) Z v m = M m=1 exp(q ⊤ n k m ) M m ′ =1 exp(q ⊤ n k m ) v m . Since all randomness is eliminated, the estimate is exact with zero bias and variance. That is, RFA(q n , K, V) = g(ω) = SoftmaxAttn(q n , K, V).

B.5 A FORMAL ANALYSIS OF THE ADVANTAGE OF PARTITIONING

In this section, we demonstrate the advantage of partitioning by showing that there always exists some set {β c } C c=1 that achieves lower weighted MSE than any globally shared coefficient, as discussed in §4.1. Lemma 3. Suppose β * m is the optimal coefficient for each m ∈ [M ] as defined in Proposition 1, and P 1 , P 2 , . . . , P C are an arbitrary partition of [M ], where each subset P c is associated with a distinct β c . We consider the following weighted mean squared error, J(β 1 , . . . , β C ) := C c=1 m∈Pc α m ∥β c -β * m ∥ 2 , ( ) where α m > 0 for each m ∈ [M ] and C c=1 m∈Pc α m = 1. Then for any choice of {α m } M m=1 and any globally shared coefficient β, there exists some {β * c } C c=1 so that J(β 1 = β, . . . , β C = β) ≥ J(β 1 = β * 1 , . . . , β C = β * C ). Proof. Let β * c = m∈Pc αmβ * m m∈Pc αm for each c = 1, . . . , C. Then we have m∈Pc α m (β * c -β * m ) = β * c m∈Pc α m - m∈Pc α m β * m = m∈Pc α m β * m m∈Pc α m m∈Pc α m - m∈Pc α m β * m = m∈Pc α m β * m - m∈Pc α m β * m = 0. ( ) According to Equations 25 and 24, for any β we have the following inequality, J(β 1 = β, . . . , β C = β) = C c=1 m∈Pc α m ∥β -β * m ∥ 2 = C c=1 m∈Pc α m ∥β -β * c + β * c -β * m ∥ 2 = C c=1 m∈Pc α m ∥β -β * c ∥ 2 + 2 (β -β * c ) ⊤ (β * c -β * m ) + ∥β * c -β * m ∥ 2 = C c=1 m∈Pc α m ∥β -β * c ∥ 2 + 2 C c=1 m∈Pc α m (β -β * c ) ⊤ (β * c -β * m ) =0 + C c=1 m∈Pc α m ∥β * c -β * m ∥ 2 = C c=1 m∈Pc α m ∥β -β * c ∥ 2 + C c=1 m∈Pc α m ∥β * c -β * m ∥ 2 ≥ C c=1 m∈Pc α m ∥β * c -β * m ∥ 2 = J(β 1 = β * 1 , . . . , β C = β * C ). As a result, for any choice of {α m } M m=1 and any globally shared coefficient β, there always exists some {β c } C c=1 that achieves lower (or equal) weighted MSE, and a solution can be simply β c = m∈Pc αmβ * m m∈Pc αm .

B.6 PROOF OF PROPOSITION 2

Proof. We first consider the case of partitioned indices, where each subset P c is associated with some specific β c . To see the global minimum of J, we differentiate on both sides and obtain ∂J(β 1 , . . . , β C ) ∂β c = ∂ ∂β c C c=1 m∈Pc exp q ⊤ n k m m ′ ∈U exp (q ⊤ n k m ′ ) ∥β c -β * m ∥ 2 = m∈Pc exp q ⊤ n k m m ′ ∈U exp (q ⊤ n k m ′ ) 2 (β c -β * m ) . By setting the partial derivative to zero, we obtain β * c = m∈Pc exp q ⊤ n k m β * m m∈Pc exp (q ⊤ n k m ) = m∈Pc exp q ⊤ n k m v m m∈Pc exp (q ⊤ n k m ) = m∈Pc E [g m (ω)] m∈Pc E [h m (ω)] = E m∈Pc g m (ω) E m∈Pc h m (ω) . As a consequence, with β c = β * c , the partition scheme must achieve lower weighted mean squared error than any globally shared β, that is, J(β 1 = β * 1 , . . . , β C = β * C ) ≤ J(β 1 = β, . . . , β C = β). In fact, with β c = β * c , the partition scheme usually enjoys much lower error than adopting a globally shared coefficient. To see the error reduction of using the partitioned strategy, we first have J(β 1 = β, . . . , β C = β) = C c=1 m∈Pc exp q ⊤ n k m m ′ ∈U exp (q ⊤ n k m ′ ) β -β * m 2 = C c=1 m∈Pc exp q ⊤ n k m m ′ ∈U exp (q ⊤ n k m ′ ) β -β c + β c -β * m 2 = C c=1 m∈Pc exp q ⊤ n k m m ′ ∈U exp (q ⊤ n k m ′ ) β -β c + β c -β * m 2 = C c=1 m∈Pc exp q ⊤ n k m m ′ ∈U exp (q ⊤ n k m ′ ) β -β c 2 + β -β c ⊤ (β c -β * m ) + ∥β c -β * m ∥ 2 . Since C c=1 m∈Pc exp q ⊤ n k m m ′ ∈U exp (q ⊤ n k m ′ ) β -β c ⊤ (β c -β * m ) = C c=1 β -β c ⊤ m∈Pc exp q ⊤ n k m m ′ ∈U exp (q ⊤ n k m ′ ) (β c -β * m ) = C c=1 β -β c ⊤ m∈Pc exp q ⊤ n k m m ′ ∈U exp (q ⊤ n k m ′ ) m∈Pc exp q ⊤ n k m v m m∈Pc exp (q ⊤ n k m ) -β * m = C c=1 β -β c ⊤ m∈Pc exp q ⊤ n k m (v m -β * m ) m ′ ∈U exp (q ⊤ n k m ′ ) = 0, plugging this result back we obtain J(β 1 = β, . . . , β C = β) = C c=1 m∈Pc exp q ⊤ n k m m ′ ∈U exp (q ⊤ n k m ′ ) β -β c 2 + β -β c ⊤ (β c -β * m ) + ∥β c -β * m ∥ 2 = C c=1 m∈Pc exp q ⊤ n k m m ′ ∈U exp (q ⊤ n k m ′ ) β -β c 2 + ∥β c -β * m ∥ 2 = C c=1 m∈Pc exp q ⊤ n k m m ′ ∈U exp (q ⊤ n k m ′ ) β -β c 2 ≥0 + C c=1 m∈Pc exp q ⊤ n k m m ′ ∈U exp (q ⊤ n k m ′ ) ∥β c -β * m ∥ 2 ≥ C c=1 m∈Pc exp q ⊤ n k m m ′ ∈U exp (q ⊤ n k m ′ ) ∥β c -β * m ∥ 2 . The last inequality holds since the first term is always non-negative. Note that the first term computes the squared error between β and each β c , weighted by the sum of attention scores over the corresponding subset. As a result, it is usually positive and the error reduction is significant if each β c deviates from β a lot. However, although the optimal coefficient in the partitioning always leads to lower error to the optimal individual coefficient, note that it does not necessarily yield lower estimation variance. Proposal Parameterization q c (ω) := N (ω; µ c , I) According to the definition of g m (•) and h m (•) in §3.1, for the vanilla RFA (Equation 4) we have µ c = q c + k c 76.67 µ c = q c 76.77 µ c = 0 76.24 µ c = RFA(q n , K, V) = S s=1 pn(ωs) q(ωs) f (ω s ) S s=1 pn(ωs) q(ωs) = g(ω) h(ω) = M m=1 g m (ω) M m=1 h m (ω) . Besides, since v m = g m (ω)/h m (ω) and β c (ω) = m∈Pc g m (ω) / m∈Pc h m (ω) , we can re-write EVA as EVA(q n , K, V) := g(ω) = m∈E g m (ω) + m / ∈E g m (ω) = m∈E exp(q ⊤ n k m ) Z v m + C c=1 m∈Pc exp(q ⊤ n k m ) Z β c (ω) = m∈E exp(q ⊤ n k m ) Z g m (ω) h m (ω) + C c=1 m∈Pc exp(q ⊤ n k m ) Z m∈Pc g m (ω) m∈Pc h m (ω) .

C MORE IMPLEMENTATION DETAILS FOR EVA

In this section, we provide more details of EVA. We also conduct a comprehensive ablation study to test the effect of different components in our implementation and report the results in Table 8 . The pseudo-code for EVA is listed in Algorithm 1. Approximating m∈Pc exp(q ⊤ n k m ) and Parameterizing k c . In our implementation, we approximate the sum of exponentials as m∈Pc exp(q ⊤ n k m ) ≈ exp(q ⊤ n k c ). Here we provide an informal justification for this approximation. Our main motivation for such approximation is based on the simple intuition that the sum of exponentials grows as fast as the maximum exponential value, as reflected by the following inequality, max m∈Pc exp(q ⊤ n k m ) ≤ m∈Pc exp(q ⊤ n k m ) ≤ |P c | max m∈Pc exp(q ⊤ n k m ). This means we can approximate the sum of exponentials by first computing the group representative k c := arg max km∈{km|m∈Pc} exp(q ⊤ n k m ), evaluating the corresponding exponential exp(q ⊤ n k c ) and then multiplying it by some scalar. Since computing the argmax operation still needs to compare each exponential dot-product, it will still incur quadratic computational costs. To circumvent this, we adopt a heuristic strategy that computes a learnable group representation, which attempts to compensate for the approximation error while only evaluating one exponential dot product. Through preliminary experiments, we try various choices to compute the representative vector of each subset, such as max and average pooling; however, we found these strategies produce almost equally good performance. As a result, we adopt the average pooling by default due to its simplicity. To be specific, we implement it as k c = σ 1 |P c | m∈Pc k m , where σ(•) is a trainable linear projection with the same hidden dimension size as inputs, followed by a layer normalization operation (Ba et al., 2016) to stabilize training. We leave further improving the approximation, such as deriving tighter error bounds or using more expressive pooling methods (Zaheer et al., 2017; Ou et al., 2022) as future work. Parameterizing β c (ω). As discussed in §4.1, we have β c (ω) = m∈Pc g m (ω) m∈Pc h m (ω) = S s=1 N (ωs;0,I) Zq(ωs) m∈Pc ξ(q n , ω s )ξ(k m , ω s )v m S s=1 N (ωs;0,I) Zq(ωs) m∈Pc ξ(q n , ω s )ξ(k m , ω s ) . Compared to the SNIS formulation of vanilla RFA Equation 4, we can express it as RFA(q n , K, V) = S s=1 pn(ωs) q(ωs) f (ω s ) S s=1 pn(ωs) q(ωs) = M m=1 g m (ω) M m=1 h m (ω) . We can think of each coefficient β c (ω) as computing the output of a localized RFA for each group P c . From this perspective, we can recast each coefficient β c (ω) as an SNIS estimator as well, which tries to estimate E ω∼pc(ω) [f c (ω)] = m∈Pc exp q ⊤ n k m m ′ ∈Pc exp (q ⊤ n k m ′ ) v m where f c (ω) := m∈Pc ξ(q n , ω)ξ(k m , ω)v m m ′ ∈Pc ξ(q n , ω)ξ(k m ′ , ω) , p c (ω) := N (ω; 0, I) m∈Pc ξ(q n , ω) ⊤ ξ(k m , ω) m ′ ∈Pc exp (q ⊤ n k m ′ ) = m∈Pc exp q ⊤ n k m m ′ ∈Pc exp (q ⊤ n k m ′ ) N (ω; q n + k m , I). This interpretation indicates that a good proposal distribution q c (ω) should be specific to each subset P c . To get close to the true distribution p c (ω) while keeping efficient computation, Zheng et al. (2022b) suggests parameterizing the proposal distribution as q c (ω) := N (ω; µ c , I) = N (ω; q c + k c , I), where q c is calculated similarly to Equation 26. We refer readers to Zheng et al. (2022b) for more discussions about the parameterization choice of proposal distributions. We conduct further ablation studies to test the effect of proposal parameterizations in our proposed model, as shown in Table 8 . In particular, we found our model is robust to different parameterization approaches. The essence in making the algorithm memory-efficient is to use only one sample in calculating β c (ω). In this case, we have β c (ω) = m∈Pc g m (ω) m∈Pc h m (ω) = N (ω c ;0,I) Zqc(ω c ) m∈Pc ξ(q n , ω c )ξ(k m , ω c )v m N (ω c ;0,I) Zqc(ω c ) m∈Pc ξ(q n , ω c )ξ(k m , ω c ) = N (ω c ;0,I) Zqc(ω c ) ξ(q n , ω c ) m∈Pc ξ(k m , ω c )v m N (ω c ;0,I) Zqc(ω c ) ξ(q n , ω c ) m∈Pc ξ(k m , ω c ) = m∈Pc ξ(k m , ω c )v m m∈Pc ξ(k m , ω c ) , w c ∼ q c (ω). Since this degenerated formulation eliminates the dependence on individual queries q n , we could precompute β c (ω) for each P c , and then re-uses them for each query, which takes up O(Cd) memory. If multiple samples are used instead, the influence of queries needs to be explicitly taken into account and thus we need to compute a distinct β c (ω) for each query, leading to O(N Cd) memory usage, which incurs a significant compute overhead. On the other hand, if we set C = 1, that is, using a shared β c (ω) over all m / ∈ E, our approach does not suffer from this issue, since the memory usage is at most O(N d). To investigate the effect of using larger C or increasing the number of samples, we conduct an ablative analysis as in Table 8 , and find that 1) when C = 1, the performance degrades a lot when using one sample, which can be largely improved by adopting more samples; while when C > 1, our partitioning strategy dominates and increasing the number of samples only improves performance marginally. This also validates the effectiveness of adopting a finer-grained treatment over control variates. Partitioning Strategy. EVA significantly improves random feature approximation by trying to locally estimate each subset of tokens, which is a much easier task than approximating the whole sequence as in previous RFA methods. To achieve this, EVA partitions the whole token sequence into multiple subsets according to the current query position n, which is denoted by {E n , P n 1 , P n 2 , . . . , P n C } N n=1 . 3 For elements in subset E n , we optimize the control variate coefficient to give an exact estimate for each single token m ∈ E n . In addition, we impose T5-style relative positional encoding (Raffel et al., 2020a) over elements in E n . While for some other subset P c , we employ the shared coefficient to approximate all tokens belonging to P c . We assume all E 1 , . . . , E N are of the same cardinality K, and |P n c | is the same for any c = 1, . . . , C and n = 1, . . . , N . The partition strategy {E n , P n 1 , P n 2 , . . . , P n C } N n=1 is decided based on a simple criterion: • for E n , it contains K local neighbors with respect to each query n. To further simplify implementation and reduce memory usage, we chunk the whole sequence into contiguous blocks of size K, and all adjacent queries belonging to the same block will share this block as the subset E n ; • as for P n 1 , P n 2 , . . . , P n C , we follow a similar treatment by splitting the complement [M ] \ E n into C contiguous chunks of the same size. For ease of implementation, we simply partition the whole index set [M ] into multiple groups instead of [M ] \ E n , which circumvents the overload for explicitly performing set difference operations in practical implementation. Although this leads to extra approximation error, this amounts to putting more attention weights on tokens belonging to the subset E and we found this approximation does not lead to performance degradation (Table 8 ).

D A CAUSAL VARIANT OF EVA

In this section, we describe the causal variant of EVA, where each query can only attend to historical tokens. Thanks to the partitioning scheme, all future information with respect to the current query token can be masked conveniently. Following the formulation of EVA, we partition the whole sequence into C + 1 subsets {E n , P n 1 , P n 2 , . . . , P n C } with respect to each query q n . To fulfill the the corresponding configurations are listed in Table 14 . 5 The vocabulary size is 267,744 with adaptive input embeddings (Baevski & Auli, 2019) . During training, we set the sequence length to 512 and evaluate the validation/test PPL with various context window sizes in {256, 480}, aligning with previous work (Baevski & Auli, 2019; Kasai et al., 2021) . For other random feature baselines, unfortunately, we failed to fully replicate their results as reported in Kasai et al. (2021) , where RFA in our implementation achieved a test perplexity of 29.0 even under a 449M Transformer model. For EVA, we set |E| = 128 and C = 64 by default for both 16-layer and 32-layer settings, ensuring similar computational cost to previous work that also evaluates random feature methods (typically with 128 or 256 random-feature dimension size) on Wikitext-103 language modeling task (Schlag et al., 2021; Kasai et al., 2021) .

E.4 EXPERIMENTAL SETTINGS OF EFFICIENCY COMPARISON

For the simulation experiment conducted in §5.3, we adopt the same transformer architecture across all attention variants. In particular, it uses 8 transformer layers, 192 embedding dimensions, and 2 attention heads so that longer sequences can fit into our devices. The batch size is set to 64 across 8 V100 GPUs, and the statistics are computed by averaging the results of 30 runs. Besides, in our ablation study, the efficiency metrics reported in Table 6 and Table 7 are evaluated under the same setup used during training. Remark on Modeling Short Sequences. Unfortunately, similar to most previous efficient attention baselines, EVA also runs slower than softmax attention under shorter sequences (e.g., length of 128 or 256), but it soon catches up in running speed, and the reduction of memory consumption is still significant. Besides, in short-sequence settings (such as the case of DeiT-Tiny/Small with sequences of 196 tokens), EVA often performs on par with or better than conventional softmax attention (see Table 1 ), whereas most previous attention variants usually perform much worse. This implies EVA can achieve a better trade-off between efficiency and quality: for short sequences, EVA is possible to achieve stronger performance competitive with softmax attention (despite in longer running time); while for long sequences, EVA can be run much faster with less memory. Comparison to Memory-efficient Attention Mechanisms. In this section, we conduct an empirical efficiency comparison between efficient approximate attention methods and FlashAttention, one of the memory-efficient attention mechanisms (Rabe & Staats, 2021; Dao et al., 2022) with optimized memory accesses. FlashAttention computes the exact softmax attention in an online manner without materializing the full attention matrix, achieving linear memory complexity with respect to sequence lengths; besides, both runtime and memory usage are further improved by minimizing IO accesses. We benchmark different attention modules on one NVIDIA GeForce RTX 3090 GPU, where we measure the memory usage and runtime of running a single attention block, consisting of 8 attention heads with 512 embedding dimension size, for both a forward and backward pass. As shown in Figure 2 , we observe that FlashAttention achieves significant memory usage reduction for softmax attention approximation and even consumes much less memory than all considered approximate baselines under all sequence lengths. In terms of runtime, we notice that FlashAttention runs faster than most attention baselines under sequence lengths less than 2048 despite scaling quadratically, but EVA, along with other more efficient approximate variants, begin to catch up at longer sequence lengths. This implies that the quadratic computational costs of softmax attention still bottleneck its runtime performance, aligning with one of the main findings in Dao et al. (2022) . According to this empirical study, we observe that FlashAttention offers a general and effective technique to speed up softmax attention; since many approximate variants (including EVA) exhibit a similar formulation to softmax attention (e.g., Equation 16), we expect they can also benefit from the optimized online softmax calculation technique and memory accesses of FlashAttention (Dao et al., 2022) .

F EXPERIMENTS ON LONG RANGE ARENA

Long Range Arena (LRA; Tay et al., 2021) is a lightweight benchmark that assesses the ability of efficient attention methods to model long sequences in diverse domains. We follow the same hyper-parameter setup as Xiong et al. (2021b) to re-evaluate all attention baselines and report the (Zheng et al., 2022b) , and performs competitively with full softmax attention. Notably, EVA even achieves better average results over all tasks, with higher accuracy on Image and Pathfinder benchmarks, suggesting its capability of capturing long-term dependencies. For LRA benchmark, we set all attention-specific hyper-parameters to 128 (e.g., the number of landmarks in Nyströmformer (Xiong et al., 2021b) and LARA (Zheng et al., 2022b) , the window size in local attention and Combiner (Ren et al., 2021) , etc.). We set |E| = 128 and C = 64 by default for EVA without any further tuning and find this setup works well.

G CONNECTIONS TO OTHER ATTENTION MECHANISMS

G.1 RFA, SOFTMAX ATTENTION, AND EVA As mentioned in our main text, one of the main contributions of this work is to develop a more general framework that bridges RFA and conventional softmax attention. To see how EVA (Equation 13) achieves this goal formally, note that if either |E| = M or C = M , EVA would be equivalent to standard softmax attention; while if we set |E| = 0 and C = 1, EVA would recover vanilla RFA. G.2 CONNECTIONS TO LARA Notably, EVA and LARA (Zheng et al., 2022b) are two efficient attention mechanisms that are both built upon the self-normalized importance sampling (SNIS) formulation of RFAs. LARA (Zheng et al., 2022b) puts the main focus on the proposal distribution used in SNIS and tries to design importance sampling proposals that are closer to the true underlying distribution. The proposed usage of multiple proposals further improves the estimation quality of SNIS and achieves strong empirical performance while still keeping linear complexity. In contrast to LARA, in this work we do not focus on the design choice of proposals used in importance sampling but aim to generalize the SNIS formulation further via control variates. As demonstrated in §3.2, our theory clearly delineates how the gap between such SNIS estimation and softmax attention can be closed by manipulating control variates. Since LARA and RFA are both SNIS estimators (their main difference lies in the choice of proposal distributions), our generalization also applies to LARA. To summarize, compared with LARA, EVA is a more general framework and improves conventional RFA from an orthogonal perspective.

G.3 CONNECTIONS TO CLUSTERED ATTENTION

Clustered attention (Vyas et al., 2020) is an efficient attention mechanism that first clusters the set of queries into multiple groups, computes the mean centroid of each group, and then performs attention between query centroids and original key-value pairs. This framework is fast and effective and enjoys well-bounded approximation error. Scatterbrain is a Special Case of EVA. For notational convenience, we denote E := Supp n (S). According to Proposition 1, suppose we employ optimal coefficients β m for all entries in Supp n (S), and use the same coefficient β for all the remaining entries (in other words, we let C = 1 and the whole index set is only partitioned into two subsets {E, [M ] \ E}). Then we have g m (ω) = g m (ω) -β m h m (ω) + β m exp(q ⊤ n km) Z = exp(q ⊤ n km)vm Z , if m ∈ E, g m (ω) -βh m (ω) + β exp(q ⊤ n km) Z , if m / ∈ E. And the resulting estimator overall becomes g(ω) = M m=1 g m (ω) = m∈E g m (ω) + m / ∈E g m (ω) = m∈E exp(q ⊤ n k m )v m Z + m / ∈E g m (ω) -βh m (ω) + β exp(q ⊤ n k m ) Z = m∈E exp(q ⊤ n k m )v m Z + m / ∈E g m (ω) -βh m (ω) + β m / ∈E exp(q ⊤ n k m ) Z = m∈E exp(q ⊤ n k m )v m Z + m / ∈E g m (ω) -βh m (ω) + β 1 - m∈E exp(q ⊤ n k m ) Z . Scatterbrain (Chen et al., 2021a ) can be a special case of this estimation algorithm if we set the proposal distribution to q(ω) = N (ω; 0, I), and estimate the normalizing constant as follows. Z = E ω∼q(ω) N (ω; 0, I) m∈E ξ(q n , ω) ⊤ ξ(k m , ω) + m / ∈E ξ(q n , ω) ⊤ ξ(k m , ω) q(ω) = m∈E exp(q ⊤ n k m ) + E ω∼q(ω) N (ω; 0, I) m / ∈E ξ(q n , ω) ⊤ ξ(k m , ω) q(ω) ≈ m∈E exp(q ⊤ n k m ) + 1 S S s=1 N (ω; 0, I) m / ∈E ξ(q n , ω) ⊤ ξ(k m , ω) q(ω s ) ξ(q n , ω s )ξ(k m , ω s ). = m∈E exp(q ⊤ n k m ) + 1 S S s=1 m / ∈E ξ(q n , ω) ⊤ ξ(k m , ω) = m∈E exp(q ⊤ n k m ) + With these specifications, we obtain g(ω) = m∈E exp(q ⊤ n k m )v m Z + m / ∈E g m (ω) -βh m (ω) + β 1 - m∈E exp(q ⊤ n k m ) Z = m∈E exp(q ⊤ n k m )v m Z + m / ∈E g m (ω) -βh m (ω) + β Z -m∈E exp(q ⊤ n k m ) Z ≈ m∈E exp(q ⊤ n k m )v m Z + m / ∈E g m (ω) -βh m (ω) + β m / ∈E h m (ω) Z = m∈E exp(q ⊤ n k m )v m Z + m / ∈E g m (ω) -βh m (ω) + β m / ∈E h m (ω) = m∈E exp(q ⊤ n k m )v m Z + m / ∈E g m (ω) = m∈E exp(q ⊤ n k m )v m Z + m / ∈E 1 S S s=1 ξ(q n , ω s )ξ(k m , ω s )v m Z = m∈E exp(q ⊤ n k m )v m Z + m / ∈E ϕ(q n , ω) ⊤ ϕ(k m , ω)v m Z ≈ m / ∈E ϕ(q n , ω) ⊤ ϕ(k m , ω)v m + m ′ ∈E exp q ⊤ n k m ′ v m ′ m / ∈E ϕ(q n , ω) ⊤ ϕ(k m , ω) + m ′ ∈E exp (q ⊤ n k m ′ ) which is equivalent to Scatterbrain (Equation 30). Note that this equivalence would hold irrespective of the choice of shared coefficients β, which possibly indicates that the formulation of Scatterbrain limits the potential benefit of optimizing control variates under our framework.



Note that the expectation of individual control variates hm(•) is still in closed form as E [hm(ω)] = exp(q ⊤ n km)/Z. The derivation can be found in Appendix B.3. Here we add the superscript n to reflect the dependence on query position n. we retain the repeated augmentation technique in training PVT to be consistent with the original training protocol inWang et al. (2021b). The setup inBaevski & Auli (2019) can be found in the corresponding Fairseq training script: https://github.com/pytorch/fairseq/blob/master/examples/language_ model/README.adaptive_inputs.md.



Figure 2: Left and right: Additional empirical memory consumption and running time comparison for different attention mechanisms under various sequence lengths.

n , ω) ⊤ ϕ(k m , ω) h m (ω) = Zh m (ω), as in this case g(ω) = 1 S S s=1 p n (ω s ) q(ω s ) f (ω s ) = 1 S n , ω s )ξ(k m , ω s )v m ,

Classification accuracy on ImageNet1k in comparison to different RF-based approximations. † vanilla PVT-v2-b3(Wang et al., 2021b)  uses a convolutional kernel to downsample key and value vectors, resulting in fewer FLOPs but with significant performance degradation.

Image classification accuracy onImageNet1k dataset with DeiT-Tiny-784.

Masked Language Modeling Perplexity on the Books3 validation dataset.

BLEU scores on the test set of WMT14 En-De. † numbers are taken fromZheng  et al. (2022b).

Validation (Val.) and Test perplexity (PPL) on Wikitext-103. 256/480 indicate evaluation context window sizes. † numbers are due to Kasai et al. (2021).

Left and middle: empirical memory consumption and running time comparison respectively of different attention mechanisms under various sequence lengths. Right: a snapshot of MLM validation loss curve versus actual elapsed time during training.complexity of EVA, it can be scaled further to much longer sequences. With input sequences of length increased to 4096, EVA (denoted as "EVA-4096") attains lower validation perplexity than exact softmax attention, which demonstrates its capability of scaling to much longer sequences. Besides, machine translation results are compared in Table4, where in this task C = 8 by default and EVA-m denotes EVA with |E| = m.

MLM validation perplexity onBooks3. "-" indicates fail to converge.

Classification results on ImageNet1k dataset under different hyper-parameter configurations of EVA. By default, we set |E| = 49 and C = 49 across all variants below.

Our hyper-parameter configuration for machine translation.

Our hyper-parameter configuration for autoregressive language modeling.

Classification accuracy (%) on LRA benchmark with different efficient attention mechanisms.

We observe that EVA largely improves previous RFA methods such as Performer(Choromanski et al., 2021) and LARA

ACKNOWLEDGMENTS

We would like to thank the HKU NLP group, the Shark-NLP group, and the anonymous reviewers for their valuable suggestions that greatly helped improve this work. This work is partially supported by the joint research scheme of the National Natural Science Foundation of China (NSFC) and the Research Grants Council (RGC) under grant number N_HKU714/21.

annex

Compute q c (ω) according to Equation 28; Sample ω c ∼ q c (ω); ▷ During inference, simply setcausal requirement, we design two different types of masking matrices to deal with both E n and {P n c } C c=1 respectively. • For E n , we adopt a single lower-triangular matrix with shape K × K (recall that each set E n is of size K) to mask future tokens locally, similar to the case of standard decoder softmax attention. Future tokens that do not belong to E n are handled by masking functions for {P n c } C c=1 , as described below.• For {P n c } C c=1 , we make use of the fact n ∈ E n . Since any P n c and E n are disjoint, we only need to mask all subsets P n c that appear after E n . This amounts to first allocating a lower-triangular matrix with shape C × C, and then conducting future masking at a subset level.The pseudo-code for the causal variant of EVA is listed in Algorithm 2.

E EXPERIMENTAL DETAILS

All of our experiments are conducted with at most 16 NVIDIA V100 GPUs.

E.1 EFFICIENT ATTENTION BASELINES

We compare our proposed attention mechanism EVA against various baselines:• Performer (Choromanski et al., 2021) , which uses the plain random features to approximate softmax attention; • LARA (Zheng et al., 2022b) , an advanced RF approximation that makes use of multiple adaptive proposals to construct the SNIS estimator; • Linformer (Wang et al., 2020) , a low-rank approximation that uses a learnable matrix to project the key-value sequence into a shorter one; • Nyströmformer (Xiong et al., 2021b) , a low-rank approximation that adopts the Nyström method to approximate softmax attention map with a sub-sampled matrix; • Local attention (Child et al., 2019) , a simple sparse approximation that splits the whole sequence into multiple blocks and only allows the query to attend to tokens within the same block; • Reformer (Kitaev et al., 2020) , a sparse approximation where hash functions are used to adaptively distribute sequence tokens into multiple buckets, and each token can only attend to tokens within the same bucket; • Scatterbrain (Chen et al., 2021a) , an approach that combines Performer and sparse attention. The details can be found in Appendix G. Here we implement the sparse module as a simple local attention to ensure a fair comparison; • Combiner (Ren et al., 2021) , a probabilistic approach that constructs a structured factorization over the softmax probability distribution via a sparse mechanism. Combiner allows both direct and indirect calculations of conditional probabilities, where the direct probability is implemented as the sparse mechanism while the indirect probability is implemented through a local abstraction over a group of tokens. Similarly, we implement the sparse mechanism as a simple local attention, which corresponds to the Combiner-Fixed variant (Ren et al., 2021); • Transformer-LS, or Long-Short (Zhu et al., 2021) , which is proposed to model long-term and short-term dependencies via low-rank structures and local attention respectively. The low-rank structure is defined as an input-dependent weight matrix that compresses the sequence into a shorter one; while the local attention is defined similarly as above.Note that for all mechanisms that involve a local attention, we split the sequence into non-overlapping blocks (or 2D windows in terms of images) and each query can only attend to tokens within the same block. We also use the relative positional embedding (Raffel et al., 2020b; Liu et al., 2021) within the local attention computation. Unlike Transformer-LS (Zhu et al., 2021) that allows each query to attend to multiple blocks, we do not use this extension as we find greatly increases memory consumption, although it does improve the model performance.

E.2 IMAGE CLASSIFICATION

Through the experiments on image classification, we consider four different vision transformer (ViT) architectures: • DeiT-Tiny (Touvron et al., 2021) , which maintains the sequence length as 196 across all transformer layers. For the particular tiny variant, the number of transformer layers is set to 12, the embedding dimension is set to 196 and the number of heads is 3; • DeiT-Small (Touvron et al., 2021) , which scales the embedding dimension and number of attention heads in DeiT-Tiny up to 384 and 6, respectively; • DeiT-Tiny-784, where the architecture is the same as DeiT-Tiny but the patch size in the tokenization step is decreased from 16 to 8. This effectively increases the sequence length from 196 to 784, which we found consistently improves predictive accuracy at the cost of significantly increased time and memory consumption. Under this setting, we also see clearer differences among these attention variants and it helps better evaluate the ability of different attention models to learn visual representations; • PVT-v2-B3 (Wang et al., 2021b) , a pyramidal transformer architecture that processes much longer token sequences at early layers and progressively reduces the sequence length to form a hierarchical structure. It patchifies input images into 3136 (56 × 56) tokens, and then processes the sequence through 4 stages. Each stage contains several transformer layers and a down-sampling operation, which reduces the sequence length by a factor of 4 and increases the embedding dimension by 2×. Due to the prohibitively long sequences initially, PVT applies an additional down-sampling module on input sequences to obtain key and value vectors, which are then passed through a normal softmax attention mechanism. To evaluate different RF approximations, we remove the down-sampling operation and directly operate on the original sequence length, which results in much fewer model parameters than vanilla PVT-v2-B3. We refer readers to Wang et al. (2021b) for detailed architecture configurations.For training, we do not use the [CLS] token for classification (Touvron et al., 2021) ; instead, we pool over the output of the last transformer layer to extract features and feed them into the classifier head. We followed the same protocol to train all model variants. Closely following DeiT Touvron et al. (2021) , we employ the AdamW (Loshchilov & Hutter, 2019) optimizer to train models for 300 epochs, where the number of warm-up epochs is 10, the learning rate is 0.001 with cosine learning rate decay (Loshchilov & Hutter, 2016) , and batch size is set to 1024. The adopted augmentation and regularization are the same as DeiT, except that we remove repeated augmentation (Hoffer et al., 2020) in DeiT models as it often slows down convergence, as also observed in previous studies (Xiao et al., 2021) . 4 The specific configurations of each attention mechanism on DeiT-Tiny-784 are listed in Table 9 . The hyper-parameter setup for each attention variant follows previous practices (Wang et al., 2021a; b; Zheng et al., 2022b) closely to ensure a similar computational cost.Comparison to State-of-the-Art Model Architectures. We also compare our model against recent state-of-the-art (SOTA) model architectures with similar parameter sizes on ImageNet1k benchmark. As reported in improves the predictive accuracy and performs competitively with recent SOTA architectures while using fewer parameters and FLOPs.

E.3 MACHINE TRANSLATION AND LANGUAGE MODELING

Our implementation for all language tasks is based on FairSeq toolkit (Ott et al., 2019) . To compare different methods, we report BLEU scores on the test set as the main metric for MT and perplexity for both Autoregressive LM and MLM tasks. For the hyper-parameters |E| and C in EVA, we set |E| = 2C by default, as we find that this choice attains a good trade-off between performance and computational costs across various tasks; while for C, it is determined based on previous practice for each task. Here we provide the detailed experimental protocol for each task.Masked Language Modeling. Following the standard pretraining practice as in RoBERTa (Liu et al., 2019) , in MLM, we aim to reconstruct a subset of tokens in the input sequence that are randomly masked out, which is the core element of BERT-style natural language pretraining (Devlin et al., 2019) . This setting allows us to investigate the generalization ability of our model on larger model sizes and much more data. The task performance is measured with validation perplexity, which reflects how well the model fits the pretraining corpus and also exhibits good correlations with downstream task metrics. For the used corpus Books3, we randomly select 100 books without replacement for the validation split, similar to the setup in C4 dataset (Raffel et al., 2020b) . For the model, we use the RoBERTa-base architecture (Liu et al., 2019) , where all the layer normalization operations (Ba et al., 2016) are placed before attention and FFN blocks (i.e., we adopt the pre-norm architecture), which leads to much more stable training for efficient attention mechanisms. We replace all softmax attention with EVA to test its effectiveness. The training setting and attention-specific parameters, which follow previous studies (Xiong et al., 2021a) to ensure a similar computational cost, can be found in Table 11 and Table 12 respectively.Machine Translation. We follow Ott et al. (2018) to process WMT14 En-De dataset, resulting in around 4.5M/3K/3K English-German sentence pairs for training/validation/testing splits, respectively, and a shared vocabulary is obtained between the source and target language of around 32K BPE types. The architecture and training specifics closely follow Vaswani et al. (2017) , as listed in Table 13 . We follow the previous protocol Zheng et al. (2022b) by replacing all encoder self-attention blocks in the encoder-decoder Transformer with EVA. For EVA, we find it beneficial to introduce an overlapping variant of E, where we allow E to be overlapped with each other. Following previous practice in the context of local attention (Xiong et al., 2021a) , E not only contains all elements within the designated chunk but also additionally includes half the tokens in its neighboring chunks. As a result, EVA-32 corresponds to |E| = 32 with a contiguous chunk size of 16. During inference, we follow the same setup as Zheng et al. (2022b) and average the last 10 model checkpoints to obtain the final model parameters. We apply beam search with size 4, length penalty 0.6, and compound split Zheng et al. (2022b) . Note that increasing C also leads to better translation quality, although we found the performance gain is slightly less effective than that of increasing |E| (c.f. Tables 6 and 7 ).Autoregressive Language Modeling. We consider Wikitext-103 benchmark in this task, which consists of around 103M/218K/246K tokens for training/validation/testing splits, respectively. We adopt the vanilla transformer decoder architecture (Vaswani et al., 2017) , replace all decoder self-attention modules in the Transformer with the causal EVA mechanism, and evaluate EVA under two different setups: 1) a standard 16-layer Transformer LM (with model sizes of around 247M) as in Baevski & Auli (2019) , and 2) a larger 32-layer Transformer LM (with model sizes of around 450M) as in Kasai et al. (2021) . We follow their hyper-parameter settings to train all models, where Clustered attention and EVA share some similarities in two aspects. First, both of them adopt the partitioning technique to reduce the computational complexity while remaining effective; and secondly, both observe that the efficient attention mechanism can be improved by refining the approximation over specific elements. For instance, clustered attention can be improved (Vyas et al., 2020) by selecting top-k key-value pairs that are most relevant to each centroid and then refining the approximation by recomputing attention weights over these keys using original queries; while EVA notices that we can directly employ the optimal control variate coefficient for a subset of key-value pairs (m ∈ E) while still remaining efficient, which yields a more accurate approximation.Nevertheless, our main technical contribution is to develop a control variate formulation in the context of RFA and demonstrate that how RFA can be further improved locally. On the other hand, while clustered attention (Vyas et al., 2020) clusters queries, EVA partitions key-value pairs. This property makes EVA more amenable to the case of autoregressive language modeling since we do not impose clustering structures over the query set, and thus the causal relation among queries can be well maintained.

G.4 CONNECTIONS TO COMBINER

Combiner (Ren et al., 2021 ) is a recently proposed attention mechanism that also partitions the sequence into chunks combined with local attention. The key difference between EVA and Combiner is the motivation, where Combiner introduces a structured factorization over the attention probability distribution, while our approach is built from the control variate perspective.

G.5 CONNECTIONS TO SCATTERBRAIN

In this section, we show that Scatterbrain (Chen et al., 2021a) can be cast as a special case of our framework EVA, although they are proposed based on quite different motivations.A Brief Review of Scatterbrain. Scatterbrain (Chen et al., 2021a) notes that sparse attention and RFA can approximate sharp and flat regions of the softmax attention matrix well, respectively. Based on this insight, Scatterbrain is proposed to first compute a Performer approximation to softmax attention and then cancel out the approximation error on critical regions via a sparse mechanism.Specifically, Scatterbrain (Chen et al., 2021a) defines a sparse matrix S ∈ R N ×M ) so that for each (n, m) ∈ S that indexes a non-zero entry. For notational simplicity, we also denote Supp(S) = {(i, j)|S ij ̸ = 0} and Supp n (S) = {m|S nm ̸ = 0}. With random features ϕ(•, •) defined in Appendix A, we let S nm = exp q ⊤ n k m -ϕ(q n , ω) ⊤ ϕ(k m , ω). We then add it back to the approximate output:The sparse mechanism can be thought of as modeling the error due to RFA and eliminating it on the support of S. After the correction step, Scatterbrain further adds a post-hoc normalization step to obtain a normalized attention output:∈Supp n (S) ϕ(q n , ω) ⊤ ϕ(k m , ω) + m ′ ∈Supp n (S) exp (q ⊤ n k m ′ ) .Intuitively, Scatterbrain (Chen et al., 2021a) produces accurate approximation in the support of the sparse matrix and remains the random feature approximation outside the support.

