YOU ONLY SAMPLE (ALMOST) ONCE: LINEAR COST SELF-ATTENTION VIA BERNOULLI SAMPLING Anonymous

Abstract

Transformer-based models have come to dominate the landscape in a wide range of natural language processing (NLP) applications. The heart of the transformer model is the self-attention mechanism, which captures the interactions of token pairs in the input sequences and consequently, depends quadratically on the input sequence length. It is known that training such models on longer sequences is quite expensive, and often, prohibitively so. We show that a Bernoulli sampling attention mechanism based on Locality Sensitive Hashing (LSH), decreases the quadratic complexity to linear. We bypass the quadratic cost by considering selfattention as a sum of individual tokens associated with Bernoulli random variables that can, in principle, be sampled at once by a single hash (although in practice, this number may be a small constant). This leads to an efficient sampling scheme to estimate self-attention which relies on specific modifications of LSH (based on feasibility of deployment on GPU architectures). We evaluate our proposed algorithm on the GLUE benchmark with standard 512 sequence length and our method achieves comparable or even slightly better performance than a standard pretrained Transformer. To evaluate whether our method can indeed handle longer sequences, we conduct experiments on long sequence (4096) language model pretraining and achieve consistent results as standard self-attention, while observing sizable inference speed-ups and memory savings.

1. INTRODUCTION

The Transformer model (Vaswani et al., 2017) is incredibly effective across a diverse set of natural language processing (NLP) applications including machine translation (Vaswani et al., 2017) , language inference (Devlin et al., 2018) and paraphrasing (Raffel et al., 2019) . Transformer-based models such as BERT (Devlin et al., 2018) are pretrained in an unsupervised manner and later finetuned on different downstream tasks, often providing state-of-the-art performance on standard benchmarks. While such models have strong empirical performance, their high computational and memory requirements remain quite high. Consequently, in the NLP setting, most current models have certain constraints on the sequence length, e.g., BERT and other transformer-based language models (Yang et al., 2019; Liu et al., 2019) limit the sentence length to be at most 512. The Multi-Head Self-Attention is central to Transformer based models and provides a flexible global receptive field to exchange information among input tokens. While self-attention provides immense benefits, it is also a key bottleneck in training with long sequences. In particular, the output of self-attention is a combination of all tokens where coefficients are determined by the similarities among tokens. While this is empirically beneficial, it involves a sizable resource footprint. For sequence length n, this leads to a O(n 2 ) complexity in both time and memory to compute pairwise similarities among all input tokens. This quadratic cost is a roadblock in attaining potential benefits that may be realizable in various applications by capturing long term context dependencies. As we will discuss in more detail later, the foregoing issue is a major thrust of several recent and ongoing efforts focused on mitigating the sizable resource requirements of such models. Our work is inspired by ideas of importance sampling via hashing-based sampling techniques (Spring & Shrivastava, 2017; Charikar & Siminelakis, 2017) . We proposed a Bernoulli based sampling to approximate self-attention, scaling linearly with the input sequence length. We achieve this by viewing self-attention as a sum of individual tokens associated with Bernoulli random variables whose success probability is determined by the similarities among tokens. In principle, we can sample all Bernoulli random variables at once with a single hash (although in practice, this number may be a small constant to lower the approximation variance). This leads to an efficient sampling scheme to estimate self-attention which relies on specific modifications of hashing-based importance sampling (based on feasibility of deployment on GPU architectures). The resulting strategy (You Only Sample Almost Once, YOSO-Attention) is far more amenable to an efficient and backpropagation friendly implementation, and has a favorable empirical performance profile on natural language modeling tasks. We evaluate our proposed algorithm on the GLUE benchmark (Wang et al., 2019) with 512 sequence length as well as on long sequence language model pretraining where we see promising results with speed-ups and memory savings.

2. BACKGROUND: SELF-ATTENTION

