ON THE DYNAMICS OF TRAINING ATTENTION MODELS

Abstract

The attention mechanism has been widely used in deep neural networks as a model component. By now, it has become a critical building block in many state-of-the-art natural language models. Despite its great success established empirically, the working mechanism of attention has not been investigated at a sufficient theoretical depth to date. In this paper, we set up a simple text classification task and study the dynamics of training a simple attention-based classification model using gradient descent. In this setting, we show that, for the discriminative words that the model should attend to, a persisting identity exists relating its embedding and the inner product of its key and the query. This allows us to prove that training must converge to attending to the discriminative words when the attention output is classified by a linear classifier. Experiments are performed, which validate our theoretical analysis and provide further insights.

1. INTRODUCTION

Attention-based neural networks have been broadly adopted in many natural language models for machine translation (Bahdanau et al., 2014; Luong et al., 2015) , sentiment classification (Wang et al., 2016) , image caption generation (Xu et al., 2015) , and the unsupervised representation learning (Devlin et al., 2019) , etc. Particularly in the powerful transformers (Vaswani et al., 2017) , attention is its key ingredient. Despite its great successes established empirically, the working mechanism of attention has not been well understood (see Section 2). This paper sets up a simple text classification task and considers a basic neural network model with the most straightforward attention mechanism. We study the model's training trajectory to understand why attention can attend to the discriminative words (referred to as the topic words). More specifically, in this task, each sentence is treated as a bag of words, and its class label, or topic, is indicated by a topic word. The model we consider involves a basic attention mechanism, which creates weighting factors to combine the word embedding vectors into a "context vector"; the context vector is then passed to a classifier. In this setting, we prove a closed-form relationship between the topic word embedding norm and the inner product of its key and the query, referred to as the "score", during gradient-descent training. It is particularly remarkable that this relationship holds irrespective of the classifier architecture or configuration. This relationship suggests the existence of a "synergy" in the amplification of the topic word score and its word embedding; that is, the growths of the two quantities promote each other. This, in turn, allows the topic word embedding to stand out rapidly in the context vector during training. Moreover, when the model takes a fixed linear classifier, this relationship allows rigorous proofs of this "mutual promotion" phenomenon and the convergence of training to the topic words. Our theoretical results and their implications are corroborated by experiments performed on a synthetic dataset and real-world datasets. Additional insights are also obtained from these experiments. For example, low-capacity classifiers tend to give stronger training signals to the attention module. The "mutual promotion" effect implied by the discovered relationship can also exhibit itself as "mutual suppression" in the early training phase. Furthermore, in the real-world datasets, where perfect Published as a conference paper at ICLR 2021 delimitation of topic and non-topic words does not exist, interesting training dynamics is observed. Due to length constraints, all proofs are presented in Appendix.

2. RELATED WORKS

Since 2019, a series of works have been published to understand the working and behaviour of attention. One focus of these works pertains to understanding whether an attention mechanism can provide meaningful explanations (Michel et al., 2019; Voita et al., 2019; Jain & Wallace, 2019; Wiegreffe & Pinter, 2019; Serrano & Smith, 2020; Vashishth et al., 2020) . Most of these works are empirical in nature, for example, by analyzing the behaviours of a well-trained attention-based model (Clark et al., 2019) , or observing the impact of altering the output weights of the attention module or pruning a few heads (Michel et al., 2019; Voita et al., 2019) , or a combination of them (Jain & Wallace, 2019; Vashishth et al., 2020) . Apart from acquiring insights from experiments, Brunner et al. ( 2019) and Hahn (2020) show theoretically that the self-attention blocks lacks identifiability, where multiple weight configurations may give equally good end predictions. The non-uniqueness of the attention weights therefore makes the architecture lack interpretability. As a fully connected neural network with infinite width can be seen as a Gaussian process (Lee et al., 2018) , a few works apply this perspective to understanding attention with infinite number of heads and infinite width of the network layers (Yang, 2019; Hron et al., 2020) . In this paper, we restrict our study to the more realist non-asymptotic regime.

3. PROBLEM SETUP

Learning Task To obtain insights into the training dynamics of attention models, we set up a simple topic classification task. Each input sentence contains m non-topic words and one topic word indicating its topic. Note that a topic may have multiple topic words, but a sentence is assumed to include only one of them. Assume that there are J topics that correspond to the mutually exclusive topic word sets T 1 , T 2 , • • • , T J . Let T = J j=1 T j be the set of all topic words. The non-topic words are drawn from a dictionary Θ, which are assumed not to contain any topic word. The training set Ψ consists of sentence-topic pairs, where each pair (χ, y) is generated by (1) randomly pick a topic y ∈ {1, 2, • • • , J} (2) pick a topic word from set T y and combine it with m words drawn uniformly at random from Θ to generate the sentence (or the bag of words) χ. In this task, one aims to develop a classifier from the training set that predicts the topic y for a random sentence χ generated in this way. We will consider the case that |Θ| >> |T|, which implies that a topic word appears much more frequently in the sentences than a non-topic word. Attention Model For this task, we consider a simple attention mechanism similar to the one proposed by Wang et al. (2016) . Each word w is associated with two parameters: an embedding ν w ∈ R d and a key κ w ∈ R d . Based on a global query q ∈ R d , the context vector of sentence χ is computed by ν (χ) = w∈χ ν w exp(q T κw) Z(χ) , where Z(χ) = w ∈χ exp(q T κ w ). Then ν(χ) is fed into a classifier that predicts the sentence's topic in terms of a distribution over all topics.foot_0 Denote the loss function by l(χ, y). Our upcoming analysis implies this attention model, although simple, may capture plenty of insight in understanding the training of more general attention models. Problem Statement Our objective is to investigate the training dynamics, under gradient descent, of this attention model. In particular, we wish to understand if there is an intrinsic mechanism that allows the attention model to discover the topic word and accelerates training. Moreover, we wish to investigate, beyond this setup, how the model is optimized when there is no clear delimitation between topic and non-topic words, as in real-world data.



The condition that the attention layer directly attends to the word embeddings merely serves to simplify the analysis in Section 4 but this condition is not required for most results presented in Sections 4 and 5. More discussions are given in Appendix A in this regard.

