MULTI-HEAD ATTENTION: COLLABORATE INSTEAD OF CONCATENATE

Abstract

Attention layers are widely used in natural language processing (NLP) and are beginning to influence computer vision architectures. Training very large transformer models allowed significan improvement in both fields, but once trained, these networks show symptoms of over-parameterization. For instance, it is known that many attention heads can be pruned without impacting accuracy. This work aims to enhance current understanding on how multiple heads interact. Motivated by the observation that trained attention heads share common key/query projections, we propose a collaborative multi-head attention layer that enables heads to learn shared projections. Our scheme decreases the number of parameters in an attention layer and can be used as a drop-in replacement in any transformer architecture. For instance, by allowing heads to collaborate on a neural machine translation task, we can reduce the key dimension by 4× without any loss in performance. We also show that it is possible to re-parametrize a pre-trained multi-head attention layer into our collaborative attention layer. Even without retraining, collaborative multi-head attention manages to reduce the size of the key and query projections by half without sacrificing accuracy. Our code is public.

1. INTRODUCTION

Since the invention of attention (Bahdanau et al., 2014) and its popularization in the transformer architecture (Vaswani et al., 2017) , multi-head attention (MHA) has become the de facto architecture for natural language understanding tasks (Devlin et al., 2019) and neural machine translation. Attention mechanisms have also gained traction in computer vision following the work of Ramachandran et al. (2019) and Bello et al. (2019) . Nevertheless, despite their wide adoption, we currently lack solid theoretical understanding of how transformers operate. In fact, many of their modules and hyperparameters are derived from empirical evidences that are possibly circumstantial. The uncertainty is amplified in multi-head attention, where both the roles and interactions between heads are still poorly understood. Empirically, it is well known that using multiple heads can improve model accuracy. However, not all heads are equally informative, and it has been shown that certain heads can be pruned without impacting model performance. For instance, Voita et al. (2019) present a method to quantify head utility and prune redundant members. Michel et al. (2019) go further to question the utility of multiple heads by testing the effect of heavy pruning in several settings. On the other hand, Cordonnier et al. (2020) prove that multiple heads are needed for self-attention to perform convolution, specifically requiring one head per pixel in the filter's receptive field. Beyond the number of heads, finding the adequate head dimension is also an open question. Bhojanapalli et al. (2020) finds that the division of the key/query projection between heads gives rise to a low-rank bottleneck for each attention head expressivity that can be fixed by increasing the head sizes. In contrast, our approach increases heads expressivity by leveraging the low-rankness accross heads to share common query/key dimensions. This work aims to better detect and quantify head redundancy by asking whether independent heads learn overlapping or distinct concepts. This relates to the work on CNN compression that factorizes common filters in a trained convolutional network (Kim et al., 2016) using Tucker decomposition. In attention models, we discover that some key/query projected dimensions are redundant, as trained concatenated heads tend to compute their attention patterns on common features. Our finding implies that MHA can be re-parametrized with better weight sharing for these common projections and a lower number of parameters. This differs from concurrent work (Shazeer et al., 2020) that orchestrate collaboration between heads on top of the dot product attention scores. Contribution 1: Introducing the collaborative multi-head attention layer. Section 3 describes a collaborative attention layer that allows heads to learn shared key and query features. The proposed re-parametrization significantly decreases the number of parameters of the attention layer without sacrificing performance. Our Neural Machine Translation experiments in Section 4 show that the number of FLOPS and parameters to compute the attention scores can be divided by 4 without affecting the BLEU score on the WMT14 English-to-German task. Contribution 2: Re-parametrizing pre-trained models into a collaborative form renders them more efficient. Pre-training large language models has been central to the latest NLP developments. But pre-training transformers from scratch remains daunting for its computational cost even when using more efficient training tasks such as (Clark et al., 2020) . Interestingly, our changes to the MHA layers can be applied post-hoc on pre-trained transformers, as a drop-in replacement of classic attention layers. To achieve this, we compute the weights of the re-parametrized layer using canonical tensor decomposition of the query and key matrices in the original layer. Our experiments in Section 4 show that the key/query dimensions can be divided by 3 without any degradation in performance. As a side contribution, we identify a discrepancy between the theory and some implementations of attention layers and show that by correctly modeling the biases of key and query layers, we can clearly differentiate between context and content-based attention.

