SPARSE TOKEN TRANSFORMERS WITH ATTENTION BACK TRACKING

Abstract

Despite the success of Transformers in various applications from text, vision, and speech domains, they are yet to become standard architectures for mobile and edge device applications due to their heavy memory and computational requirements. While there exist many different approaches to reduce the complexities of the Transformers, such as the pruning of the weights/attentions/tokens, quantization, and distillation, we focus on token pruning, which reduces not only the complexity of the attention operations, but also the linear layers, which have non-negligible computational costs. However, previous token pruning approaches often remove tokens during the feed-forward stage without consideration of their impact on later layers' attentions, which has a potential risk of dropping out important tokens for the given task. To tackle this issue, we propose an attention back-tracking method that tracks the importance of each attention in a Transformer architecture from the outputs to the inputs, to preserve the tokens that have a large impact on the final predictions. We experimentally validate the effectiveness of the method on both NLP and CV benchmarks, using Transformer architectures for both domains, and the results show that the proposed attention back-tracking allows the model to better retain the full models' performance even at high sparsity rates, significantly outperforming all baselines. Qualitative analysis of the examples further shows that our method does preserve semantically meaningful tokens.

1. INTRODUCTION

Transformers have achieved huge success in various application domains such as natural language processing (NLP) and computer vision (CV), obtaining state-of-the-art performances on a variety of tasks, and are now considered the de facto standard architectures for a number of domains. However, Transformers require over a few tera-flops per entry (Devlin et al., 2019; Dosovitskiy et al., 2021b) to compute, which is orders of magnitude larger than the computational cost for previous CNN and RNN architectures (Tan & Le, 2019; Bahdanau et al., 2015) . To reduce such computational burdens of Transformer models, previous works explored model compression methods such as distillation (Wang et al., 2020; Jiao et al., 2020 ), quantization (Han et al., 2015; Frankle & Carbin, 2018), and pruning (Zaheer et al., 2020; Guo et al., 2019; Kim et al., 2022; Rao et al., 2021; Wang et al., 2021) . Pruning approaches for Transformers mostly aim to remove unnecessary model weights (Han et al., 2015; Frankle & Carbin, 2018) , or the attentions (Zaheer et al., 2020) , which could achieve a linear reduction in complexity. Token pruning can reduce the complexity of attentions and fully connected layers simultaneously by removing less relevant tokens for the target task (Kim et al., 2022; Goyal et al., 2020; Kong et al., 2021) . How can we decide which tokens to prune then? Previous studies either compute the importance score of each input token as the average attention scores (Goyal et al., 2020; Kim et al., 2022) (Figure 1 ), or learn the importance score of each token with an additional neural network (NN) at each layer (Rao et al., 2021; Kong et al., 2021) . However, previous token pruning methods have the following problem: they prune the tokens in the input sequence without explicitly evaluating the importance of each token on the final sequence representation and prediction tasks. This is because all existing works (Kim et al., 2022; Wang an input token at each attention layer while doing the forward pass from the input to the final output (feed-forward token pruning, Figure 1 ). Thus, they may prune out important tokens at earlier layers that are important for the final representation as well as the task loss. For example, in Figure 1 , the token Love in the layer l + 1 is pruned by the feed-forward method, although the token has a high attention probability in the representation at the final layer l + 2 that is used by the classification task. To tackle such a limitation of conventional feed-forward token pruning methods, we propose an Attention back-tracking method for computing the importance score of each input token by their importance on the required output tokens (e.g.: sequence representation token, in last layer) and the task performance. As illustrated in Figure 1 (right), we take a backward pass from the last layers' token representations to input tokens, pruning the tokens by their importance score of each input token by accumulating its attention score. By doing so, we are able to better select and keep the important tokens that will minimally affect the output and the prediction. We name our novel token pruning method as Sparse Token Transformer with Attention Back-Tracking (STTABT). However, one challenge here is that such backtracking requires us to know the attention score of each token before feed-forwarding the input sequence to the model. To handle the issue, we introduce a lightweight attention approximation network trained with knowledge distillation called ApproxNet. Moreover, to actually prune the tokens, we need to decide which tokens to be retained based on the importance scores. The top-k method, by design, can select a predefined number of tokens (Goyal et al., 2020) , but requires sorting the tokens by their importance scores, which is not differentiable. The thresholding method, which prunes the tokens with scores under the threshold value (Kim et al., 2022) does not require sorting, but setting the threshold can be tricky since it is difficult to know how many tokens will be remained in advance with the given threshold. To remedy such issues in both methods, we propose a learnable and smooth threshold function named Concrete masking, inspired by Concrete Dropout (Gal et al., 2017) . Specifically, we jointly train the threshold value for each layer and importance score which is computed by attention back-tracking, with the task objective to find the thresholds that can minimize the task loss. Our method is generally applicable to Transformer encoders for any domains, and we validate our method on text and image classification tasks. (Touvron et al., 2021) . The experimental results show that our model works surprisingly well even at very high pruning rates, for which baselines show significant performance degeneration. For example, on a GLUE benchmark task (QQP), our method obtains 45.54% increased token sparsity (17.15% to 9.34%) compared to the baseline with the same accuracy. Moreover, our method obtains very high token sparsity (18.7%), keeping only half of the tokens retained by DynamicViT (37.5%), with only 0.8% accuracy loss on ImageNet-1K benchmark. Our contributions can be summarized as follows: • We propose a novel token pruning method for Transformers based on attention back-tracking, which considers the importance of each token on the final representations as well as the task loss based on the approximated attention from the distilled model. • We propose Concrete masking, which automatically learns the pruning rate for each layer with Concrete dropout. (Gal et al., 2017) .



Figure 1: Concepts. Comparison of token pruning method between feed-forward and attention back-tracking. The thickness of arrows represents the weight of attention probability. The left side is token pruning in a feed-forward pass. The right side is token pruning with attention back-tracking. et al., 2021; Goyal et al., 2020; Rao et al., 2021; Kong et al., 2021)  compute the importance score of an input token at each attention layer while doing the forward pass from the input to the final output (feed-forward token pruning, Figure1). Thus, they may prune out important tokens at earlier layers that are important for the final representation as well as the task loss. For example, in Figure1, the token Love in the layer l + 1 is pruned by the feed-forward method, although the token has a high attention probability in the representation at the final layer l + 2 that is used by the classification task.

For text classification, we validated the proposed token pruning methods on GLUE (Wang et al., 2019) benchmark with BERT (Devlin et al., 2019). For image classification, we validate it on ImageNet-1K (Deng et al., 2009) benchmark with DeiT