Self-Attention. Self-attention is a scaled dot-product attention mechanism to capture token dependencies in the input sequence, which can be defined as, A(Q, K, V ) = softmax     (QW Q )(KW K ) T √ d h P     V W V = D P exp (P) V W V (1) where Q, K, V ∈ R n×d are embedding matrices from the input sequence, called queries, key and values respectively. Here, n is the input sequence length, d is the embedding dimension of each token, W Q , W K , W V ∈ R d×d h are learned parameter matrices, d h is the dimension of hidden embedding, and D P is a n × n diagonal matrix which normalizes each row of the exp (P) matrix such that the row entries sum up to 1. For simplicity, we overload the notations for Q, K, V to denote QW Q , KW K , V W V in our presentation. Multi-Head Self-Attention. Multi-Head self-attention in Transformers runs through the scaled dot-product attention multiple times and the attention outputs are concatenated to help the model capture information from multiple representation subspaces Vaswani et al. (2017) . Multi-Head Selfattention can be formally written as, MultiHead(Q, K, V ) = Concat A 1 (Q, K, V ), • • • , A h (Q, K, V ) W (2) where h is the number of heads, A i , i = 1, . . . , h are heads with different parameter matrices. Self-Attention Bottleneck. A key bottleneck in self-attention is computing the softmax matrix, softmax(P), which requires the calculation of all pairwise input token similarities. To reduce this cost, we seek to approximate the softmax matrix by viewing self-attention for each query as an expectation of a softmax distribution and computing the approximated self-attention with an efficient sampling mechanism. In the following sections, we will first review LSH-based importance sampling and then propose Bernoulli sampling with LSH to estimate self-attention efficiently.

3. IMPORTANCE SAMPLING VIA LOCALITY SENSITIVE HASHING

Importance sampling (Press et al., 2007) helps approximate properties of a target distribution by a weighted average of random draws from another distribution. It is known (Press et al., 2007) that importance sampling can be directly used for the softmax distribution by drawing samples from a uniform distribution -which avoids sampling from the softmax distribution directly which is harder. But this leads to a high variance estimate since the softmax distribution is usually concentrated in a small region. When using this idea for softmax matrix approximation for self-attention in particular, the variance tends to grow with the input sequence length. Before proceeding, we will summarize an interesting importance sampling method for low variance estimators, specifically, importance sampling via LSH from (Charikar & Siminelakis, 2017; Spring & Shrivastava, 2017) . LSH-based Importance Sampling. Consider the case when the angular distance between a key and a query is small. In this case, the similarity (between the key and the query) as well as the softmax probability will be large. When viewed through the lens of a nearest neighbor retrieval, the above property coincides with a large collision probability of high similarity key-query pairs, assuming that the neighbor retrieval is implemented via LSH. Motivated by the link between softmax probability p and LSH collision probability q, Spring & Shrivastava (2017) and Charikar & Siminelakis (2017) suggest using LSH as an efficient sampler for low variance softmax estimators. (a) Spring & Shrivastava (2017) propose approximating softmax by sampling a set, S, a collection of neighboring keys for each query formed by the union of colliding keys using m hash tables. The estimator is computed using |S| -1 i∈S p(q,ki) q(q,ki) v i where q is a query vector, k i , v i are key and value vectors in the sampling set S, and p(•, •) and q(•, •) are softmax probability and collision probability of given pairs. The procedure is equivalent to performing importance sampling without replacement, which involves a dependency among the samples. Deduplication (avoiding double counting) requires memory to store keys in each hash table and runtime to deduplicate keys for each query. If the size of hash buckets is skewed, the GPU memory needs depend on the size of the hash bucket and the runtime depends on the size of S. (b) Charikar & Siminelakis (2017) proposed a Hash based Estimator to simulate a proposal distribution for importance sampling via LSH, which can be easily applied in the context of softmax. For each hash table, a key is uniformly selected from the bucket that the query is hashed to, for simulating a draw from a proposal distribution. The estimate is computed as m -1 m i=1 p(q,ki)|Hi(q)| q(q,ki) v i where |H i (q)| denotes the size of hash bucket in the i-th hash table which q is hashed to. This simulates m samples drawn with replacement from the proposal distribution. However, the probability of one key being sampled depends not only on (a) the angular distance to the query but also (b) the number of keys within the hash bucket, leading to a sampling dependency among all keys. Further, using it for self-attention causes a dependence between the sparsity in the softmax matrix and the number of hashes used. Specifically, the number of tokens that each query can attend to is bounded by the number of hashes: the procedure samples at most one key for each hash table and so, it adds one additional nonzero to the softmax matrix, at most. Remark 1. While LSH-based importance sampling exploits the agreement between high probability p(•, •) and high collision probability q(•, •), the alignment is not perfect. Samples from proposal distribution must be reweighted to compensate for the difference. Further, for different queries, the likelihood ratios between softmax distribution and proposal distribution w.r.t. a single key are different. Therefore, the reweighing has to be done during querying. Although maintaining hash tables for storing keys is not a major problem in general, the high memory cost for hash tables and computation time for reweighing would influence efficiency when applied to self-attention.

