MULTI-HEAD ATTENTION: COLLABORATE INSTEAD OF CONCATENATE

Abstract

Attention layers are widely used in natural language processing (NLP) and are beginning to influence computer vision architectures. Training very large transformer models allowed significan improvement in both fields, but once trained, these networks show symptoms of over-parameterization. For instance, it is known that many attention heads can be pruned without impacting accuracy. This work aims to enhance current understanding on how multiple heads interact. Motivated by the observation that trained attention heads share common key/query projections, we propose a collaborative multi-head attention layer that enables heads to learn shared projections. Our scheme decreases the number of parameters in an attention layer and can be used as a drop-in replacement in any transformer architecture. For instance, by allowing heads to collaborate on a neural machine translation task, we can reduce the key dimension by 4× without any loss in performance. We also show that it is possible to re-parametrize a pre-trained multi-head attention layer into our collaborative attention layer. Even without retraining, collaborative multi-head attention manages to reduce the size of the key and query projections by half without sacrificing accuracy. Our code is public.

1. INTRODUCTION

Since the invention of attention (Bahdanau et al., 2014) and its popularization in the transformer architecture (Vaswani et al., 2017) , multi-head attention (MHA) has become the de facto architecture for natural language understanding tasks (Devlin et al., 2019) and neural machine translation. Attention mechanisms have also gained traction in computer vision following the work of Ramachandran et al. (2019) and Bello et al. (2019) . Nevertheless, despite their wide adoption, we currently lack solid theoretical understanding of how transformers operate. In fact, many of their modules and hyperparameters are derived from empirical evidences that are possibly circumstantial. The uncertainty is amplified in multi-head attention, where both the roles and interactions between heads are still poorly understood. Empirically, it is well known that using multiple heads can improve model accuracy. However, not all heads are equally informative, and it has been shown that certain heads can be pruned without impacting model performance. For instance, Voita et al. (2019) present a method to quantify head utility and prune redundant members. Michel et al. (2019) go further to question the utility of multiple heads by testing the effect of heavy pruning in several settings. On the other hand, Cordonnier et al. (2020) prove that multiple heads are needed for self-attention to perform convolution, specifically requiring one head per pixel in the filter's receptive field. Beyond the number of heads, finding the adequate head dimension is also an open question. Bhojanapalli et al. (2020) finds that the division of the key/query projection between heads gives rise to a low-rank bottleneck for each attention head expressivity that can be fixed by increasing the head sizes. In contrast, our approach increases heads expressivity by leveraging the low-rankness accross heads to share common query/key dimensions. This work aims to better detect and quantify head redundancy by asking whether independent heads learn overlapping or distinct concepts. This relates to the work on CNN compression that factorizes common filters in a trained convolutional network (Kim et al., 2016) using Tucker decomposition. In attention models, we discover that some key/query projected dimensions are redundant, as trained concatenated heads tend to compute their attention patterns on common features. Our finding implies that MHA can be re-parametrized with better weight sharing for these common projections and a lower number of parameters. This differs from concurrent work (Shazeer et al., 2020) that orchestrate collaboration between heads on top of the dot product attention scores. Contribution 1: Introducing the collaborative multi-head attention layer. Section 3 describes a collaborative attention layer that allows heads to learn shared key and query features. The proposed re-parametrization significantly decreases the number of parameters of the attention layer without sacrificing performance. Our Neural Machine Translation experiments in Section 4 show that the number of FLOPS and parameters to compute the attention scores can be divided by 4 without affecting the BLEU score on the WMT14 English-to-German task. Contribution 2: Re-parametrizing pre-trained models into a collaborative form renders them more efficient. Pre-training large language models has been central to the latest NLP developments. But pre-training transformers from scratch remains daunting for its computational cost even when using more efficient training tasks such as (Clark et al., 2020) . Interestingly, our changes to the MHA layers can be applied post-hoc on pre-trained transformers, as a drop-in replacement of classic attention layers. To achieve this, we compute the weights of the re-parametrized layer using canonical tensor decomposition of the query and key matrices in the original layer. Our experiments in Section 4 show that the key/query dimensions can be divided by 3 without any degradation in performance. As a side contribution, we identify a discrepancy between the theory and some implementations of attention layers and show that by correctly modeling the biases of key and query layers, we can clearly differentiate between context and content-based attention.

2. MULTI-HEAD ATTENTION

We first review standard multi-head attention introduced by Vaswani et al. (2017).

2.1. ATTENTION

Let X ∈ R T ×Din and Y ∈ R T ×Din be two input matrices consisting of respectively T and T tokens of D in dimensions each. An attention layer maps each of the T query token from D in to D out dimensions as follows: Attention(Q, K, V ) = softmax QK √ d k V , with Q = XW Q , K = Y W K , V = Y W V (1) The layer is parametrized by a query matrix W Q ∈ R Din×D k , a key matrix W K ∈ R Din×D k and a value matrix W V ∈ R Din×Dout . Using attention on the same sequence (i.e. X = Y ) is known as self-attention and is the basic building block of the transformer architecture. 2.2 CONTENT VS. CONTEXT Some re-implementations of the original transformer architecturefoot_1 use biases in the linear layers. This differs from the attention operator defined in eq. ( 1) where the biases b Q and b K ∈ R D k are ommited. Key and query projections are computed as K = XW K + 1 T ×1 b K and Q = Y W Q + 1 T ×1 b Q , respectively , where 1 a×b is an all one matrix of dimension a × b. The exact computation of the (unscaled) attention scores can be decomposed as follows: QK = (XW Q + 1 T ×1 b Q )(Y W K + 1 T ×1 b K ) (2) = XW Q W K Y context + 1 T ×1 b Q W K Y content +XW Q b K 1 1×T + 1 T ×T b Q b K (3) As the last two terms of eq. ( 3) have a constant contribution over all entries of the same row, they do not contribute to the computed attention probabilities (softmax is shift invariant and softmax(x + c) = softmax(x), ∀c). On the other hand, the first two terms have a clear meaning: XW Q W K Y



https://github.com/... For instance: the BERT orignal implementation, its HuggingFace re-implementation and FairSeq encoderdecoder transformer.

