TREEFORMER: DENSE GRADIENT TREES FOR EFFICIENT ATTENTION COMPUTATION

Abstract

Standard inference and training with transformer based architectures scale quadratically with input sequence length. This is prohibitively large for a variety of applications especially in web-page translation, query-answering etc. Consequently, several approaches have been developed recently to speedup attention computation by enforcing different attention structures such as sparsity (Zaheer et al., 2020), low-rank (Wang et al., 2020), approximating attention using kernels (Choromanski et al., 2021). In this work, we view attention computation as that of nearest neighbor retrieval, and use decision tree based hierarchical navigation to reduce the retrieval cost per query token from linear in sequence length to nearly logarithmic. Based on such hierarchical navigation, we design Treeformer which can use one of two efficient attention layers -TF-ATTENTION and TC-ATTENTION. TF-ATTENTION computes the attention in a fine-grained style, while TC-ATTENTION is a coarse attention layer which also ensures that the gradients are "dense". To optimize such challenging discrete layers, we propose a two-level bootstrapped training method. Using extensive experiments on standard NLP benchmarks, especially for long-sequences, we demonstrate that our TREEFORMER architecture can be almost as accurate as baseline Transformer while using 30x lesser FLOPs in the attention layer. Compared to Linformer, the accuracy can be as much as 12% higher while using similar FLOPs in the attention layer.

1. INTRODUCTION

Self attention layer is the key component of Transformers (Vaswani et al., 2017) , enabling them to achieve state of the art performance across tasks in Natural Language Processing (Devlin et al., 2019; Radford et al., 2019; Raffel et al., 2019) and Computer Vision (Dosovitskiy et al., 2021) . Attention computation scales quadratically (n 2 ) with the input sequence length (n), making it a key bottleneck in scaling Transformers to long inputs. This has resulted in a lot of alternate proposals to efficiently compute attention using different approximations. However, as shown in Tay et al. (2021) , methods that offer dramatic speedups (Wang et al., 2020; Choromanski et al., 2021) suffer in accuracy, and methods that match accuracy of Transformer (Zaheer et al., 2020) don't offer significant speedups. In this paper we develop a novel efficient attention mechanism using input dependent sparsity structure in attention computation, motivated by the analysis showing the sparse nature of attention (Shi et al., 2021; Gupta et al., 2021) . We propose to use decision trees to efficiently compute attention by only retrieving the top nearest neighboring keys for a given query. Decision trees are low cost, hierarchical non-linear models that have been popular in both supervised (Chen & Guestrin, 2016; Quinlan, 2014) and unsupervised (Duda et al., 1973) learning. Given a vector, each node of the decision tree decides whether to navigate to the left or the right child depending on the sign of a simple linear projection. We learn decision trees in each attention layer of a Transformer model so that a query pays attention only to keys that map to the same leaf nodesfoot_0 of the decision trees. We refer to this as TF-ATTENTION. Such sparse decision trees that retrieve only a single leaf node are typically hard to train due to their discrete nature. Further, queries that pay strong attention to multiple keys result in bunching many keys together in a single leaf node creating imbalance. The key computational advantage of decision trees is realized only when they are balanced. To address these issues we propose another attention mechanism TC-ATTENTION, that in addition to keys in the leaf node, pays attention to keys at all the intermediate nodes traversed by the query. We further combine the retrieved value vectors with query independent, but trainable weights. This allows us to learn a decision tree that clusters keys in a hierarchical and balanced manner. In both these methods we train the decision tree in an end-to-end manner with the rest of the components of the Transformer. Since the decision trees themselves are inexpensive to evaluate, this allows us to reduce attention computation cost dramatically, requiring typically only a few dot products per query. In our goal to train these decision trees in an end to end manner with the rest of the Transformer components we faced some challenges. Naively using the decision trees to restrict attention computation results in poor optimization with training getting stuck at high loss. To solve this we introduce a novel bootstrapping method that allows us to gradually restrict attention computation in layers using decision trees. While this keeps training expensive initially, it still dramatically reduces inference costs. We further use Dense Gradient Trees (Karthikeyan et al., 2021) , allowing us to use simple gradient based optimization. We evaluate the proposed architecture on popular NLP models such as BERT (Devlin et al., 2019) , and on the Long Range Arena benchmark (LRA) (Tay et al., 2021) , developed specifically to compare efficient attention methods on long sequence length tasks. We show that TREEFORMER is able to match/improve the performance over other popular models while achieving lower computational cost, for e.g. -TREEFORMER has 8-9x lesser FLOPs than models like BigBird (Zaheer et al., 2020) and Performer (Choromanski et al., 2021) . Our contributions in this paper are -(i) We pose attention computation as a nearest neighbor retrieval problem, and propose a novel architecture TREEFORMERusing dense gradient trees for retrieving top keys to pay attention to for a given query. We develop two novel attention mechanisms using decision trees -that, for a given query, pay attention to only keys in the matching leaf node (TF-ATTENTION) or to all the keys along the traversed path (TC-ATTENTION). (ii) We develop a novel bootstrapping method to gradually sparsify attention computation thereby showing that decision trees can be trained as a part of neural networks using back propagation. (iii) We experimentally show the effectiveness of the proposed architecture using BERT models and on LRA (Tay et al., 2021) . In particular, both TF-ATTENTION and TC-ATTENTION have matching accuracy to baseline Transformer architecture, but with up to 30× cheaper (in FLOPs) attention layer compared to standard Transformers and up to 9× cheaper attention layer compared to SOTA BigBird (Zaheer et al., 2020) . 2 PRELIMINARIES: ATTENTION AND DECISION TREES 2.1 ATTENTION Let Q, K, V ∈ R n×d be the input query, key, and value matrices, where n is the sequence length and d is the embedding dimension, and W Q , W K , W V ∈ R d×d are their respective projection matrices. The standard attention in each Transformer layer can be defined as Attention(Q, K, V ) = softmax QW Q (KW K ) T √ d A •V W V (1) Note that for the self attention layer, all the three matrices Q, K, V are same as the input to the layer. Computing both attention A and its product with the projected value matrix V W V is an O(n 2 d + d 2 n) operation. In particular, assuming cpqr cost for multiplying two matrices of size p × q and q × r, the main cost of this layer is: c(2n 2 d + 2d 2 n). This computation becomes a major bottleneck in applications which have large sequence lengths n. To mitigate this, several efficient attention mechanisms have been proposed (Child et al., 2019; Wang et al., 2020; Choromanski et al., 2021; Kitaev et al., 2020; Sun et al., 2021; Wang et al., 2022) . 



Leaf nodes are the bottom nodes of a decision tree with 0 children.



Child et al. (2019) proposed sparse attention (A) with different sparsity patterns that reduce attention cost from O(n 2 ) to O(n 1.5 ). Later Zaheer et al. (2020); Yun et al. (2020) proved universal approximation power of Transformers with sparse attention, improving the cost to O(ns), with constant sparsity s per query. Alternately Wang et al. (2020) considered a low rank approximation to attention, reducing computation cost to O(nk) for rank k approximation. Choromanski et al. (2021) considered a kernel

