APPROXIMATING HOW SINGLE HEAD ATTENTION LEARNS

Abstract

Why do models often attend to salient words, and how does this evolve throughout training? We approximate model training as a two stage process: early on in training when the attention weights are uniform, the model learns to translate individual input word i to o if they co-occur frequently. Later, the model learns to attend to i while the correct output is o because it knows i translates to o. To formalize, we define a model property, Knowledge to Translate Individual Words (KTIW) (e.g. knowing that i translates to o), and claim that it drives the learning of the attention. This claim is supported by the fact that before the attention mechanism is learned, KTIW can be learned from word co-occurrence statistics, but not the other way around. Particularly, we can construct a training distribution that makes KTIW hard to learn, the learning of the attention fails, and the model cannot even learn the simple task of copying the input words to the output. Our approximation explains why models sometimes attend to salient words, and inspires a toy example where a multi-head attention model can overcome the above hard training distribution by improving learning dynamics rather than expressiveness. We end by discussing the limitation of our approximation framework and suggest future directions.

1. INTRODUCTION

The attention mechanism underlies many recent advances in natural language processing, such as machine translation Bahdanau et al. (2015) and pretraining Devlin et al. (2019) . While many works focus on analyzing attention in already-trained models Jain & Wallace (2019) ; Vashishth et al. (2019) ; Brunner et al. (2019) ; Elhage et al. (2021) ; Olsson et al. (2022) , little is understood about how the attention mechanism is learned via gradient descent at training time. These learning dynamics are important, as standard, gradient-trained models can have very unique inductive biases, distinguishing them from more esoteric but equally accurate models. For example, in text classification, while standard models typically attend to salient (high gradient influence) words Serrano & Smith (2019) , recent work constructs accurate models that attend to irrelevant words instead Wiegreffe & Pinter (2019); Pruthi et al. (2020) . In machine translation, while the standard gradient descent cannot train a high-accuracy transformer with relatively few attention heads, we can construct one by first training with more heads and then pruning the redundant heads Voita et al. (2019) ; Michel et al. (2019) . To explain these differences, we need to understand how attention is learned at training time. Our work opens the black box of attention training, focusing on attention in LSTM Seq2Seq models Luong et al. (2015) (Section 2.1). Intuitively, if the model knows that the input individual word i translates to the correct output word o, it should attend to i to minimize the loss. This motivates us to investigate the model's knowledge to translate individual words (abbreviated as KTIW), and we define a lexical probe β to measure this property. We claim that KTIW drives the attention mechanism to be learned. This is supported by the fact that KTIW can be learned when the attention mechanism has not been learned (Section 3.2), but not the other way around (Section 3.3). Specifically, even when the attention weights are frozen to be uniform, probe β still strongly agrees with the attention weights of a standardly trained model. On the other hand, when KTIW cannot be learned, the attention mechanism cannot be learned. Particularly, we can construct a distribution where KTIW is hard to learn; as a result, the model fails to learn a simple task of copying the input to the output. Now the problem of understanding how attention mechanism is learned reduces to understanding how KTIW is learned. Section 2.3 builds a simpler proxy model that approximates how KTIW is learned, and Section 3.2 verifies empirically that the approximation is reasonable. This proxy model is simple enough to analyze and we interpret its training dynamics with the classical IBM Translation Model 1 (Section 4.2), which translates individual word i to o if they co-occur more frequently. To collapse this chain of reasoning, we approximate model training in two stages. Early on in training when the attention mechanism has not been learned, the model learns KTIW through word co-occurrence statistics; KTIW later drives the learning of the attention. Using these insights, we explain why attention weights sometimes correlate with word saliency in binary text classification (Section 5.1): the model first learns to "translate" salient words into labels, and then attend to them. We also present a toy experiment (Section 5.2) where multi-head attention improves learning dynamics by combining differently initialized attention heads, even though a single head model can express the target function. Nevertheless, "all models are wrong". Even though our framework successfully explains and predicts the above empirical phenomena, it cannot fully explain the behavior of attention-based models, since approximations are after all less accurate. Section 6 identifies and discusses two key assumptions: (1) information of a word tends to stay in the local hidden state (Section 6.1) and (2) attention weights are free variables (Section 6.2). We discuss future directions in Section 7.

2. MODEL

Section 2.1 defines the LSTM with attention Seq2Seq architecture. Section 2.2 defines the lexical probe β, which measures the model's knowledge to translate individual words (KTIW). Section 2.3 approximates how KTIW is learned early on in training by building a "bag of words" proxy model. Section 2.4 shows that our framework generalizes to binary classification. (1) We count all cooccurrences of the input and output words. (5) Alignment α: how much each input word contributes towards the 2nd output word "movie". It is attracted to "Film".  α 2,1 = Alignment(t = 2, l = 2) = β 2,1 ∑ 4 l= 1 β 2,l

βt=2

(2) "Film" is more likely to translate to "movie". (1) The model first learns word translation under uniform attention when training starts. (3) Attention α is then attracted to the word "Film". Dieser Film ist großartig movie 0.01 0.60 0.02 0.03

Dieser Film ist großartig

This movie is great h4 h3 h2 h1 (4) "Film" is more likely to translate to "movie".

Classical Alignment Learning Procedure

