EFFICIENT ATTENTION VIA CONTROL VARIATES

Abstract

Random-feature-based attention (RFA) is an efficient approximation of softmax attention with linear runtime and space complexity. However, the approximation gap between RFA and conventional softmax attention is not well studied. Built upon previous progress of RFA, we characterize this gap through the lens of control variates and show that RFA can be decomposed into a sum of multiple control variate estimators for each element in the sequence. This new framework reveals that exact softmax attention can be recovered from RFA by manipulating each control variate. Besides, it allows us to develop a more flexible form of control variates, resulting in a novel attention mechanism that significantly reduces the approximation gap while maintaining linear complexity. Extensive experiments demonstrate that our model outperforms state-of-the-art efficient attention mechanisms on both vision and language tasks.

1. INTRODUCTION

Random-feature-based attention (RFA, also known as Performer; Choromanski et al., 2021; Peng et al., 2021b) is an established fast approximation to the conventional softmax attention mechanism (Bahdanau et al., 2014; Vaswani et al., 2017) , which successfully scales Transformer models to processing much longer sequences (Choromanski et al., 2021) . At its core is the usage of random features (RF; Rahimi & Recht, 2008) to linearize the exponential kernel in softmax attention, which reduces the computational cost from quadratic to linear runtime and space complexity. Despite its efficiency, recent studies have pointed out that such approximation suffers from substantial performance degeneration (Xiong et al., 2021a; Zheng et al., 2022b) . In this work, we generalize the formulation of RFA via control variates (Owen, 2013) , which characterizes the approximation gap between RFA and softmax attention in theory. We first show that RFA can be decomposed from a global approximation over the whole sequence into a sum of local control variate estimators, each of which is applied to an individual element in the sequence. Under this formulation, RFA is equivalent to employing the same coefficient for all control variate estimators to scale their variance isotropically ( §3.1). Besides, we prove that if we optimize the coefficient of each control variate to minimize the estimation variance individually, RFA estimation becomes exact, that is, softmax attention is recovered with zero bias and zero variance ( §3.2). Our key observation is that such formulation reveals a localized perspective of the RFA approximation. Instead of directly seeking a better estimate over the entire sequence, we can break down the problem into smaller problems that aim at improving the approximation for each subsequence ( §4). The control variate estimator for each subsequence can be tuned separately and combined to yield better estimation, which provably reduces approximation error in the global sense ( §4.1). Nevertheless, one caveat is that as the number of sub-problems increases, the approximation gap will be reduced but at the expense of higher computational complexity. For instance, if we optimize the control variate for every single element, softmax attention would be recovered as desired but with quadratic complexity. To attain a good trade-off between approximation quality and efficiency, we develop a new Efficient attention via control VAriates (EVA) that implements this divide-and-conquer strategy efficiently. In EVA, the sequence is partitioned into a fixed number of disjoint subsets. For the subset that might bear the highest correlations to the query, we explicitly optimize the control variate for each element, which recovers exact softmax attention probabilities; while for the others, the control variate coefficient is shared locally among all elements within the same subset. The resulting attention mechanism is not only highly effective but also runs with the same computational complexity as RFA ( §4.2). Extensive experiments on both language and vision tasks demonstrate that EVA outperforms the state-of-the-art efficient attention methods ( §5).

2. BACKGROUND

2.1 SOFTMAX ATTENTION MECHANISM Assume there exist a set of N queries {q n } N n=1 and M key-value pairs K = [k 1 , . . . , k M ] and V = [v 1 , . . . , v M ], where queries, keys and values are all d-dimensional vectors. The softmax attention mechanism (Bahdanau et al., 2014; Vaswani et al., 2017) is defined as an average over the value vectors weighted by the dot-product similarities of the queries and keys. For the n-th query, the attention mechanism outputs SoftmaxAttn(q n , K, V) := M m=1 exp q ⊤ n k m M m ′ =1 exp (q ⊤ n k m ′ ) v m . (1) In the case of self-attention (Lin et al., 2017; Vaswani et al., 2017) , we have M = N , which results in quadratic computational complexity since we have to compute the similarity for each query-key pair explicitly.

2.2. RANDOM-FEATURE-BASED ATTENTION WITH SELF-NORMALIZED IMPORTANCE SAMPLING

Recently, Zheng et al. (2022b) identifies that softmax attention (Equation 1) can be written as an expectation over an attention-like aggregating function, SoftmaxAttn(q n , K, V) = M m=1 exp q ⊤ n k m M m ′ =1 exp (q ⊤ n k m ′ ) v m = E ω∼pn(ω) [f n (ω)] , where f n (ω) := M m=1 ξ(q n , ω)ξ(k m , ω)v m M m ′ =1 ξ(q n , ω)ξ(k m ′ , ω) , p n (ω) := N (ω; 0, I) M m=1 ξ(q n , ω) ⊤ ξ(k m , ω) Z . Here ξ(•, •) is the randomized mapping defined in such a way that exp q ⊤ n k m = E ω∼N (0,I) ξ(q n , ω) ⊤ ξ(k m , ω) , and Z = M m=1 exp q ⊤ n k m denotes the normalizing constant of distribution p n . Throughout this paper, we consider the positive randomized mapping ξ(x, ω) = exp ω ⊤ x -1 2 ∥x∥ 2 (Choromanski et al., 2021) unless otherwise specified. Random-Feature-based Attention (RFA) methods (Choromanski et al., 2021; Peng et al., 2021b) can be interpreted as performing self-normalized importance sampling (SNIS; Hesterberg, 1995) to approximate Equation 2 (Zheng et al., 2022b) . In SNIS, one draws Monte Carlo samples from some proposal distribution q(ω) instead of the true distribution p n (ω) and estimates the target expectation as E ω∼pn(ω) [f n (ω)] = E ω∼q(ω) pn(ω) q(ω) f n (ω) ≈ S s=1 pn(ω) q(ω) fn(ωs) S s=1 pn (ωs) q(ωs ) , where ω 1 , . . . , ω S ∼ q(ω). Vanilla RFA amounts to constructing the SNIS estimation with q(ω) = N (ω; 0, I). The SNIS representation also turns out equivalent to the more established form of RFA, RFA(q n , K, V) := S s=1 pn(ωs) q(ωs) f (ω s ) S s=1 pn(ωs) q(ωs) = M m=1 ϕ(q n , ω) ⊤ ϕ(k m , ω)v m M m ′ =1 ϕ(q n , ω) ⊤ ϕ(k m ′ , ω) , where the random feature, denoted by ϕ(x, ω) := 1/ √ S[ξ(x, ω 1 ), . . . , ξ(x, ω S )] ⊤ , is proposed to approximate exponential kernels in its original motivation (see Appendix A for a detailed review).