2. MULTI-HEAD ATTENTION

We first review standard multi-head attention introduced by Vaswani et al. (2017) .

2.1. ATTENTION

Let X ∈ R T ×Din and Y ∈ R T ×Din be two input matrices consisting of respectively T and T tokens of D in dimensions each. An attention layer maps each of the T query token from D in to D out dimensions as follows: Attention(Q, K, V ) = softmax QK √ d k V , with Q = XW Q , K = Y W K , V = Y W V (1) The layer is parametrized by a query matrix W Q ∈ R Din×D k , a key matrix W K ∈ R Din×D k and a value matrix W V ∈ R Din×Dout . Using attention on the same sequence (i.e. X = Y ) is known as self-attention and is the basic building block of the transformer architecture. 2.2 CONTENT VS. CONTEXT Some re-implementations of the original transformer architecturefoot_1 use biases in the linear layers. This differs from the attention operator defined in eq. ( 1) where the biases b Q and b K ∈ R D k are ommited. Key and query projections are computed as K = XW K + 1 T ×1 b K and Q = Y W Q + 1 T ×1 b Q , respectively , where 1 a×b is an all one matrix of dimension a × b. The exact computation of the (unscaled) attention scores can be decomposed as follows: QK = (XW Q + 1 T ×1 b Q )(Y W K + 1 T ×1 b K ) (2) = XW Q W K Y context + 1 T ×1 b Q W K Y content +XW Q b K 1 1×T + 1 T ×T b Q b K (3) As the last two terms of eq. ( 3) have a constant contribution over all entries of the same row, they do not contribute to the computed attention probabilities (softmax is shift invariant and softmax(x + c) = softmax(x), ∀c). On the other hand, the first two terms have a clear meaning: Hence, the heads are sharing common projections in their column-space. XW Q W K Y (i) Q W (i) K W (i) K W (i) Q 1 64 considers the relation between keys and query pairs, whereas 1 T ×1 b Q W K Y computes attention solely based on key content. The above findings suggest that the bias b K of the key layer can be always be disabled without any consequence. Moreover, the query biases b Q play an additional role: they allow for attention scores that are content-based, rather than solely depending on key-query interactions. This could provide an explanation for the recent success of the Dense-SYNTHESIZER (Tay et al., 2020) , a method that ignores context and computes attention scores solely as a function of individual tokens. That is, perhaps context is not always crucial for attention scores, and content can suffice.

2.3. MULTI-HEAD ATTENTION

Traditionally, the attention mechanism is replicated by concatenation to obtain multi-head attention defined for N h heads as: MultiHead(X, Y ) = concat i∈[N h ] H (i) W O H (i) = Attention(XW (i) Q , Y W (i) K , Y W (i) V ), where distinct parameter matrices W (i) Q , W (i) K ∈ R Din×d k and W (i) V ∈ R Din×dout are learned for each head i ∈ [N h ] and the extra parameter matrix W O ∈ R N h dout×Dout projects the concatenation of the N h head outputs (each in R dout ) to the output space R Dout . In the multi-head setting, we call d k the dimension of each head and D k = N h d k the total dimension of the query/key space.

3. IMPROVING THE MULTI-HEAD MECHANISM

Head concatenation is a simple and remarkably practical setup that gives empirical improvements. However, we show that another path could have been taken instead of concatenation. As the multiple heads are inherently solving similar tasks, they can collaborate instead of being independent.

3.1. HOW MUCH DO HEADS HAVE IN COMMON?