Figure 1 : Attention mechanism in recurrent models (left, Section 2.1) and word alignments in the classical model (right, Section 4.2) are learned similarly. Both first learn how to translate individual words (KTIW) under uniform attention weights/alignment at the start of training (upper, blue background), which then drives the attention mechanism/alignment to be learned (lower, red background).

2.1. MACHINE TRANSLATION MODEL

We use the dot-attention variant from Luong et al. (2015) . The model maps from an input sequence {x l } with length L to an output sequence {y t } with length T . We first use LSTM encoders to embed {x l } ⊂ I and {y t } ⊂ O respectively, where I and O are input and output vocab space, and obtain encoder and decoder hidden states {h l } and {s t }. Then we calculate the attention logits a t,l by applying a learnable mapping from h l and s t , and use softmax to obtain the attention weights α t,l : a t,l = s T t W h l ; α t,l = e a t,l L l ′ =1 e a t,l ′ . (1) Next we sum the encoder hidden states {h t } weighted by the attention to obtain the "context vector" c t , concatenate it with the decoder s t , and obtain the output vocab probabilities p t by applying a learnable neural network N with one hidden layer and softmax activation at the output, and train the model by minimizing the sum of negative log-likelihood of all the output words y t . c t = L l=1 α t,l h l ; p t = N ([c t , s t ]); L = - T t=1 log p t,yt . (2)

2.2. LEXICAL PROBE β

We define the lexical probe β t,l as: β t,l := N ([h l , s t ]) yt , which means "the probability assigned to the correct word y t , if the network attends only to the input encoder state h l ". If we assume that h l only contains information about x l , β closely reflects KTIW, since β can be interpreted as "the probability that x l is translated to the output y t ". Heuristically, to minimize the loss, the attention weights α should be attracted to positions with larger β t,l . Hence, we expect the learning of the attention to be driven by KTIW (Figure 1 left). We then discuss how KTIW is learned.

2.3. EARLY DYNAMICS OF LEXICAL KNOWLEDGE

To approximate how KTIW is learned early on in training, we build a proxy model by making a few simplifying assumptions. First, since attention weights are uniform early on in training, we replace the attention distribution with a uniform one. Second, since we are defining individual word translation, we assume that information about each word is localized to its corresponding hidden state. Therefore, similar to Sun & Lu (2020) , we replace h l with an input word embedding e x l ∈ R d , where e represents the word embedding matrix and d is the embedding dimension. Third, to simplify analysis, we assume N only contains one linear layer W ∈ R |O|×d before softmax activation and ignore the decoder state s t . Putting these assumptions together, we now define a new proxy model that produces output vocab probability p t := σ( 1 L L l=1 W e x l ). On a high level, this proxy averages the embeddings of the input "bag of words", and produces a distribution over output vocabs to predict the output "bag of words". This implies that the sets of input and output words for each sentence pair are sufficient statistics for this proxy. The probe β px can be similarly defined as β px t,l := σ(W e x l ) yt . We provide more intuitions on how this proxy learns in Section 4.

2.4. BINARY CLASSIFICATION MODEL

Binary classification can be reduced to "machine translation", where T = 1 and |O| = 2. We drop the subscript t = 1 when discussing classification. We use the standard architecture from Wiegreffe & Pinter (2019) . After obtaining the encoder hidden states {h t }, we calculate the attention logits a l by applying a feed-forward neural network with one hidden layer and take the softmax of a to obtain the attention weights α: a l = v T (ReLU (Qh l )); α l = e a l L l ′ =1 e a l ′ , where Q and v are learnable. We sum the hidden states {h l } weighted by the attention, feed it to a final linear layer and apply the sigmoid activation function (σ) to obtain the probability for the positive class p pos = σ(W T L l=1 a l h l ) = σ( L l=1 α l W T h l ). Similar to the machine translation model (Section 2.1), we define the "lexical probe": β l := σ((2y -1)W T h l ), where y ∈ {0, 1} is the label and 2y -1 ∈ {-1, 1} controls the sign. On a high level, Sun & Lu (2020) focuses on binary classification and provides almost the exact same arguments as ours. Specifically, their polarity score "s l " equals β l 1-β l in our context, and they provide a more subtle analysis of how the attention mechanism is learned in binary classification.

3. EMPIRICAL EVIDENCE

We provide evidence that KTIW drives the learning of the attention early on in training: KTIW can be learned when the attention mechanism has not been learned (Section 3.2), but not the other way around (Section 3.3).

3.1. MEASURING AGREEMENT

