APPROXIMATING HOW SINGLE HEAD ATTENTION LEARNS

Abstract

Why do models often attend to salient words, and how does this evolve throughout training? We approximate model training as a two stage process: early on in training when the attention weights are uniform, the model learns to translate individual input word i to o if they co-occur frequently. Later, the model learns to attend to i while the correct output is o because it knows i translates to o. To formalize, we define a model property, Knowledge to Translate Individual Words (KTIW) (e.g. knowing that i translates to o), and claim that it drives the learning of the attention. This claim is supported by the fact that before the attention mechanism is learned, KTIW can be learned from word co-occurrence statistics, but not the other way around. Particularly, we can construct a training distribution that makes KTIW hard to learn, the learning of the attention fails, and the model cannot even learn the simple task of copying the input words to the output. Our approximation explains why models sometimes attend to salient words, and inspires a toy example where a multi-head attention model can overcome the above hard training distribution by improving learning dynamics rather than expressiveness. We end by discussing the limitation of our approximation framework and suggest future directions.

1. INTRODUCTION

The attention mechanism underlies many recent advances in natural language processing, such as machine translation Bahdanau et al. (2015) and pretraining Devlin et al. (2019) . While many works focus on analyzing attention in already-trained models Jain & Wallace ( 2019 We claim that KTIW drives the attention mechanism to be learned. This is supported by the fact that KTIW can be learned when the attention mechanism has not been learned (Section 3.2), but not the other way around (Section 3.3). Specifically, even when the attention weights are frozen to be uniform, probe β still strongly agrees with the attention weights of a standardly trained model. On the other hand, when KTIW cannot be learned, the attention mechanism cannot be learned. Particularly, we can construct a distribution where KTIW is hard to learn; as a result, the model fails to learn a simple task of copying the input to the output.



); Vashishth et al. (2019); Brunner et al. (2019); Elhage et al. (2021); Olsson et al. (2022), little is understood about how the attention mechanism is learned via gradient descent at training time. These learning dynamics are important, as standard, gradient-trained models can have very unique inductive biases, distinguishing them from more esoteric but equally accurate models. For example, in text classification, while standard models typically attend to salient (high gradient influence) words Serrano & Smith (2019), recent work constructs accurate models that attend to irrelevant words instead Wiegreffe & Pinter (2019); Pruthi et al. (2020). In machine translation, while the standard gradient descent cannot train a high-accuracy transformer with relatively few attention heads, we can construct one by first training with more heads and then pruning the redundant heads Voita et al. (2019); Michel et al. (2019). To explain these differences, we need to understand how attention is learned at training time. Our work opens the black box of attention training, focusing on attention in LSTM Seq2Seq models Luong et al. (2015) (Section 2.1). Intuitively, if the model knows that the input individual word i translates to the correct output word o, it should attend to i to minimize the loss. This motivates us to investigate the model's knowledge to translate individual words (abbreviated as KTIW), and we define a lexical probe β to measure this property.

