APPROXIMATING HOW SINGLE HEAD ATTENTION LEARNS

Abstract

Why do models often attend to salient words, and how does this evolve throughout training? We approximate model training as a two stage process: early on in training when the attention weights are uniform, the model learns to translate individual input word i to o if they co-occur frequently. Later, the model learns to attend to i while the correct output is o because it knows i translates to o. To formalize, we define a model property, Knowledge to Translate Individual Words (KTIW) (e.g. knowing that i translates to o), and claim that it drives the learning of the attention. This claim is supported by the fact that before the attention mechanism is learned, KTIW can be learned from word co-occurrence statistics, but not the other way around. Particularly, we can construct a training distribution that makes KTIW hard to learn, the learning of the attention fails, and the model cannot even learn the simple task of copying the input words to the output. Our approximation explains why models sometimes attend to salient words, and inspires a toy example where a multi-head attention model can overcome the above hard training distribution by improving learning dynamics rather than expressiveness. We end by discussing the limitation of our approximation framework and suggest future directions.

1. INTRODUCTION

The attention mechanism underlies many recent advances in natural language processing, such as machine translation Bahdanau et al. (2015) and pretraining Devlin et al. (2019) . While many works focus on analyzing attention in already-trained models Jain & Wallace (2019); Vashishth et al. (2019) ; Brunner et al. (2019) ; Elhage et al. (2021) ; Olsson et al. (2022) , little is understood about how the attention mechanism is learned via gradient descent at training time. These learning dynamics are important, as standard, gradient-trained models can have very unique inductive biases, distinguishing them from more esoteric but equally accurate models. For example, in text classification, while standard models typically attend to salient (high gradient influence) words Serrano & Smith (2019) , recent work constructs accurate models that attend to irrelevant words instead Wiegreffe & Pinter (2019); Pruthi et al. (2020) . In machine translation, while the standard gradient descent cannot train a high-accuracy transformer with relatively few attention heads, we can construct one by first training with more heads and then pruning the redundant heads Voita et al. (2019); Michel et al. (2019) . To explain these differences, we need to understand how attention is learned at training time. Our work opens the black box of attention training, focusing on attention in LSTM Seq2Seq models Luong et al. (2015) (Section 2.1). Intuitively, if the model knows that the input individual word i translates to the correct output word o, it should attend to i to minimize the loss. This motivates us to investigate the model's knowledge to translate individual words (abbreviated as KTIW), and we define a lexical probe β to measure this property. We claim that KTIW drives the attention mechanism to be learned. This is supported by the fact that KTIW can be learned when the attention mechanism has not been learned (Section 3.2), but not the other way around (Section 3.3). Specifically, even when the attention weights are frozen to be uniform, probe β still strongly agrees with the attention weights of a standardly trained model. On the other hand, when KTIW cannot be learned, the attention mechanism cannot be learned. Particularly, we can construct a distribution where KTIW is hard to learn; as a result, the model fails to learn a simple task of copying the input to the output. Now the problem of understanding how attention mechanism is learned reduces to understanding how KTIW is learned. Section 2.3 builds a simpler proxy model that approximates how KTIW is learned, and Section 3.2 verifies empirically that the approximation is reasonable. This proxy model is simple enough to analyze and we interpret its training dynamics with the classical IBM Translation Model 1 (Section 4.2), which translates individual word i to o if they co-occur more frequently. To collapse this chain of reasoning, we approximate model training in two stages. Early on in training when the attention mechanism has not been learned, the model learns KTIW through word co-occurrence statistics; KTIW later drives the learning of the attention. Using these insights, we explain why attention weights sometimes correlate with word saliency in binary text classification (Section 5.1): the model first learns to "translate" salient words into labels, and then attend to them. We also present a toy experiment (Section 5.2) where multi-head attention improves learning dynamics by combining differently initialized attention heads, even though a single head model can express the target function. Nevertheless, "all models are wrong". Even though our framework successfully explains and predicts the above empirical phenomena, it cannot fully explain the behavior of attention-based models, since approximations are after all less accurate. Section 6 identifies and discusses two key assumptions: (1) information of a word tends to stay in the local hidden state (Section 6.1) and ( 2) attention weights are free variables (Section 6.2). We discuss future directions in Section 7.

2. MODEL

Section 2.1 defines the LSTM with attention Seq2Seq architecture. Section 2.2 defines the lexical probe β, which measures the model's knowledge to translate individual words (KTIW). Section 2.3 approximates how KTIW is learned early on in training by building a "bag of words" proxy model. Section 2.4 shows that our framework generalizes to binary classification. (1) We count all cooccurrences of the input and output words. (2) Create counts for all input-output sentence pairs and add them. (5) Alignment α: how much each input word contributes towards the 2nd output word "movie". It is attracted to "Film".  α 2,1 = Alignment(t = 2, l = 2) = β 2,1 ∑ 4 l= 1 β 2,l

βt=2

(2) "Film" is more likely to translate to "movie". (1) The model first learns word translation under uniform attention when training starts. (3) Attention α is then attracted to the word "Film". Dieser Film ist großartig movie 0.01 0.60 0.02 0.03

Dieser Film ist großartig

This movie is great h4 h3 h2 h1 (4) "Film" is more likely to translate to "movie".

Attention-based Model Learning Dynamics

Classical Alignment Learning Procedure 



| Dieser) = .04 Trans(movie | ist) = .04 Trans(movie | grobartig) = .04 Trans(movie | schlecht) = .03 β 2,1 = .04 β 2,2 = .32 β 2,3 = .04 β 2,4 = .03

Figure 1: Attention mechanism in recurrent models (left, Section 2.1) and word alignments in the classical model (right, Section 4.2) are learned similarly. Both first learn how to translate individual words (KTIW) under uniform attention weights/alignment at the start of training (upper, blue background), which then drives the attention mechanism/alignment to be learned (lower, red background).