4. YOSO-ATTENTION

We start from LSH-based importance sampling and seek to address some of the aforementioned issues when it is deployed for approximating self-attention. Instead of using LSH to simulate sampling from a proposal distribution over tokens, we view attention as a sum of tokens associated with Bernoulli random variables. This modification relates better to LSH and less with LSH-based importance sampling -the probability of one query colliding with a key is not based on other keys. This strategy helps avoid the sampling dependency issue in LSH-based importance sampling and offers an opportunity to develop a strategy more amenable to GPUs. Remark 2. We assume that the input keys and queries of self-attention are unit length -to unify dot-product similarity in self-attention and cosine similarity in LSH. This is simple using Neyshabur & Srebro (2015) : a temperature variable τ is used to bound the squared 2 norm of all queries and keys and to reconstruct new unit length keys and queries while preserving their pairwise similarities. We can work with the softmax matrix in angular distance metric and derive our algorithm. Self-Attention via Bernoulli Sampling. We aim to approximate self-attention, which uses a softmax matrix to capture the context dependency among tokens via their pairwise similarities. Assuming that we can represent this context dependency directly using collision probability q(•, •), then the challenges discussed in importance sampling can be resolved. The coincidence of softmax probability p(•, •) and LSH collision probability q(•, •) makes q(•, •) a sensible starting point for approximating self-attention. Specifically, to model dependency based on similarity, the collision probability aligns well with the exponential function in softmax in the domain of interest [-1, 1] in Figure 1 : both functions have positive zeroth, first and second order derivatives. with the collision probability of concatenating τ hyperplane hashes (Charikar, 2002) (1 -arccos(x)/π) τ for τ = 8. We plot exp(τ (x -1)) so that the range is between 0 and 1 but without changing the actual attention weights in softmax. We also plot the derivative of exponential function and of collision probability, as well as a lower bound we will use later during backpropagation. Our method can be viewed as using LSH collision probability to estimate a biased approximation of exponential function. Note that (a) positive zeroth order derivative indicates that the dependency is positive, (b) positive first order derivative ensures that the dependency based on similarity is monotonic, and (c) positive second order derivative means that low similarity corresponds to almost no dependency. This leads us to hypothesize that a collision-based self-attention may be as effective as softmax-based self-attention. It can be formulated as, n i=1 B i (q, k i )v i (3) where B i (q, k i ) is a Bernoulli random variable where the success probability is given by the collision probability of q with the keys k i . Hence, it can be determined by the similarity between q, k i . In a single hash, each B i (q, k i ) generates a realization to determine whether the corresponding token will be part of attention output or not. Conceptually, when sampling from softmax distribution, only one token is sampled as the attention output. In contrast, Bernoulli sampling determines whether each individual token is a part of the attention output. In principle, to determine the context dependency among tokens, you only need to sample once (YOSO) using a single hash to generate realizations of all Bernoulli random variables, B i (q, k i ), i = 1, . . . , n. Specifically, when keys are hashed to a hash table using a single hash, the realization of B i (q, k i ) for each query q will be 1 if q collides with k i , otherwise it will be 0. To our knowledge, using LSH collision probability to replace softmax dependencies for self-attention has not been studied before. YOSO-Attention. By replacing softmax dependency with Bernoulli random variables and using LSH as an efficient sampler to estimate the success probability, we achieve an efficient self-attention (YOSO-Attention) to approximate softmax-based self-attention. YOSO(Q, K, V ) = B(Q, K)V ; E[YOSO(Q, K, V )] = 1 - arccos(QK T ) π τ V (4) where B(Q, K) is the Bernoulli sampling matrix using m hashes. Kj,:) where f k , k = 1, . . . , m are hash functions. (5) B(Q, K) i,j = 1 m m k=1 1 f k (Qi,:)=f k ( Normalizing Attention. In standard self-attention, each row of the softmax matrix is normalized so that the dependencies sum up to 1. In the above, we have discussed how the pairwise query-key dependencies can be estimated using Bernoulli sampling. We now present how to normalize the dependency in our method as standard self-attention. We can first estimate the dependencies and then normalize them using the sum of estimated dependencies estimated by B(Q, K)1 where 1 is a vector of all entries being 1. B(Q, K)1 can be computed by Eq. 4 by plugging 1 into V . To make the estimation of self-attention more efficient, we turn to adopt a 2 normalization to the attention output, similar as Levy et al. (2015) to use 2 normalization for word embedding. Thus, attention outputs are invariant of the scaling, B(Q, K)1, under 2 normalization. Therefore, we have, Empirically, we show the 2 normalization does not affect the performance of our method as expected, which can be seen in Figure 3 . N-YOSO(Q, K, V ) = 2 (B(Q, K)V ) (6) Value Key Code v0 k0 3 v1 k1 3 v2 k2 1 v3 k3 2 v4 k4 0 v5 k5 3 v6 k6 0 v7 k7 1 Code Hash Table 0 v4 + v6 1 v2 + v7 2 v3 3 v0 + v1 + v5 Query Code Output q0 3 v0 + v1 + v5 q1 2 v3 q2 0 v4 + v6 q3 2 v3 q4 2 v3 q5 1 v2 + v7 q6 3 v0 + v1 + v5 q7 0 v4 + v6 = 3 3 1 2 0 3 0 1 k0 k1 k2 k3 k4 k5 k6 k7 3 q0 2 q1 0 q2 2 q3 2 q4 1 q5 3 q6 0 q7 × Value v0 v1 v2 v3 v4 v5 v6 v7 LSH-based Bernoulli Sampling. Now we discuss how to implement the procedure of using Bernoulli sampling to approximate self-attention. While a standard LSH procedure can be used, maintaining hash tables to store keys is inefficient on a GPU -the GPU memory size required for hash table cannot be predetermined and the workload might be skewed due to skewed bucket sizes. To tackle this issue, we propose LSH-based Bernoulli Sampling by only saving the summation of values corresponding to hashed keys instead of storing a collection of hashed keys. The overview of our algorithm is shown in Figure 2 . To compute Y = B(Q, K)V , the procedure proceeds as follows. For each k ∈ [1, . . . , m], we sample a hash function f k and create a hash table H k ∈ R 2 τ ×d representing 2 τ d-dimensional buckets. For each key K j,: , we add the value V j,: to the bucket whose index is hash code f k (K j,: ), denoted as H k f k (Kj,:) , H k f k (Kj,:) ← H k f k (Kj,:) + V j,: Note that the size of H k is O(2 τ d) and is independent of which bucket keys are hashed. With all keys processed for k ∈ [1, . . . , m], for each query Q i,: , we maintain an output vector Y i,: initialized to 0. Then, we allocate the bucket in H k using f k (Q i,: ) for k ∈ [1, . . . , m ] and add all corresponding results in buckets to the output vector Y i,: as Y i,: ← Y i,: + H k f k (Qi,:),: Therefore, each final output Y i,: can be computed as, Y i,: = m k=1 n j=1 1 f k (Qi,:)=f k (Kj,:) V j,: = n j=1 B(Q, K) i,j V j,: Remark 3. The memory and time complexity of this algorithm are O(m2 τ d) and O(nmd) respectively, In addition, both time and memory are independent of the size of hash buckets. Further, We can improve the memory complexity to O(m2 τ ) by reusing hash table and processing a few dimensions each time without increasing time complexity. The constant τ is small as it controls the decay rate of attention weight with respect to the angular distance between query and key, and it can be chosen to be a function of log 2 (n). In our experiments, τ is set to log 2 (n).

Speed-up.

While not essential, we find that a fast random projection for computing the LSH hash code will be beneficial, since this step takes a large portion of the overall runtime. As suggested by Andoni et al. (2015) , we use the approximated random projection to reduce time complexity to O(nmτ log 2 (d)), allowing fast computation of hash codes. Backpropagation through YOSO-Attention. For training, we also need to show backward propagation steps for YOSO-Attention. Here, we discuss this last component of YOSO-Attention which enables end-to-end and efficient training. For backpropagation, the gradient of the loss L w.r.t. V can be estimated similar to equation 4, ∇ V L = ((1 - arccos(QK T ) π ) τ ) T (∇ YOSO L) ≈ B(K, Q)(∇ YOSO L) The gradients of L w.r.t. Q, K are similar, so we only provide the expression for Q, ∇ Q L = ∇ YOSO L)V T τ (1 - arccos(QK T ) π ) τ -1 π 1 -(QK T ) 2 K (11) where , are element-wise division and multiplication. The problem with the true gradient is that it goes to infinity as the alignment score between the query and the key approaches 1, which might lead to divergence. To avoid this numerical issue, we use a lower bound of the actual derivative of the collision probability, [[(∇ YOSO L)V T ] τ 2 (1 -arccos(QK T ) π ) τ ]K, see Figure 1 , which can be efficiently estimated via a variation of LSH-based Bernoulli Sampling. Specifically, note that the approximation can be decomposed into sum of d LSH-based Bernoulli Sampling, ( ∇Q L) i,: = d l=1 (∇ YOSO L) i,l n j=1 B(Q, K) i,j (V j,l τ 2 K j,: ) Therefore, following LSH-based Bernoulli Sampling, the memory complexity is O(m2 τ d 2 ), and time complexity is O(nmd 2 ). The d 2 term can be eliminated by repeatedly using the same hash tables d 2 times without increasing runtime, which improves the memory complexity to O(m2 τ ). The overall complexity of our method and comparison to standard self-attention is summarized in Table 1 . Further, to address the quadratic dependence on d, in the Appendix, we will discuss a scheme to estimate the same quantity but is linear in d.

