PREDICTIVE ATTENTION TRANSFORMER: IMPROVING TRANSFORMER WITH ATTENTION MAP PREDICTION

Abstract

Transformer is a ubiquitous model for natural language processing and has also attracted wide attentions in other domains such as computer vision. The selfattention maps, learned independently for each layer, are indispensable for a transformer model to encode the dependencies among input tokens, however, learning them effectively is still a challenging problem. In this paper, we address this problem and propose a novel approach to improve self-attention through supplementary prediction modules. The underlying assumption is that the attention structures in the current layer should not be completely independent from those in the previous layers and can be better modeled via a convolutional inductive bias. Specifically, we propose Predictive Attention Transformer, which predicts the attention maps through a chain of convolutional neural networks and obtains significant performance gains for various kinds of tasks on top of multiple state-of-the-art models. On GLUE benchmark, the average performances of BERT-Base, BERT-Large, RoBERTa-Large and T5-Base are lifted by 4.1, 2.5, 0.8 and 1.3 points respectively. For ImageNet classification, we achieve significant improvement over multiple backbone models with different capacities.

1. INTRODUCTION

Transformer (Vaswani et al., 2017) is the state-of-the-art for sequential modeling which achieves superior performances in multiple domains, including natural language understanding (Devlin et al., 2019) , image generation (Parmar et al., 2018) and time-series forecasting (Li et al., 2019) . The performance of a transformer model largely depends on its capability of inducing reasonable dependencies among input tokens. However, as demonstrated by previous work (Jain & Wallace, 2019), it is difficult for a vanilla attention layer to capture the dependencies effectively without any apriori knowledge. To cope with this problem, recent efforts have tried to address the effectiveness of attention learning, such as concatenating self-attention with CNN layers to obtain a better representation (Bello et al., 2019; Wu et al., 2020) , or synthesizing the attention maps directly (Tay et al., 2020) . In this paper, we consider another question, can we improve the learning of attention maps via a dedicated prediction model? As we will see, it is possible through augmenting the transformer architecture by a chain of convolutional modules for attention map prediction. For a multi-layer transformer, the self-attention maps in each layer are learned independently, which introduces a huge amount of parameters and hurts the generalization ability. Our motivation is that we can bridge the attention maps from different layers, while a succeeding layer can take the knowledge from previous layers directly to induce a better dependency structure. To this end, we propose Predictive Attention Transformer (PA-Transformer), which guides the learning of attention maps via a chain of convolution-based prediction modules. In each block, PA-Transformer takes all attention maps generated by the previous block as a multi-channel image. Then, by performing 2D-convolution over that image, the attention maps for the current block can be predicted effectively and efficiently. In this way, the general patterns of inter-token dependencies are shared across all blocks, benefiting the generalization ability of a multi-layer Transformer. Meanwhile, the selfattention layer in each block is guided by the predicted attention patterns and can be learned to capture complementary relationships. As shown by a real case of image classification in Figure 1 (b), the attention map learned in the second PA-Transformer block correctly highlights the structure of a horse with the help of inherited knowledge from previous layers. Specifically, the convolution-based attention prediction module captures key patterns from a local perspective (probably owning to the 1

