YOU ONLY SAMPLE (ALMOST) ONCE: LINEAR COST SELF-ATTENTION VIA BERNOULLI SAMPLING Anonymous

Abstract

Transformer-based models have come to dominate the landscape in a wide range of natural language processing (NLP) applications. The heart of the transformer model is the self-attention mechanism, which captures the interactions of token pairs in the input sequences and consequently, depends quadratically on the input sequence length. It is known that training such models on longer sequences is quite expensive, and often, prohibitively so. We show that a Bernoulli sampling attention mechanism based on Locality Sensitive Hashing (LSH), decreases the quadratic complexity to linear. We bypass the quadratic cost by considering selfattention as a sum of individual tokens associated with Bernoulli random variables that can, in principle, be sampled at once by a single hash (although in practice, this number may be a small constant). This leads to an efficient sampling scheme to estimate self-attention which relies on specific modifications of LSH (based on feasibility of deployment on GPU architectures). We evaluate our proposed algorithm on the GLUE benchmark with standard 512 sequence length and our method achieves comparable or even slightly better performance than a standard pretrained Transformer. To evaluate whether our method can indeed handle longer sequences, we conduct experiments on long sequence (4096) language model pretraining and achieve consistent results as standard self-attention, while observing sizable inference speed-ups and memory savings.

1. INTRODUCTION

The Transformer model (Vaswani et al., 2017) is incredibly effective across a diverse set of natural language processing (NLP) applications including machine translation (Vaswani et al., 2017) , language inference (Devlin et al., 2018) and paraphrasing (Raffel et al., 2019) . Transformer-based models such as BERT (Devlin et al., 2018) are pretrained in an unsupervised manner and later finetuned on different downstream tasks, often providing state-of-the-art performance on standard benchmarks. While such models have strong empirical performance, their high computational and memory requirements remain quite high. Consequently, in the NLP setting, most current models have certain constraints on the sequence length, e.g., BERT and other transformer-based language models (Yang et al., 2019; Liu et al., 2019) limit the sentence length to be at most 512. The Multi-Head Self-Attention is central to Transformer based models and provides a flexible global receptive field to exchange information among input tokens. While self-attention provides immense benefits, it is also a key bottleneck in training with long sequences. In particular, the output of self-attention is a combination of all tokens where coefficients are determined by the similarities among tokens. While this is empirically beneficial, it involves a sizable resource footprint. For sequence length n, this leads to a O(n 2 ) complexity in both time and memory to compute pairwise similarities among all input tokens. This quadratic cost is a roadblock in attaining potential benefits that may be realizable in various applications by capturing long term context dependencies. As we will discuss in more detail later, the foregoing issue is a major thrust of several recent and ongoing efforts focused on mitigating the sizable resource requirements of such models. Our work is inspired by ideas of importance sampling via hashing-based sampling techniques (Spring & Shrivastava, 2017; Charikar & Siminelakis, 2017) . We proposed a Bernoulli based sampling to approximate self-attention, scaling linearly with the input sequence length. We achieve this by viewing self-attention as a sum of individual tokens associated with Bernoulli random variables whose success probability is determined by the similarities among tokens. In principle, we can sample all Bernoulli random variables at once with a single hash (although in practice, this number may be a small constant to lower the approximation variance). This leads to an efficient sampling scheme to estimate self-attention which relies on specific modifications of hashing-based importance sampling (based on feasibility of deployment on GPU architectures). The resulting strategy (You Only Sample Almost Once, YOSO-Attention) is far more amenable to an efficient and backpropagation friendly implementation, and has a favorable empirical performance profile on natural language modeling tasks. We evaluate our proposed algorithm on the GLUE benchmark (Wang et al., 2019) with 512 sequence length as well as on long sequence language model pretraining where we see promising results with speed-ups and memory savings.

2. BACKGROUND: SELF-ATTENTION

Self-Attention. Self-attention is a scaled dot-product attention mechanism to capture token dependencies in the input sequence, which can be defined as, A(Q, K, V ) = softmax     (QW Q )(KW K ) T √ d h P     V W V = D P exp (P) V W V (1) where Q, K, V ∈ R n×d are embedding matrices from the input sequence, called queries, key and values respectively. Here, n is the input sequence length, d is the embedding dimension of each token, W Q , W K , W V ∈ R d×d h are learned parameter matrices, d h is the dimension of hidden embedding, and D P is a n × n diagonal matrix which normalizes each row of the exp (P) matrix such that the row entries sum up to 1. For simplicity, we overload the notations for Q, K, V to denote QW Q , KW K , V W V in our presentation. Multi-Head Self-Attention. Multi-Head self-attention in Transformers runs through the scaled dot-product attention multiple times and the attention outputs are concatenated to help the model capture information from multiple representation subspaces Vaswani et al. (2017) . Multi-Head Selfattention can be formally written as, MultiHead(Q, K, V ) = Concat A 1 (Q, K, V ), • • • , A h (Q, K, V ) W (2) where h is the number of heads, A i , i = 1, . . . , h are heads with different parameter matrices. Self-Attention Bottleneck. A key bottleneck in self-attention is computing the softmax matrix, softmax(P), which requires the calculation of all pairwise input token similarities. To reduce this cost, we seek to approximate the softmax matrix by viewing self-attention for each query as an expectation of a softmax distribution and computing the approximated self-attention with an efficient sampling mechanism. In the following sections, we will first review LSH-based importance sampling and then propose Bernoulli sampling with LSH to estimate self-attention efficiently.

3. IMPORTANCE SAMPLING VIA LOCALITY SENSITIVE HASHING

Importance sampling (Press et al., 2007) helps approximate properties of a target distribution by a weighted average of random draws from another distribution. It is known (Press et al., 2007) that importance sampling can be directly used for the softmax distribution by drawing samples from a uniform distribution -which avoids sampling from the softmax distribution directly which is harder. But this leads to a high variance estimate since the softmax distribution is usually concentrated in a small region. When using this idea for softmax matrix approximation for self-attention in particular, the variance tends to grow with the input sequence length. Before proceeding, we will summarize an interesting importance sampling method for low variance estimators, specifically, importance sampling via LSH from (Charikar & Siminelakis, 2017; Spring & Shrivastava, 2017) . LSH-based Importance Sampling. Consider the case when the angular distance between a key and a query is small. In this case, the similarity (between the key and the query) as well as the softmax probability will be large. When viewed through the lens of a nearest neighbor retrieval, the above property coincides with a large collision probability of high similarity key-query pairs, assuming that

