SPARSE TOKEN TRANSFORMERS WITH ATTENTION BACK TRACKING

Abstract

Despite the success of Transformers in various applications from text, vision, and speech domains, they are yet to become standard architectures for mobile and edge device applications due to their heavy memory and computational requirements. While there exist many different approaches to reduce the complexities of the Transformers, such as the pruning of the weights/attentions/tokens, quantization, and distillation, we focus on token pruning, which reduces not only the complexity of the attention operations, but also the linear layers, which have non-negligible computational costs. However, previous token pruning approaches often remove tokens during the feed-forward stage without consideration of their impact on later layers' attentions, which has a potential risk of dropping out important tokens for the given task. To tackle this issue, we propose an attention back-tracking method that tracks the importance of each attention in a Transformer architecture from the outputs to the inputs, to preserve the tokens that have a large impact on the final predictions. We experimentally validate the effectiveness of the method on both NLP and CV benchmarks, using Transformer architectures for both domains, and the results show that the proposed attention back-tracking allows the model to better retain the full models' performance even at high sparsity rates, significantly outperforming all baselines. Qualitative analysis of the examples further shows that our method does preserve semantically meaningful tokens.

1. INTRODUCTION

Transformers have achieved huge success in various application domains such as natural language processing (NLP) and computer vision (CV), obtaining state-of-the-art performances on a variety of tasks, and are now considered the de facto standard architectures for a number of domains. However, Transformers require over a few tera-flops per entry (Devlin et al., 2019; Dosovitskiy et al., 2021b) to compute, which is orders of magnitude larger than the computational cost for previous CNN and RNN architectures (Tan & Le, 2019; Bahdanau et al., 2015) . To reduce such computational burdens of Transformer models, previous works explored model compression methods such as distillation (Wang et al., 2020; Jiao et al., 2020) , quantization (Han et al., 2015; Frankle & Carbin, 2018) , and pruning (Zaheer et al., 2020; Guo et al., 2019; Kim et al., 2022; Rao et al., 2021; Wang et al., 2021) . Pruning approaches for Transformers mostly aim to remove unnecessary model weights (Han et al., 2015; Frankle & Carbin, 2018) , or the attentions (Zaheer et al., 2020) , which could achieve a linear reduction in complexity. Token pruning can reduce the complexity of attentions and fully connected layers simultaneously by removing less relevant tokens for the target task (Kim et al., 2022; Goyal et al., 2020; Kong et al., 2021) . How can we decide which tokens to prune then? Previous studies either compute the importance score of each input token as the average attention scores (Goyal et al., 2020; Kim et al., 2022) (Figure 1 ), or learn the importance score of each token with an additional neural network (NN) at each layer (Rao et al., 2021; Kong et al., 2021) . However, previous token pruning methods have the following problem: they prune the tokens in the input sequence without explicitly evaluating the importance of each token on the final sequence representation and prediction tasks. This is because all existing works (Kim et al., 2022; Wang 