We start by describing how to evaluate the agreement between quantities of interest, such as α and β. For any input-output sentence pair (x m , y m ), for each output index t, α m t , β m t , β px,m t ∈ R L m all associate each input position l with a real number. Since attention weights and word alignment tend to be sparse, we focus on the agreement of the highest-valued position. Suppose u, v ∈ R L , we formally define the agreement of v with u as: A(u, v) := 1[|{j|v j > v arg max ui }| < 5%L], which means "whether the highest-valued position (dimension) in u is in the top 5% highest-valued positions in v". We average the A values across all output words on the validation set to measure the agreement between two model properties. We also report Kendall's τ rank correlation coefficient in Appendix 2 for completeness. We denote its random baseline as Â. Â is close to but not exactly 5% because of integer rounding. Contextualized Agreement Metric. However, since different datasets have different sentence length distributions and variance of attention weights caused by random seeds, it might be hard to directly interpret this agreement metric. Therefore, we contextualize this metric with model performance. We use the standard method to train a model till convergence using T steps and denote its attention weights as α; next we train the same model from scratch again using another random seed. We denote its attention weights at training step τ as α(τ ) and its performance as p(τ ). Roughly speaking, when τ < T , both A(α, α(τ )) and p(τ ) increase as τ increases. We define the contextualized agreement ξ as: ξ(u, v) := p(inf{τ |A(α, α(τ )) > A(u, v)}). In other words, we find the training step τ 0 where its attention weights α(τ 0 ) and the standard attention weights α agrees more than u and v agrees, and report the performance at this iteration. We refer to the model performance when training finishes (τ = T ) as ξ * . Datasets. We evaluate the agreement metrics A and ξ on multiple machine translation and text classification datasets. For machine translation, we use Multi-30k (En-De), IWSLT'14 (De-En), and News Commentary v14 (En-Nl, En-Pt, and It-Pt). For text classification, we use IMDB Sentiment Analysis, AG News Corpus, 20 Newsgroups (20 NG), Stanford Sentiment Treebank, Amazon review, The tasks above the horizontal line are classification and below are translation. The (contextualized) agreement metric A(ξ) is described in Section 3.1. Across all tasks, A(α, β), A(α, β uf ), and A(β uf , β px ) significantly outperform the random baseline Â and the corresponding contextualized interpretations ξ are also non-trivial. This implies that 1) the proxy model from Section 2.3 approximates well how KTIW is learned, 2) attention weights α and the probe β of KTIW strongly agrees, and 3) KTIW can still be learned when the attention weights are uniform. Task A(α, β uf ) A(β uf , β px ) A(∆, β uf ) A(α, β) Â ξ(α, β uf ) ξ(α, and Yelp Open Data Set. All of them are in English. The details and citations of these datasets can be seen in the Appendix A.5. We use token accuracyfoot_0 to evaluate the performance of translation models and accuracy to evaluate the classification models. Due to space limit we round to integers and include a subset of datasets in Table 1 for the main paper. Appendix Table 4 includes the full results.

3.2. KTIW LEARNS UNDER UNIFORM ATTENTION

Even when the attention mechanism has not been learned, KTIW can still be learned. We train the same model architecture with the attention weights frozen to be uniform, and denote its lexical probe as β uf . Across all tasks, A(α, β uf ) and A(β uf , β px )foot_1 significantly outperform the random baseline Â, and the contextualized agreement ξ(α, β uf ) is also non-trivial. This indicates that 1) the proxy we built in Section 2.3 approximates KTIW and 2) even when the attention weights are uniform, KTIW is still learned.

3.3. ATTENTION FAILS WHEN KTIW FAILS

We consider a simple task of copying from the input to the output, and each input is a permutation of the same set of 40 vocab types. Under this training distribution, the proxy model provably cannot learn: every input-output pair contains the exact same set of input-output words. 3 As a result, our framework predicts that KTIW is unlikely to be learned, and hence the learning of attention is likely to fail. The training curves of learning to copy the permutations are in Figure 2 left, colored in red: the model sometimes fails to learn. For the control experiment, if we randomly sample and permute 40 vocabs from 60 vocab types as training samples, the model successfully learns (blue curve) from this distribution every time. Therefore, even if the model is able to express this task, it might fail to learn it when KTIW is not learned. The same qualitative conclusion holds for the training distribution that mixes permutations of two disjoint sets of words (Figure 2 middle), and Appendix A.3 illustrates the intuition. For binary classification, it follows from the model definition that attention mechanism cannot be learned if KTIW cannot be learned, since p correct = σ( L l=1 α l σ -1 (β l )); σ(x) = 1 1 + e -x , and the model needs to attend to positions with higher β, in order to predict correctly and minimize the loss. For completeness, we include results where we freeze β and find that the learning of the attention fails in Appendix A.6. 

4. CONNECTION TO IBM MODEL 1

Section 2.3 built a simple proxy model to approximate how KTIW is learned when the attention weights are uniform early on in training, and Section 3.2 verified that such an approximation is empirically sound. However, it is still hard to intuitively reason about how this proxy model learns. This section provides more intuitions by connecting its initial gradient (Section 4.1) to the classical IBM Model 1 alignment algorithm Brown et al. (1993) (Section 4.2).

4.1. DERIVATIVE AT INITIALIZATION

We continue from the end of Section 2.3. For each input word i and output word o, we are interested in understanding the probability that i assigns to o, defined as: θ px i,o := σ(W e i ) o . ( ) This quantity is directly tied to β px , since β px t,l = θ px x l ,yt . Using super-script m to index sentence pairs in the dataset, the total loss L is: L = - m T m t=1 log(σ( 1 L m L m l=1 W e x m l ) y m t ). Suppose each e i or W o is independently initialized from a normal distribution N (0, I d /d) and we minimize L over W and e using gradient flow, then the value of e and W are uniquely defined for each continuous time step τ . By some straightforward but tedious calculations (details in Appendix A.2), the derivative of θ i,o when the training starts is: lim d→∞ ∂θ px i,o ∂τ (τ = 0) p → 2(C px i,o - 1 |O| o ′ ∈O C px i,o ′ ). ( ) where p → means convergence in probability and C px i,o is defined as C px i,o := m L m l=1 T m t=1 1 L m 1[x m l = i]1[y m t = o]. Equation 12 tells us that β px t,l = θ px x l ,yt is likely to be larger if C x l ,yt is large. The definition of C seems hard to interpret from Equation 13, but in the next subsection we will find that this quantity naturally corresponds to the "count table" used in the classical IBM 1 alignment learning algorithm.