We hypothesize that some heads might attend on similar features in the input space, for example computing high attention on the verb of a sentence or extracting some dimensions of the positional encoding. To verify this hypothesis, it does not suffice to look at the similarity between query (or key) matrices {W (i) Q } i∈[N h ] of different heads. To illustrate this issue, consider the case where two heads are computing the same key/query representations up to a unitary matrix R ∈ R d k ×d k such that W (2) Q = W (1) Q R and W (2) K = W (1) K R. Even though the two heads are computing identical attention scores, i.e. W (1) Q RR W (1) K = W (1) Q W (1) K , they can have orthogonal column-spaces and the concatenation [W (1) Q , W (2) Q ] ∈ R Din×2d k can be full rank. To disregard artificial differences due to common rotations or scaling of the key/query spaces, we study the similarity of the product 1 shows the captured energy by the principal components of the key, query matrices and their product. It can be seen on the left that single head key/query matrices W (i) Q W (i) K ∈ R Din×Din across heads. Figure W (i) Q W (i) K are not low rank on average. However, as seen on the right, even if parameter matrices taken separately are not low rank, their concatenation is indeed low rank. This means that heads, though acting independently, learn to focus on the same subspaces. The phenomenon is quite pronounced: one third of the dimensions suffices to capture almost all the energy of W Q W K , which suggests that there is inefficiency in the way multi-head attention currently operate.

3.2. COLLABORATIVE MULTI-HEAD ATTENTION

Following the observation that heads' key/query projections learn redundant projections, we propose to learn key/query projections for all heads at once and to let each head use a re-weighting of these projections. Our collaborative head attention is defined as follows: CollabHead(X, Y ) = concat i∈[N h ] H (i) W O (6) H (i) = Attention(X WQ diag(m i ), Y WK , Y W (i) V ) . The main difference with standard multi-head attention defined in eq. ( 5) is that we do not duplicate the key and query matrices for each head. Instead, each head learns a mixing vector m i ∈ R Dk that defines a custom dot product over the Dk projected dimensions of the shared matrices WQ and WK of dimension D in × Dk . This approach leads to: (i) adaptive head expressiveness, with heads being able to use more or fewer dimensions according to attention pattern complexity; (ii) parameter efficient representation, as learned projections are shared between heads, hence stored and learned only once. It is instructive to observe how standard multi-head attention (where heads are simply concatenated) can be seen as a special case of our collaborative framework (with Dk = N h d k ). The left of Figure 2 displays the standard attention computed between x n and y m input vectors with the mixing matrix M := concat i∈[N h ] m i ∈ R N h × Dk , laying out the mixing vectors m i as rows. In the concatenated MHA, the mixing vector m i for the i-th head is a vector with ones aligned with the d k dimensions allocated to the i-th head among the D k = N h d k total dimensions. Some alternative collaborative schema can be seen on the right side of Figure 2 . By learning the mixing vectors {m i } i∈[N h ] instead of fixing them to this "blocks-of-1" structure, we increase the expressive power of each head for a negligible increase in the number of parameters. The size d k of each head, arbitrarily set to 64 in most implementations, is now adaptive and the heads can attend to a smaller or bigger subspace if needed.

3.3. HEAD COLLABORATION AS TENSOR DECOMPOSITION