5. RELATED WORKS

There are a number of efforts describing ways to reduce the quadratic cost of self attention w.r.t. input sequence length. Among these works, Linformer (Wang et al., 2020) suggests that low rank attention might be sufficient and adds linear projections (on the sequence) to fixed size keys and values. There are also other low rank approximation ideas (Katharopoulos et al., 2020), (Choromanski et al., 2020) using separable functions on queries and keys to replace softmax self-attention. By assuming the self-attention rank to be independent of input sequence length, these methods can achieve O(n) time and memory complexity. Another direction is to exploit the sparsity of softmax matrix and focus on certain sparsity patterns by only computing softmax dependencies within those patterns, including Sparse Transformer (Child et al., 2019) , Longformer (Beltagy et al., 2020) , and Big Bird (Zaheer et al., 2020) and Reformer (Kitaev et al., 2020) . Note that, instead of using LSH as a tool to approximate nearest neighbor search to dynamically determine the sparsity pattern in Reformer, our YOSO-attention takes advantage of the connection of query-key similarity to the LSH collision probability to model the dependency among tokens.

6. EXPERIMENTS

In this section, we provide the empirical results for the proposed approach. To evaluate our proposed method, we follow the BERT language model pretraining procedure (Devlin et al., 2018) and evaluate the performance of our method in both intrinsic tasks and multiple downstream tasks in GLUE benchmark as well as runtime and memory relative to standard self attention. Previously, we assumed that queries and keys are unit length and described the construction to make it work. In the experiments, we found that simply applying a 2 normalization on queries and keys and using a temperature τ as a hyperparameter does not degrade the performance of model and yet is more efficient to compute, so we use the simpler version in the experiments. BERT Pretraining. Following Devlin et al. (2018) , the model is pretrained on BookCorpus (Zhu et al., 2015) and English Wikipedia. To evaluate the capacity of model capturing the sentence level information, instead of using Next-Sentence-Prediction (NSP) as sentence level loss in the original BERT, we adapt the Sentence-Ordering-Prediction (SOP) proposed in ALBERT (Lan et al., 2019) as a more difficult task compared to NSP. All model are trained with Mask-Language-Modeling (MLM) and SOP objectives. We used the same hyperparameters for pretraining as Devlin et al. (2018) . However, due to the computational resources limit, all models are trained for 500K steps. The batch size is set so that around 2 17 tokens are processed per step. (batch size of 256 for sequence length 512, and batch size of 32 for sequence length 4096).