4.2. IBM MODEL 1 ALIGNMENT LEARNING

The classical alignment algorithm aims to learn which input word is responsible for each output word (e.g. knowing that y 2 "movie" aligns to x 2 "Film" in Figure 1 upper left), from a set of input-output sentence pairs. IBM Model 1 Brown et al. (1993) starts with a 2-dimensional count table C IBM indexed by i ∈ I and o ∈ O, denoting input and output vocabs. Whenever vocab i and o co-occurs in an input-output pair, we add 1 L to the C IBM i,o entry (step 1 and 2 in Figure 1 right). After updating C IBM for the entire dataset, C IBM is exactly the same as C px defined in Equation 13. We drop the super-script of C to keep the notation uncluttered. Given C, the classical model estimates a probability distribution of "what output word o does the input word i translate to" (Figure 1 right step 3) as Trans(o|i) = C i,o o ′ C i,o ′ . ( ) In a pair of sequences ({x l }, {y t }), the probability β IBM that x l is translated to the output y t is: β IBM t,l := Trans(y t |x l ), and the alignment probability α IBM that "x l is responsible for outputting y t versus other x l ′ " is α IBM (t, l) = β IBM t,l L l ′ =1 β IBM t,l ′ , ( ) which monotonically increases with respect to β IBM t,l . See Figure 1 right step 5.

4.3. VISUALIZING AFOREMENTIONED TASKS

Figure 1 (right) visualizes the count table C for the machine translation task, and illustrates how KTIW is learned and drives the learning of attention. We provide similar visualization for why KTIW is hard to learn under a distribution of vocab permutations (Section 3.3) in Figure 3 , and how word polarity is learned in binary classification (Section 2.4) in Figure 4 . 

5.1. INTERPRETABILITY IN CLASSIFICATION

We use gradient based method Ebrahimi et al. (2018) to approximate the influence ∆ l for each input word x l . The column A(∆, β uf ) reports the agreement between ∆ and β uf , and it significantly outperforms the random baseline. Since KTIW initially drives the attention mechanism to be learned, this explains why attention weights are correlated with word saliency on many classification tasks, even though the training objective does not explicitly reward this. 

5.2. MULTI-HEAD IMPROVES TRAINING DYNAMICS

We saw in Section 3.3 that learning to copy sequences under a distribution of permutations is hard and the model can fail to learn; however, sometimes it is still able to learn. Can we improve learning and overcome this hard distribution by ensembling several attention parameters together? We introduce a multi-head attention architecture by summing the context vector c t obtained by each head. Suppose there are K heads each indexed by k, similar to Section 2.1: a (k) t,l = s T t W (k) h l ; α (k) t,l = e a (k) t,l L l ′ =1 e α (k) t,l ′ , ( ) and the context vector and final probability p t defined as: c (k) t = L l=1 α (k) t,l h l ; p t = N ([ K k=1 c (k) t , d t ]), where W (k) are different learn-able parameters. We call W (k) init a good initialization if training with this single head converges, and bad otherwise. We use rejection sampling to find good/bad head initializations and combine them to form 8-head (K = 8) attention models. We experiment with 3 scenarios: (1) all head initializations are bad, (2) only one initialization is good, and (3) initializations are sampled independently at random. Figure 2 right presents the training curves. If all head initializations are bad, the model fails to converge (red). However, as long as one of the eight initializations is good, the model can converge (blue). As the number of heads increases, the probability that all initializations are bad is exponentially small if all initializations are sampled independently; hence the model converges with very high probability (green). In this experiment, multi-head attention improves not by increasing expressiveness, since one head is sufficient to accomplish the task, but by improving the learning dynamics.

6. ASSUMPTIONS

We revisit the approximation assumptions used in our framework. Section 6.1 discusses whether the lexical probe β t,l necessarily reflects local information about input word x l , and Section 6.2 discusses whether attention weights can be freely optimized to attend to large β. These assumptions are accurate enough to predict phenomenon in Section 3 and 5, but they are not always true and hence warrant more future researches. We provide simple examples where these assumptions might fail.

6.1. β REMAINS LOCAL

We use a toy classification task to show that early on in training, expectantly, β uf is larger near positions that contain the keyword. However, unintuitively, β uf L (β at the last position in the sequence) will become the largest if we train the model for too long under uniform attention weights. In this toy task, each input is a length-40 sequence of words sampled from {1, . . . , 40} uniformly at random; a sequence is positive if and only if the keyword "1" appears in the sequence. We restrict "1" to appear only once in each positive sequence, and use rejection sampling to balance positive and negative examples. Let l * be the position where x l * = 1. For the positive sequences, we examine the log-odd ratio γ l before the sigmoid activation in Equation 5, since β will be all close to 1 and comparing γ would be more informative: γ l := log β uf l 1-β uf l . We measure four quantities: 1) γ l * , the log-odd ratio if the model only attends to the key word position, 2) γ l * +1 , one position after the key word position, 3) γ := L l=1 γ l L , if attention weights are uniform, and 4) γ L if the model attends to the last hidden state. If the γ l only contains information about word x l , we should expect: Hypothesis 1 : γ l * ≫ γ ≫ γ L ≈ γ l * +1 . However, if we accept the conventional wisdom that hidden states contain information about nearby words Khandelwal et al. (2018) , we should expect: Hypothesis 2 : γ l * ≫ γ l * +1 ≫ γ ≈ γ L . ( ) To verify these hypotheses, we plot how γ l * , γ l * +1 , γ, and γ L evolve as training proceeds in Figure 5 . Hypothesis 2 is indeed true when training starts; however, we find the following to be true asymptotically: Observation 3 : γ L ≫ γ l * +1 ≫ γ ≈ γ l * . ( ) which is wildly different from Hypothesis 2. If we train under uniform attention weights for too long, the information about keywords can freely flow to other non-local hidden states. 19) is true; however, asymptotically, Oberservation 3 (Equation 21) is true.