As we show next, there is a simple way to convert any standard attention layer to collaborative attention without retraining. To this end, we must extract the common dimensions between query/key matrices {W (i) Q W (i) K ∈ R Din×Din } i∈[N h ] across the different heads. This can be solved using the Tucker tensor decomposition (Tucker, 1966) of the 3rd-order tensor Following the notationfoot_2 of Kolda & Bader (2009) , the Tucker decomposition of a tensor T ∈ R I×J×K is written as W QK := stack i∈[N h ] W (i) Q W (i) K ∈ R N h ×Din×Din . ( ) 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 T ≈ G × 1 A × 2 B × 3 C = P p=1 Q q=1 R r=1 g pqr a p • b q • c r =: G; A, B, C , with A ∈ R I×P , B ∈ R J×Q , and C ∈ R K×R being factor matrices, whereas G ∈ R P ×Q×R is the core tensor. Intuitively, the core entry g pqr = G p,q,r quantifies the level of interaction between the components a p , b q , and c r . In the case of attention, it suffices to consider the dot product of the aligned key/query components of the Q and K matrices, which means that the core tensor is super-diagonal (i.e. g pqr = 0 only if q = r). We further simplify the Tucker decomposition by setting the factors dimensions P, Q and R to Dk , a single interpretable hyperparameter equal to the dimension of the shared key/query space that controls the amount of compression of the decomposition into collaborative heads. These changes lead to a special case of Tucker decomposition called the canonical decomposition, also known as CP or PARAFAC (Harshman, 1970) in the literature (Kolda & Bader, 2009) . Fix any positive rank R. The decomposition yields: T ≈ R r=1 a r • b r • c r =: A, B, C , with A ∈ R I×R , B ∈ R J×R and C ∈ R K×R . What is remarkable is that the above can be used to express any (trained) attention layer parametrized by {W (i) Q , b (i) Q , W (i) K , b (i) K } i∈[N h ] as a collaborative layer. In particular, if we apply the decomposition to the stacked heads W QK we obtain the three matrices M , WQ , WK that define a collaborative attention layer: the mixing matrix M ∈ R N h × Dk , as well as the key and query projection matrices WQ , WK ∈ R Din× Dk . On the other hand, biases can be easily dealt with based on the content/context decomposition of eq. ( 3), by storing for each head the vector v i = W (i) K b (i) Q ∈ R Din . (12) With this in place, the computation of the (unscaled) attention score for the i-th head is given by: XW (i) Q + 1 T ×1 b Q Y W (i) K + 1 T ×1 b K ≈ X WQ diag(m i ) W K Y + 1 T ×1 v i Y , ( ) where m i is the i-th row of M . If Dk ≥ D k the decomposition is exact (eq. ( 11) is an equality) and our collaborative heads layer can express any concatenation-based attention layer. We also note that the proposed re-parametrization can be applied to the attention layers of many transformer architectures, such as the ones proposed by Computational cost. Our layer decomposes two matrices into three, of modulable dimensions. To compute the attention scores between T tokens for all the N h heads, collaborative MHA requires 2T (D in +N h ) Dk +T 2 N h Dk FLOPS, while the concatenation-based MHA uses 2T D in D k +T 2 D k FLOPS. Assuming that D in N h = O(1) (as is common in most implementations), we obtain a theoretical speedup of Θ(D k / Dk ). However in practice, having two matrix multiplications instead of a larger one makes our implementation slightly slower, if larger multiplications are supported by the hardware.

4. EXPERIMENTS

The goal of our experimental section is two-fold. First, we show that concatenation-based MHA is a drop-in replacement for collaborative MHA in transformer architectures. We obtain a significant reduction in the number of parameters and number of FLOPS without sacrificing performance on a Neural Machine Translation (NMT) task with an encoder-decoder transformer. Secondly, we verify that our tensor decomposition allows one to reparametrize pre-trained transformers, such as BERT (Devlin et al., 2019) and its variants. To this end, we show that collaborative MHA performs on par with its concatenation-based counter-part on the GLUE benchmark (Wang et al., 2018) for Natural Language Understanding (NLU) tasks, even without retraining. The NMT experiments are based on the FairSeq (Ott et al., 2019) implementation of transformer-base by Vaswani et al. (2017) . For the NLU experiments, we implemented the collaborative MHA layer as an extension of the Transformers library (Wolf et al., 2019) . The flexibility of our layer allows it to be applied to most of the existing transformer architectures, either at pre-training or after fine-tuning (Vaswani et al., 2017) using collaborate vs. concatenate heads with key/query dimension D k . We visualize performence as a function of number of parameters (middle) and training time (right). Collaborative attention consistently improves BLEU score, D k can be decreased by a factor of 4 without drop in performance. using tensor decomposition. We use the tensor decomposition library Tensorly (Kossaifi et al., 2019) with the PyTorch backend (Paszke et al., 2017) to reparameterize pre-trained attention layers. Our code and datasets are publicly availablefoot_3 and all hyperparameters are specified in the Appendix.

4.1. COLLABORATIVE MHA FOR NEURAL MACHINE TRANSLATION

