TREEFORMER: DENSE GRADIENT TREES FOR EFFICIENT ATTENTION COMPUTATION

Abstract

Standard inference and training with transformer based architectures scale quadratically with input sequence length. This is prohibitively large for a variety of applications especially in web-page translation, query-answering etc. Consequently, several approaches have been developed recently to speedup attention computation by enforcing different attention structures such as sparsity (Zaheer et al., 2020), low-rank (Wang et al., 2020), approximating attention using kernels (Choromanski et al., 2021). In this work, we view attention computation as that of nearest neighbor retrieval, and use decision tree based hierarchical navigation to reduce the retrieval cost per query token from linear in sequence length to nearly logarithmic. Based on such hierarchical navigation, we design Treeformer which can use one of two efficient attention layers -TF-ATTENTION and TC-ATTENTION. TF-ATTENTION computes the attention in a fine-grained style, while TC-ATTENTION is a coarse attention layer which also ensures that the gradients are "dense". To optimize such challenging discrete layers, we propose a two-level bootstrapped training method. Using extensive experiments on standard NLP benchmarks, especially for long-sequences, we demonstrate that our TREEFORMER architecture can be almost as accurate as baseline Transformer while using 30x lesser FLOPs in the attention layer. Compared to Linformer, the accuracy can be as much as 12% higher while using similar FLOPs in the attention layer.

1. INTRODUCTION

Self attention layer is the key component of Transformers (Vaswani et al., 2017) , enabling them to achieve state of the art performance across tasks in Natural Language Processing (Devlin et al., 2019; Radford et al., 2019; Raffel et al., 2019) and Computer Vision (Dosovitskiy et al., 2021) . Attention computation scales quadratically (n 2 ) with the input sequence length (n), making it a key bottleneck in scaling Transformers to long inputs. This has resulted in a lot of alternate proposals to efficiently compute attention using different approximations. However, as shown in Tay et al. (2021) , methods that offer dramatic speedups (Wang et al., 2020; Choromanski et al., 2021) suffer in accuracy, and methods that match accuracy of Transformer (Zaheer et al., 2020) don't offer significant speedups. In this paper we develop a novel efficient attention mechanism using input dependent sparsity structure in attention computation, motivated by the analysis showing the sparse nature of attention (Shi et al., 2021; Gupta et al., 2021) . We propose to use decision trees to efficiently compute attention by only retrieving the top nearest neighboring keys for a given query. Decision trees are low cost, hierarchical non-linear models that have been popular in both supervised (Chen & Guestrin, 2016; Quinlan, 2014) and unsupervised (Duda et al., 1973) learning. Given a vector, each node of the decision tree decides whether to navigate to the left or the right child depending on the sign of a simple linear projection. We learn decision trees in each attention layer of a Transformer model so that a query pays attention only to keys that map to the same leaf nodes 1 of the decision trees. We refer to this as TF-ATTENTION. Such sparse decision trees that retrieve only a single leaf node are typically hard to train due to their discrete nature. Further, queries that pay strong attention to multiple keys result in bunching many keys together in a single leaf node creating imbalance. The key computational advantage of decision trees is realized only when they are balanced. 1 Leaf nodes are the bottom nodes of a decision tree with 0 children. 1