6.2. ATTENTION WEIGHTS ARE FREE VARIABLES

In Section 2.1 we assumed that attention weights α behave like free variables that can assign arbitrarily high probabilities to positions with larger β. However, α is produced by a model, and sometimes learning the correct α can be challenging. Let π be a random permutation of integers from 1 to 40, and we want to learn the function f that permutes the input with π: f ([x 1 , x 2 , . . . x 40 ]) := [x π(1) , x π(2) . . . x π(40) ]. (22) Input x are randomly sampled from a vocab of size 60 as in Section 3.3. Even though β uf behaves exactly the same for these two tasks, sequence copying is much easier to learn than permutation function: while the model always reaches perfect accuracy in the former setting within 300 iterations, it always fails in the latter. LSTM has a built-in inductive bias to learn monotonic attention.

7. CONCLUSIONS

Our work tries to understand the black box of attention training. Early on in training, the LSTM attention models first learn how to translation individual words from bag of words co-occurrence statistics, which then drives the learning of the attention. Our framework explains why attention weights obtained by standard training often correlate with saliency, and how multi-head attention can increase performance by improving the training dynamics rather than expressiveness. These phenomena cannot be explained if we treated the training process as a black box.

8. ETHICAL CONSIDERATIONS

We present a new framework for understanding and predicting behaviors of an existing technology: the attention mechanism in recurrent neural networks. We do not propose any new technologies or any new datasets that could directly raise ethical questions. However, it is useful to keep in mind that our framework is far from solving the question of neural network interpretability, and should not be interpreted as ground truth in high stake domains like medicine or recidivism. We are aware and very explicit about the limitations of our framework, which we made clear in Section 6.

9. REPRODUCABILITY STATEMENT

To promote reproducibility, we provide extensive results in the appendix and describe all experiments in detail. We also attach source code for reproducing all experiments to the supplemental of this submission.

A APPENDICES

A.1 HEURISTIC THAT α ATTENDS TO LARGER β It is a heuristic rather than a rigorous theorem that attention α is attracted to larger β. There are two reasons. First, there is a non-linear layer after the averaging the hidden states, which can interact in an arbitrarily complex way to break this heuristic. Second, even if there are no non-linear operations after hidden state aggregation, the optimal attention that minimizes the loss does not necessarily assign any probability to the position with the largest β value when there are more than two output vocabs. Specifically, we consider the following model: p t = σ(W c l=1 α t,l h l + W s s t ) = σ( l=1 α t,l γ l + γ s ), ( ) where W c and W s are learnable weights, and γ defined as: γ l := W c h l ; γ s := W s s t ⇒ β t,l = σ(γ l + γ s ) yt . ( ) Consider the following scenario that outputs a probability distribution p over 3 output vocabs and γ s is set to 0: p = σ(α 1 γ 1 + α 2 γ 2 + α 3 γ 3 ), ( ) where γ l=1,2,3 ∈ R |O|=3 are the logits, α is a valid attention probability distribution, σ is the softmax, and p is the probability distribution produced by this model. Suppose γ 1 = [0, 0, 0], γ 2 = [0, -10, 5], γ 3 = [0, 5, -10] and the correct output is the first output vocab (i.e. the first dimension). Therefore, we take the softmax of γ l and consider the first dimension: β l=1 = 1 3 > β l=2 = β l=3 ≈ e -5 . ( ) We calculate "optimal α" α opt : the optimal attention weights that can maximize the correct output word probability p 0 and minimize the loss. We find that α opt 2 = α opt 3 = 0.5, while α opt 1 = 0. In this example, the optimal attention assigns 0 weight to the position l with the highest β l . Fortunately, such pathological examples rarely occur in real datasets, and the optimal α are usually attracted to positions with higher β. We empirically verify this for the below variant of machine translation model on Multi30K. As before, we obtain the context vector c t . Instead of concatenating c t and d t and pass it into a non-linear neural network N , we add them and apply a linear layer with softmax after it to obtain the output word probability distribution p t = σ(W (c t + d t )). ( ) This model is desirable because we can now provably find the optimal α using gradient descent (we delay the proof to the end of this subsection). Additionally, this model has comparable performance with the variant from our main paper (Section 2.1), achieving 38.2 BLEU score, vs. 37.9 for the model in our main paper. We use α opt to denote the attention that can minimize the loss, and we find that A(α opt , β) = 0.53. β do strongly agree with α opt . Now we are left to show that we can use gradient descent to find the optimal attention weights to minimize the loss. We can rewrite p t as p t = σ( L l=1 α l W h l + W d t ). ( ) We define γ l := W h l ; γ s := W d t . ( ) Under review as a conference paper at ICLR 2023 Without loss of generality, suppose the first dimension of γ 1...L , γ s are all 0, and the correct token we want to maximize probability for is the first dimension, then the loss for the output word is L = log(1 + g(α)), where g(α) := o∈O,o̸ =0 e α T γ ′ o +γs,o , where γ ′ o = [γ 1,o . . . γ l,o . . . γ L,o ] ∈ R L . ( ) Since α is defined within the convex probability simplex and g(α) is convex with respect to α, the global optima α opt can be found by gradient descent.