Number of Hashes during

Pretraining. Since the estimation variance decreases as the number of hashes increases, to evaluate the trade-off between efficiency and performance in YOSO, we test on four hash settings: 16 hashes, 32 hashes, 64 hashes, and expectation of collision to simulate infinite hashes. We plot MLM validation perplexity and SOP validation loss curves of 512 length model pretrained with softmax self-attention and YOSO-Attention in the right plots of Figure 3 . The curves of our method using expectation match and slightly exceed softmax self-attention, indicating our method is indeed as capable as self-attention. It is expected that as the number of hashes increase, the performance of our method will approach the curve using expectation as the approximation Time Memory Forward Backward Forward Backward Self-Attention O(n 2 d) O(n 2 d) O(n 2 ) O(n 2 ) YOSO-Attention O(nmτ log 2 (d) + nmd) O(nmd 2 ) O(nmτ + m2 τ ) O(m2 τ ) Table 1 : Time/memory complexity of self-attention and YOSO-attention in forward/backward computation become more accurate. For both MLM and SOP, we confirm that our method is as effective as softmax self-attention. Number of Hashes during Validation. YOSO-Attention is a stochastic model. To make the inference deterministic, as in dropout (Srivastava et al., 2014) , we take the expectation as our output. However, directly computing expectation involves a O(n 2 ) cost, so we experiment with the effect of different hash settings in validation and simulate expectation as the number of hashes increases. We plot the MLM perplexity and SOP loss of the same pretrained models using different number of hashes on validation in the center plots of Figure 3 . We observe that as the number of hash increases, the MLM perplexity and SOP loss generally decreases for all pretraining hash settings. Pretraining on Longer Sequence. To examine whether our method can scale linearly with sequence length, we continue to pretrain BERT-base models using the corresponding 500K step checkpoints for 512 length model, and add additional positional embedding as suggested in Beltagy et al. (2020) . We observe that compared to 512 sequence length, the small performance gap between YOSO-Attention and softmax self-attention does not increase as suggested in the left plots of Figure 3 , providing evidence that the number of hashes can be chosen independent of sequence length. 3 and YOSO-x-E means that YOSO-x is finetuned on downstream tasks using expectation. GLUE Tasks. In addition to intrinsic tasks, we examined the effectiveness of our method on diverse downstream tasks and ask how our method compares with standard attention even after finetuning. We finetuned all pretrained BERT-base model on MRPC (Dolan & Brockett, 2005) , RTE (Giampiccolo et al., 2007) , SST-2 (Socher et al., 2013) , QNLI (Rajpurkar et al., 2016) , QQP (Chen et al., 2018) , and MNLI (Williams et al., 2018) tasks in the GLUE benchmarks and report their corresponding dev metrics. For large datasets including QNLI, QQP, and MMNL, due to extensive resource needs, we cannot do hyperparameter search, so we used a batch size of 32 and learning rate 3e-5 to update our model and finetune our models for 4 epochs. For MRPC, RTE, and SST-2, we follow BERT finetuning to do a hyperparameter search with candidate batch size {8, 16, 32} and learning rate {2e-5, 3e-5, 4e-5, 5e-5} and select the best dev set result. Results are listed in Table 2 . We observed that YOSO's performance on downstream tasks is comparable with standard attention, and even has slightly better results in some hash settings. Further, the downstream performance of YOSO generally increases as more hashes are used, providing an adjustable trade-off between efficiency and accuracy. Longer Sequence Task. To further evaluate YOSO on long sequence tasks, we extended the positional embeddings of a trained YOSO-64 model and used it as an initialization to train a 4096 length YOSO-128 model using a batch size of 64 and learning rate 5e-5 on BookCorpus (Zhu et al., 2015) , English Wikipedia, one third of the Stories (Trinh & Le, 2018) , and one third of Realnews (Zellers et al., 2019) for 100K steps, similar to Longformer pretraining (Beltagy et al., 2020) . Then, we finetuned our model on WikiHop (Welbl et al., 2018) . Due to the computational resource limits, we only tested a small set of hyperparameters (batch size = 32, learning rate ∈ {1e-5, 2e-5, 4e-5}, number of epochs = 10). The dev accuracy is 73.7 for YOSO-128-E, which is comparable to 73.8 in Longformer-512 (see caption in Comparisons to Baselines. Apart from comparing YOSO to standard self-attention, we also evaluated its competitiveness with other efficient attention methods. To keep the financial costs of these experiments reasonable, instead of training all methods from scratch, we used RoBERTa-base's pretrained weights as the starting point and trained each model using batch size 512 and learning rate 5e-5 on Book-Corpus (Zhu et al., 2015) and English Wikipedia for 95K steps. Then, we finetuned the models on SST-2, QQP, and MNLI. These results are shown in Table 3 . We observed that our performance is competitive with other baselines while the memory consumption of YOSO is much less (2.6×, 1.9×, 2.1× memory savings compared to Reformer, Longformer, and Linformer respectively, see Backward-Cache in Table 4 ). This has potential ramifications for training such models with more moderate hardware resources which are much less expensive. Further, notice that YOSO is potentially applicable to a wider range of applications, especially where the input sequence represents an unordered set of high dimensional points (where spatial locality of the input sequence may not hold). Estimation Error. To assess the effectiveness of our algorithm, using Q, K from the trained model, we generated attention matrices using our algorithm with different number of hashes and compare it against standard self-attention. In Figure 4 , visually, we see that our method produces similar attention patterns as standard self-attention. The estimation of attention matrix becomes more accurate as the number of hashes increases. Further, each output of YOSO-Attention is a weighted sum of random variables as shown in equation 3; so one may suspect that as the sequence length increases, the variance of YOSO-Attention output might potentially increase. We did not observe this behavior which may be partly due to the hyperparameter τ = O(log(n)) that controls the decay rate of LSH collision probability as the similarity changes. We can also ask whether or not the estimation error of YOSO-Attention for a fixed number of hashes increases as the sequence length increases. We use Q, K, V generated by the pretrained model and estimate the error between N-YOSO(Q, K, V ) and E[N-YOSO(Q, K, V )]. As the left plot of Figure 5 suggests, the relative error of our method stays almost constant as the sequence length increases from 128 to 4096. This indicates that using sampling to estimate attention weight based on YOSO-Attention can scale up with sequence length and preserve the same estimation quality without increasing the number of hashes.