We replace the concatenation-based MHA layers of an encoder-decoder transformer by our collaborative MHA and evaluate it on the WMT14 English-to-German translation task. Following (Vaswani et al., 2017) , we train on the WMT16 train corpus, apply checkpoint averaging and report compound split tokenized BLEU. We use the same hyperparameters as the baseline for all our runs. Results are shown in Figure 3 . Our run of the original base transformer with N h = 8 heads and D k = 512 key/query total dimensions achieves 27.40 BLUE (instead of 27.30). As observed in the original paper by Vaswani et al. (2017) , decreasing the key/query head size d k degrades the performance (× in Figure 3 ). However, with collaborative heads (+ in Figure 3 ), the shared key/query dimension can be reduced by 4× without decreasing the BLEU score. As feed-forward layers and embeddings are left untouched, this translates to a 10% decrease in number of parameters for a slight increase in training time. When setting a total key/query dimension of D k = 64, corresponding to d k = 8 dimensions per head, the classic MHA model suffers a drop of 0.6 BLEU points, meanwhile the collaborative MHA stays within 0.1 point of the transformer-base model using concatenation. We conclude that sharing key/query projections across heads allows attention features to be learned and stored only once. This weight sharing enables decreasing D k without sacrificing expressiveness. We turn to experiments on Natural Language Understanding (NLU) tasks, where transformers have been decisive in improving the state-of-the-art. As pre-training on large text corpora remains an expensive task, we leverage the post-hoc reparametrization introduced in Section 3.3 to cast already pretrained models into their collaborative form. We proceed in 3 steps for each GLUE task (Wang et al., 2018) . First, we take a pre-trained transformer and fine-tune it on each task individually. Secondly, we replace all the attention layers by our collaborative MHA using tensor decomposition to compute WQ , WK and M and re-parametrize the biases into v. This step only takes a few minutes as shown in Figure 4 . Finally, we fine-tune the compressed model again and evaluate its performance.

4.2. RE-PARAMETRIZE

We experiment with a pre-trained BERT-base model (Devlin et al., 2019) . We also repurpose two variants of BERT designed to be more parameter efficient: ALBERT (Lan et al., 2020) , an improved Table 2 : Performance of collaborative MHA on the GLUE benchmark (Wang et al., 2018) . We report the median of 3 runs for BERT (Devlin et al., 2019) , DistilBERT (Sanh et al., 2019) and ALBERT (Lan et al., 2020) with collaborative heads and different compression controlled by Dk . Comparing the original models (D k = 768) with their compressed counter part shows that the number of parameters can be decreased with less than 1.5% performance drop (gray rows). transformer with a single layer unrolled, and DistilBERT (Sanh et al., 2019) a smaller version of BERT trained with distillation. We report in Table 2 the median performance of 3 independent runs of the models on the GLUE benchmark (Wang et al., 2018) . We first verify that tensor decomposition without compression ( Dk = D k = 768) does not alter performance. As shown in Table 2 , both BERT-base and its decomposition performs similarly with an average score of 83.0% and 83.2% respectively. We should clarify that, for consistency, we opted to re-finetune the model in all cases (even when Dk = D k ), and that the slight score variation disappears without re-finetuning. Nevertheless, even with re-finetuning, reparametrizing the attention layers into collaborative form is beneficial in 4 out of the 8 tasks, as well as in terms of the average score. We then experiment with compressed decomposition using a smaller Dk . Comparing the original models with their well-performing compressed counterpart (gray rows) shows that the key/query dimension of BERT and DistilBERT can be reduced by 2× and 3× respectively without sacrificing more than 1.5% of performance. This is especially remarkable given that DistilBERT was designed to be a parameter-efficient version of BERT. It seems that ALBERT suffers more from compression, but the dimension can be reduced by a factor 1.5× with minor performance degradation. We suspect that unrolling the same attention layer over the depth of the transformer forces the heads to use different projections and decreases their overlap, decreasing the opportunity for weight-sharing. Our hypothesis is that better performance may be obtained by pre-training the whole BERT architecture variants from scratch. Recovering from compression with fine-tuning. We further investigate the necessity of the second fine-tuning-step 3 of our experimental protocol-after the model compression. Figure 5 shows the performance of BERT-base on 3 GLUE tasks for different compression parameters Dk with and without the second fine-tuning. We find that for compression up to 1.5× (from D k = 768 to Dk = 512), the re-parametrization is accurate and performance is maintained without fine-tuning again. Further compressing the model starts to affect performance. Nevertheless, for compression by up to 3× (to Dk = 256), this loss can readily be recovered by a second fine-tuning (in orange).