A.2 CALCULATING ∂θi,o ∂τ

We drop the px super-script of θ to keep the notation uncluttered. We copy the loss function here to remind the readers: L = - m T m t=1 log(σ( 1 L m L m l=1 W e x m l ) y m t ). ( ) and since we optimize W and e with gradient flow, ∂W ∂τ := - L ∂W ; ∂e ∂τ := - L ∂e . ( ) We first define the un-normalized logits γ and then take the softmax. We define θ = W e, p m := σ( 1 L m L m l=1 W e x m l ) and L m t := -log(p m y m t ); ϵ m t,i,o := W o ∂L m t ∂e i . ( ) Hence, L = m T m t=1 L m t ; ϵ i,o = m T m t=1 ϵ m t,i,o . Therefore, - ∂L m t ∂e i = 1 L m L m l=1 1[x m l = i](W y m t - |O| o=1 p m o W o ). Hence, ϵ m t,i,y m t = -W T y m t ∂L m t ∂e i = 1 L m L m l=1 1[x m l = i] (42) (||W y m t || 2 2 - |O| o=1 p m o W T y m t W o ), while for o ′ ̸ = y m t , ϵ m t,i,o ′ = -W T o ′ ∂L m t ∂e i = 1 L m L m l=1 1[x m l = i] (W T o ′ W y m t - |O| o=1 p m o W T o ′ W o ). If W o and e i are each sampled i.i.d. from N (0, I d /d), then by central limit theorem: ∀o ̸ = o ′ , √ dW T o W o ′ p → N (0, 1), ∀o, i, √ dW T o e i p → N (0, 1), √ d(||W o || 2 2 -1) p → N (0, 2). (46) Therefore, when τ = 0, lim d→∞ ϵ m t,i,o p → 1 L m L m l=1 1[x m l = i](1[y t l = o] - 1 |O| ). Summing over all the ϵ m t,i,o terms, we have that ϵ i,o = C i,o - 1 |O| o ′ C i,o ′ , where C is defined as C i,o := m L m l=1 T m t=1 1 L m 1[x m l = i]1[y m t = o]. We find that -∂W ∂τ e converges exactly to the same value. Hence ∂ θi,o ∂τ = ∂W e ∂τ = 2(C i,o - 1 |O| o ′ C i,o ′ ). Since lim d→∞ θ(τ = 0) p → 1 |O| 1 |I|×|O| , by chain rule, lim d→∞ ∂γ i,o ∂τ (τ = 0) p → 2(C i,o - 1 |O| o ′ ∈O C i,o ′ ).

A.3 MIXTURE OF PERMUTATIONS

For this experiment, each input is either a random permutation of the set {1 . . . 40}, or a random permutation of the set {41 . . . 80}. The proxy model can easily learn whether the input words are less than 40 and decide whether the output words are all less than 40. However, β px is still the same for every position; as a result, the attention and hence the model fail to learn. The count table C can be see in Figure 6 .

A.4 ADDITIONAL TABLES FOR COMPLETENESS

We report several variants of Table 1 . We chose to use token accuracy to contextualize the agreement metric in the main paper, because the errors would accumulate much more if we use a not-fully trained model to auto-regressively generate output words. • Table 2 contains the same results as Table 1 , except that its agreement score A(u, v) is now Kendall Tau rank correlation coefficient, which is a more popular metric. Alignment α has no preference over any of these words, since the probabilities are uniform over the input words "A", "B", "C", "D".

Figure 6:

The training distributions mixes random permutation of disjoint set of words (left and right, respectively). From the count table, β px could learn that the set of input words {A, B, C, D} corresponds to the set of output words {A ′ , B ′ , C ′ , D ′ }, but its β value for each input position is still uniformly 0.25. 2 except with performance measured by bleu rather than token accuracy. Section A.4 Task A(α, β uf ) A(β uf , β px ) A(∆, β uf ) Â IMDB • Table 4 contains the same results as Table 1 , except that results are now rounded to two decimal places. • Table 6 consists of the same results as Table 1 , except that the statistics is calculated over the training set rather than the validation set. • Table 3 , Table 5 , and Table 7 contain the translation results from the above 3 mentioned tables respectively, except that p is defined as BLEU score rather than token accuracy, and hence the contextualized metric interpretation ξ changes correspondingly.

A.5 DATASET DESCRIPTION