A APPENDIX

In Appendix, we provide some details of our method that are left out in the main text. Backpropogation Derivation When using expectation of LSH collision as attention weights, the attention of one query q to keys k i and associated values v i for all i ∈ {1, ..., n} is defined as y = n i=1 1 - arccos(q T k i ) π τ v i then we want to compute the gradient of loss w.r.t. q, which we denoted as ∇ q L, with the gradient of loss w.r.t. y denoted, ∇ y L, given. We start by computing the p-th entry of ∇ q L: ∂L ∂q p =  ) τ -1 π √ 1-x 2 and plug it into the Eq. 13 ∂L ∂q p = d j=1 ∂L ∂y j n i=1    τ 1 -arccos(q T k i ) π τ -1 π 1 -(q T k i ) 2 k i p    v i j (15) After swapping the order of two summations, Eq. 13 becomes ∂L ∂q p = n i=1 (∇ y L) T v i τ 1 -arccos(q T k i ) π τ -1 π 1 -(q T k i ) 2 k i p Note that only k i p is different for different entries of ∇ q L, so we can write it as ∇ q L = n i=1 (∇ y L) T v i τ 1 -arccos(q T k i ) π τ -1 π 1 -(q T k i ) 2 k i ( ) Equation 11 is the matrix form of above ∇ Q L = ∇ YOSO L)V T τ (1 - arccos(QK T ) π ) τ -1 π 1 -(QK T ) 2 K (18) Note that π 1 -(QK T ) 2 approaches to 0 as alignment score between the query and the key approaches 1, so we use the fact that 1 2 (1 -arccos(x) π ) ≤ 1 π √ 1-x 2 for x ∈ [-1, 1] and define a lower bound to replace the actual gradient ∇ Q L = ∇ YOSO L)V T τ 2 (1 - arccos(QK T ) π ) τ K ( ) Approximating Random Projection in LSH. In the main text, we discussed how to estimate selfattention using Bernoulli sampling via LSH. The first step of using LSH is computing hash code using random projection. To compute hash codes for a vector x, we proceed as follows. F : R d → {0, 1} mτ F (x) = sign(Rx) where R ∈ R (mτ )×d , R ij ∼ N (0, 1), then the output vector are partition to m τ -dimensional binary hash code. The time complexity for random project is O(nmτ d). To efficiently approximate random projection, we follow the construction used in Andoni et al. (2015) . The output of mτ -dimensional vector is divided to mτ d d-dimensional vectors, then hash codes are estimated by F (x) = concat(sign(HD 1 3 HD 1 2 HD 1 1 x), ..., sign(HD  where D j i are diagonal matrices with entries uniformly sampled from {-1, +1}, and H is Hadamard matrix. This approximation reduce time complexity to O(nmτ log 2 (d)).



Figure 1: We compare attention weights using exp(τ (x -1))

Figure 2: Overview of YOSO-Attention. The hash table stores the sum of values associated with hashed keys.

Figure 3: (a) The left two plots are results on MLM and SOP for 512 sequence length. We report MLM validation perplexity and SOP validation loss for each 2K training steps. (b) The middle two plots are results for MLM and SOP when using different number of hashes on validation. Since the runtime of YOSO-Attention is linear with respect to the number of hashes, these two plot directly reflect the equivalent relation between performance vs inference time. (c) The right two plots are results on on MLM and SOP for 4096 sequence length. YOSO-x means the model is pretrained with YOSO-Attention using x hashes with E being expectation.

Figure 4: Attention matrices generated by self-attention and YOSO-Attention with different hash settings using the same input. Notice that the patterns are preserved well.

Table3) without hyperparameter search but slightly worse than 75.0 that Longformer-512 achieves with hyperparameter search. Dev set results on SST-2, QQP, and MNLI. We report both F1 score and accuracy QQP and accuracy for others. Reformer-x: Reformer using HuggingFace implementation(Wolf et al., 2019) using x hashes. Longformer-x: Long-

annex

]. The relative error is estimated by computing E[N-YOSO(Q, K, V )] based on collision probability, then estimating N-YOSO(Q, K, V ) multiple times, finally computing the mean of relative error of each estimate as an estimation of the outer expectation. (b) The runtime per token is estimated by estimating N-YOSO(Q, K, V ) multiple times and measuring the total time elapsed and then dividing the total time by number of iterations and sequence length to get runtime per token.Runtime and Memory. We measure the runtime of our method as sequence length increases. To show the trend more precisely, we measured the runtime per token as shown in Figure 5 