5. CONCLUSION

This work showed that trained concatenated heads in multi-head attention models can extract redundant query/key representations. To mitigate this issue, we propose to replace concatenation-based MHA by collaborative MHA. When our layer is used as a replacement for standard MHA in encoder/decoder transformers for Neural Machine Translation, it enables the decrease of effective individual head size from d k = 64 to 8 without impacting performance. Further, without pre-training from scratch, switching a MHA layer to collaborative halves the number of FLOPS and parameters needed to compute the attentions score affecting the GLUE score by less than 1.5%. Our model can impact every transformer architecture and our code (publicly available) provides post-hoc compression of already trained networks. We believe that using collaborative MHA in models pre-trained from scratch could force heads to extract meaningful shared query/key features. We are curious if this would translate to faster pre-training, better performance on downstream tasks and improved interpretability of the attention mechanism.



https://github.com/... For instance: the BERT orignal implementation, its HuggingFace re-implementation and FairSeq encoderdecoder transformer. • represents the vector outer product https://github.com/...



Figure 2: Left: computation of the attention scores between tokens x n and y m using a standard concatenated multi-head attention with N h = 3 independent heads. The block structure of the mixing matrix M enforces that each head dot products non overlapping dimensions. Right: we propose to use more general mixing matrices M than (a) heads concatenation, such as (b) allowing heads to have different sizes; (c) sharing heads projections by learning the full matrix; (d) compressing the number of projections from D k to Dk as heads can share redundant projections.

Collaborative heads use (2D in + N h ) Dk parameters, as compared to 2D in D k in the standard case (ignoring biases). Hence, the compression ratio is ≈ D k / Dk , controlled by the shared key dimension Dk . The collaborative factorization introduces a new matrix M of dimension N h × Dk . Nevertheless, as the number of heads is small compared to the hidden dimension (in BERT-base N h = 12 whereas D in = 768), the extra parameter matrix yields a negligible increase as compared to the size of the query/key/values matrices of dimension D in × D k .

Figure 3: Comparison of the BLEU score on WMT14 EN-DE translation task for an encoder-decoder transformer-base (Vaswani et al., 2017) using collaborate vs. concatenate heads with key/query dimension D k . We visualize performence as a function of number of parameters (middle) and training time (right). Collaborative attention consistently improves BLEU score, D k can be decreased by a factor of 4 without drop in performance.

A PRE-TRAINED MHA INTO COLLABORATIVE MHA

Figure 4: Time to decompose BERT-base from D k = 768 to Dk .

Figure 5: Performance on MNLI, MRPC and STS-B datasets of a fine-tuned BERT-base model, decomposed with collaborative heads of compressed dimension Dk (horizontal axis).Repeating fine-tuning after compression can make the model recover the original performance when compression was drastic. The GLUE baseline gives a reference for catastrophic failure.

= 12 heads of dimension d k = 64. Bold lines show the means. Even though, by themselves, heads are not low rank (left), the product of their concatenation W Q W K is low rank (right, in red).

Supplementary Material

A HYPERPARAMETERS FOR NEURAL MACHINE TRANSLATION EXPERIMENTS Our implementation is based on Fairseq implementation Ott et al. (2019) . We report in the following tables the specification of the architecture. We used the default hyperparameters if they are not specified below. 

B HYPERPARAMETERS FOR NATURAL LANGUAGE UNDERSTANDING EXPERIMENTS

We use standard models downloadable from HuggingFace repository along with their configuration. We use HuggingFace default hyperparameters for GLUE fine-tuning in all our runs. We train with a learning rate of 2 • 10 -5 for 3 epochs for all datasets except SST-2 and RTE where we train for 10 epochs. In preliminary experiments, we tried to tune the tensor decomposition tolerance hyperparameter among {10 -6 , 10 -7 , 10 -8 } but did not see significant improvement and kept the default 10 -6 for all our experiments. 