We summarize the datasets that we use for classification and machine translation. See 2015) 120,000 news articles and their corresponding topic (world, sports, business, or science/tech). We classify between the world and business articles. 20 Newsgroupsfoot_4 A news data set containing around 18,000 newsgroups articles split between 20 different labeled categories. We classify between baseball and hocky articles. Stanford Sentiment Treebank Socher et al. (2013) A data set for classifying the sentiment of movie reviews, labeled on a scale from 1 (negative) to 5 (positive). We remove all movies labeled as 3, and classify between 4 or 5 and 1 or 2. Multi Domain Sentiment Data setfoot_5 Approximately 40,000 Amazon reviews from various product categories labeled with a corresponding positive or negative label. Since some of the sequences are particularly long, we only use sequences of length less than 400 words. Yelp Open Data Setfoot_6 20,000 Yelp reviews and their corresponding star rating from 1 to 5. We classify between reviews with rating ≤ 2 and ≥ 4. Multi-30k Elliott et al. (2016) English to German translation. The data is from translation image captions. IWSLT'14 Cettolo et al. (2015) German to English translation. The data is from translated TED talk transcriptions. News Commentary v14 Cettolo et al. (2015) A collection of translation news commentary datasets in different languages from WMT19foot_7 . We use the following translation splits: English-Dutch (En-Nl), English-Portuguese (En-Pt), and Italian-Portuguese (It-Pt). In pre-processing for this dataset, we removed all purely numerical examples. A.6 α FAILS WHEN β IS FROZEN For each classification task we initialize a random model and freeze all parameters except for the attention layer (frozen β model). We then compute the correlation between this trained attention (defined as α fr ) and the normal attention α. Table 9 reports this correlation at the iteration where α fr is most correlated with α on the validation set. As shown in Table 9 , the left column is consistently lower than the right column. This indicates that the model can learn output relevance without attention, but not vice versa.

A.7 TRAINING β uf

We find that A(α, β uf (τ )) first increases and then decreases as training proceeds (i.e. τ increases), so we chose the maximum agreement to report in  β log x l ). Here β log i indicates the ith column of β log ; these are the entries in β log corresponding to predictions for the ith word in the vocab. Now it is easy to arrive at the equivalence between logistic regression and our proxy model. If we restrict the rank of β log to be at most min(d, |O|, |I|) by factoring it as β log = W E where W ∈ R |O|×d and E ∈ R d×|I| , then the logistic regression looks like: ∀t, p t = σ( 1 L L l=1 W E x l ), which is equivalent to our proxy model: ∀t, p t = σ( 1 L L l=1 W e x l ). Since d = 256 for the proxy model, which is larger than |O| = 2 in the classification case, the proxy model is not rank limited and is hence fully equivalent to the logistic regression model. Therefore the β px can be interpreted as "keywords" in the same way that the logistic regression weights can. To empirically verify this equivalence, we trained a logistic regression model with ℓ2 regularization on each of our classification datasets. To pick the optimal regularization level, we did a sweep of regularization coefficients across ten orders of magnitude and picked the one with the best validation accuracy. We report results for A(β uf , β log ) in comparison to A(β uf , β px ) in Table 10 foot_9 . Note that these numbers are similar but not exactly equivalent. The reason is that the proxy model did not use ℓ2 regularization, while logistic regression did.



Appendix Tables 5, 3, and 7 include results for BLEU. Empirically, β px converges to the unigram weight of a bag-of-words logistic regression model, and hence β px does capture an interpretable notion of "keywords".(Appendix A.10.) We provide more intuitions on this in Section http://qwone.com/ jason/20Newsgroups/ https://www.cs.jhu.edu/ mdredze/datasets/sentiment/ https://www.yelp.com/dataset http://www.statmt.org/wmt19/translation-task.html http://www.dt.fee.unicamp.br/ tiago/smsspamcollection/ These numbers were obtained from a retrain of all the models in the main table, so for instance, the LSTM model used to produce β uf might not be exactly the same as the one used for the results in all the other tables due to random seed difference.



| Dieser) = .04 Trans(movie | ist) = .04 Trans(movie | grobartig) = .04 Trans(movie | schlecht) = .03 β 2,1 = .04 β 2,2 = .32 β 2,3 = .04 β 2,4 = .03

Figure 2: Each curve represents accuracy on the test distribution vs. number of training steps for different random seeds (20 each). Left and Middle are accuracy curves for single head attention models. When trained on a distribution of permutation of 40 vocabs (red) (Left) or a mixture of permutations (Middle), the model sometimes fails to learn and converges slower. The right figure is for multi-head attention experiments. If all head initializations (head-init) are bad (red), the model is likely to fail; if one of the head-init is good (blue), it is likely to learn; with high chance, at least one out of eight random head-init is good (green).

Alignment α has no preference over any of these words, since the probabilities are uniform.

Figure 3: Co-occurrence table C is non-informative under a distribution of permutations. Therefore, this distribution is hard for the attention-based model to learn.

Figure 4: The classical model first learns word polarity, which later attracts attention.

Figure 5: When training begins, Hypothesis 2 (Equation19) is true; however, asymptotically, Oberservation 3 (Equation21) is true.

ϵ := -W ∂e ∂τ . Since ϵ ∈ R |I|×|O| , we analyze each entry ϵ i,o . Since differentiation operation and left multiplication by matrix W is linear, we analyze each individual loss term in Equation 34 and then sum them up.







Table 1 except with agreement defined by Kendall Tau. Section A.4 Task A(α, β uf ) A(β uf , β px ) A(∆, β uf )

translation results from Table

Table 8 for details on train/test splits and median sequence lengths for each dataset.TaskA(α, β uf ) A(β uf , β px ) A(∆, β uf ) A(α, β) translation results from Table4except with performance measured by bleu rather than token accuracy. Section A.4TaskA(α, β uf ) A(β uf , β px ) A(∆, β uf ) A(α, β)