(right).

There is a slight increase in runtime per token as the sequence length increases, but note that the x-axis of the plot is log scale, so the increment is small. When the sequence length increases by 32×, the runtime per token only increases by 30%, which is explained by our choice of hyperparameter τ = O(log(n)). Aside from the plot, we report the training and testing efficiency of our method as well as three other efficient attention methods against standard self-attention. The results were measured using Q, K, V of a specified sequence length generated by a trained model and fed into a BERT-base Multi-Head Attention module multiple times. The experiments were performed on a single NVIDIA 2080TI. From Table 4 , we can see that while for a standard 512 length sequence, our method has a similar runtime as self-attention, as the sequence length increases, the speed-up and memory savings become significant. While our method offers similar runtime savings as other efficient attention methods, the memory consumption for training (i.e., Backward-Cache) of our method is much lower than all other methods in almost all settings. 

7. CONCLUSION

We presented a transformer-based model, YOSO-Attention, that scales linearly in the number of input tokens. This allows the model to be applicable to a wide range of long document NLP tasks. Via a randomized sampling based scheme, our model approximates self-attention as a sum of individual tokens associated with Bernoulli random variables that can be sampled at once by a single hash, in principle. With specific modifications of LSH, YOSO-Attention can be efficiently deployed within a deep learning framework and various aspects of this idea and our implementation, we expect, will find use in other novel settings and applications (e.g., in vision).Alternative Procedure for Approximating Backpropagation. In the main text, we provided a procedure as shown in Eq. 12, which use LSH-based Bernoulli Sampling d times as subroutine. The complexity of this procedure is linear w.r.t. sequence length n, which is desirable but the runtime can be large if d is relatively large. Therefore, we provide second procedure, which is linear with respect to d. The gradient of L w.r.t. the i-th row of Q is written asNote that if B(Q, K) i,j is zero then the corresponding summation term does not need to be computed. The alternative procedure counts the number of success in m samples at each entry B(Q, K) i,j and only computes the summation term when B(Q, K) i,j is non-zero, and thus the runtime is O(nnz(S(A, B))(m + d)) (counting number of success + computing nonzero terms). In the worst case, nnz(B(Q, K)) = n 2 , it would be as expensive as dense matrix multiplications in complexity and even worst in practice due to large memory latency resulted from indirect memory access. However, in practice, B(Q, K) is generally sparse if τ is set properly. Further, the first procedure guarantees a linear complexity scaling of our method for extremely long sequences. As an improvement, we can dynamically select one from these two method based on runtime, than the time complexity is O(min(nmd 2 , nnz(B(Q, K))(m + d))).

