YOU ONLY SAMPLE (ALMOST) ONCE: LINEAR COST SELF-ATTENTION VIA BERNOULLI SAMPLING Anonymous

Abstract

Transformer-based models have come to dominate the landscape in a wide range of natural language processing (NLP) applications. The heart of the transformer model is the self-attention mechanism, which captures the interactions of token pairs in the input sequences and consequently, depends quadratically on the input sequence length. It is known that training such models on longer sequences is quite expensive, and often, prohibitively so. We show that a Bernoulli sampling attention mechanism based on Locality Sensitive Hashing (LSH), decreases the quadratic complexity to linear. We bypass the quadratic cost by considering selfattention as a sum of individual tokens associated with Bernoulli random variables that can, in principle, be sampled at once by a single hash (although in practice, this number may be a small constant). This leads to an efficient sampling scheme to estimate self-attention which relies on specific modifications of LSH (based on feasibility of deployment on GPU architectures). We evaluate our proposed algorithm on the GLUE benchmark with standard 512 sequence length and our method achieves comparable or even slightly better performance than a standard pretrained Transformer. To evaluate whether our method can indeed handle longer sequences, we conduct experiments on long sequence (4096) language model pretraining and achieve consistent results as standard self-attention, while observing sizable inference speed-ups and memory savings.

1. INTRODUCTION

The Transformer model (Vaswani et al., 2017) is incredibly effective across a diverse set of natural language processing (NLP) applications including machine translation (Vaswani et al., 2017) , language inference (Devlin et al., 2018) and paraphrasing (Raffel et al., 2019) . Transformer-based models such as BERT (Devlin et al., 2018) are pretrained in an unsupervised manner and later finetuned on different downstream tasks, often providing state-of-the-art performance on standard benchmarks. While such models have strong empirical performance, their high computational and memory requirements remain quite high. Consequently, in the NLP setting, most current models have certain constraints on the sequence length, e.g., BERT and other transformer-based language models (Yang et al., 2019; Liu et al., 2019) limit the sentence length to be at most 512. The Multi-Head Self-Attention is central to Transformer based models and provides a flexible global receptive field to exchange information among input tokens. While self-attention provides immense benefits, it is also a key bottleneck in training with long sequences. In particular, the output of self-attention is a combination of all tokens where coefficients are determined by the similarities among tokens. While this is empirically beneficial, it involves a sizable resource footprint. For sequence length n, this leads to a O(n 2 ) complexity in both time and memory to compute pairwise similarities among all input tokens. This quadratic cost is a roadblock in attaining potential benefits that may be realizable in various applications by capturing long term context dependencies. As we will discuss in more detail later, the foregoing issue is a major thrust of several recent and ongoing efforts focused on mitigating the sizable resource requirements of such models. Our work is inspired by ideas of importance sampling via hashing-based sampling techniques (Spring & Shrivastava, 2017; Charikar & Siminelakis, 2017) . We proposed a Bernoulli based sampling to approximate self-attention, scaling linearly with the input sequence length. We achieve this by viewing self-attention as a sum of individual tokens associated with Bernoulli random variables 1