Table 1 except with correlations and performance metrics taken over the training set instead of the validation set. Section A.4

translation results from TableTable 6 except with performance measured by bleu rather than token accuracy. Section A.4 IMDB Sentiment Analysis Maas et al. (2011) A sentiment analysis data set with 50,000 (25,000 train and 25,000 test) IMDB movie reviews and their corresponding positive or negative sentiment. AG News Corpus Zhang et al. (

Table1over the course of training. Since this trend is consistent across all datasets, our choice minimally inflates the agreement measure, and is comparable to the practice of reporting dev set results. As discussed in Section 6.1, training under uniform attention for too long might bring unintuitive results,A.8 MODEL AND TRAINING DETAILSClassification Our model uses dimension 300 GloVe-6B pre-trained embeddings to initialize the token embeddings where they aligned with our vocabulary. The sequences are encoded with a 1 layer bidirectional LSTM of dimension 256. The rest of the model, including the attention mechanism, is exactly as described in 2.4. Our model has 1,274,882 parameters excluding embeddings. Since each classification set has a different vocab size each model has a slightly different parameter count when considering embeddings:19,376,282 for IMDB, 10,594,382 for AG News, 5,021,282  for 20 statistics for each dataset. Median sequence length in the training set and train set size. Note: src refers to the input "source" sequence, and trg refers to the output "target" sequence. Section A.5

We report the correlation between α fr and α on classification datasets, and compare it against A(α, β uf ), the same column defined in Table1.Section A.6   Newsgroups, 4,581,482 for SST, 13,685,282 for Yelp, 12,407,882  for Amazon, and 2,682,182 for SMS.Translation We use a a bidirectional two layer bi-LSTM of dimension 256 to encode the source and the use last hidden state h L as the first hidden state of the decoder. The attention and outputs are then calculated as described in 2. The learn-able neural network before the outputs that is mentioned in Section 2, is a 1 hidden layer model with ReLU non-linearity. The hidden layer is dimension 256. Our model contains 6,132,544 parameters excluding embeddings and 8,180,544 including embeddings on all datasets.Permutation Copying We use single directional single layer LSTM with hidden dimension 256 for both the encoder and the decoder.Classification Procedure For all classification datasets we used a batch size of 32. We trained for 4000 iterations on each dataset. For each dataset we train on the pre-defined training set if the dataset has one. Additionally, if a dataset had a predefined test set, we randomly sample at most 4000 examples from this test set for validation. Specific dataset split sizes are given in Table8.TaskA(β uf , β px ) A(β uf , β log )

we report A(β uf , β log ) to demonstrate its effective equivalence to A(β uf , β px ). These values are not exactly the same due to differences in regularization strategies.

annex

Xiang Zhang, Junbo Zhao, and Yann LeCun. Character-level convolutional networks for text classification. In Advances in neural information processing systems, pp. 649-657, 2015.Classification Evaluation We evaluated each model at steps 0, 10, 50, 100, 150, 200, 250 , and then every 250 iterations after that.Classification Tokenization We tokenized the data at the word level. We mapped all words occurring less than 3 times in the training set to <unk>. For 20 Newsgroups and AG News we mapped all non-single digit integer "words" to <unk>. For 20 Newsgroups we also split words with the " " character.

Classification Training

We trained all classification models on a single GPU. Some datasets took slightly longer to train than others (largely depending on average sequence length), but each train took at most 45 minutes.Translation Hyper Parameters For translation all hidden states in the model are dimension 256.We use the sequence to sequence architecture described above. The LSTMs used dropout 0.5.Translation Procedure For all translation tasks we used batch size 16 when training. For IWSLT'14 and Multi-30k we used the provided dataset splits. For the News Commentary v14 datasets we did a 90-10 split of the data for training and validation respectively.Translation Evaluation We evaluated each model at steps 0, 50, 100, 500, 1000, 1500, and then every 2000 iterations after that.

Translation Training

We trained all translation models on a single GPU. IWSLT'14, and the News Commentary datasets took approximately 5-6 hours to train, and multi-30k took closer to 1 hour to train.Translation Tokenization We tokenized both translation datasets using the Sentence-Piece tokenizer trained on the corresponding train set to a vocab size of 8,000. We used a single tokenization for source and target tokens. And accordingly also used the same matrix of embeddings for target and source sequences.A.9 A NOTE ON SMS DATASETIn addition to the classification datasets reported in the tables, we also ran experiments on the SMS Spam Collection V.1 dataset 8 . The attention learned from this dataset was very high variance, and so two different random seeds would consistently produce attentions that did not correlate much. The dataset itself was also a bit of an outlier; it had shorter sequence lengths than any of the other datasets (median sequence length 13 on train and validation set), it also had the smallest training set out of all our datasets (3500 examples), and it had by far the smallest vocab (4691 unique tokens). We decided not to include this dataset in the main paper due to these unusual results and leave further exploration to future works.

A.10 LOGISTIC REGRESSION PROXY MODEL

Our proxy model can be shown to be equivalent to a bag-of-words logistic regression model in the classification case. Specifically, we define a bag-of-words logistic regression model to be:∀t, p t = σ(β log x).(52)where x ∈ R |I| , β log ∈ R |O|×|I| , and σ is the softmax function. The entries in x are the number of times each word occurs in the input sequence, normalized by the sequence length. and β log is learned. This is equivalent to:

