ON THE DYNAMICS OF TRAINING ATTENTION MODELS

Abstract

The attention mechanism has been widely used in deep neural networks as a model component. By now, it has become a critical building block in many state-of-the-art natural language models. Despite its great success established empirically, the working mechanism of attention has not been investigated at a sufficient theoretical depth to date. In this paper, we set up a simple text classification task and study the dynamics of training a simple attention-based classification model using gradient descent. In this setting, we show that, for the discriminative words that the model should attend to, a persisting identity exists relating its embedding and the inner product of its key and the query. This allows us to prove that training must converge to attending to the discriminative words when the attention output is classified by a linear classifier. Experiments are performed, which validate our theoretical analysis and provide further insights.

1. INTRODUCTION

Attention-based neural networks have been broadly adopted in many natural language models for machine translation (Bahdanau et al., 2014; Luong et al., 2015) , sentiment classification (Wang et al., 2016) , image caption generation (Xu et al., 2015) , and the unsupervised representation learning (Devlin et al., 2019) , etc. Particularly in the powerful transformers (Vaswani et al., 2017) , attention is its key ingredient. Despite its great successes established empirically, the working mechanism of attention has not been well understood (see Section 2). This paper sets up a simple text classification task and considers a basic neural network model with the most straightforward attention mechanism. We study the model's training trajectory to understand why attention can attend to the discriminative words (referred to as the topic words). More specifically, in this task, each sentence is treated as a bag of words, and its class label, or topic, is indicated by a topic word. The model we consider involves a basic attention mechanism, which creates weighting factors to combine the word embedding vectors into a "context vector"; the context vector is then passed to a classifier. In this setting, we prove a closed-form relationship between the topic word embedding norm and the inner product of its key and the query, referred to as the "score", during gradient-descent training. It is particularly remarkable that this relationship holds irrespective of the classifier architecture or configuration. This relationship suggests the existence of a "synergy" in the amplification of the topic word score and its word embedding; that is, the growths of the two quantities promote each other. This, in turn, allows the topic word embedding to stand out rapidly in the context vector during training. Moreover, when the model takes a fixed linear classifier, this relationship allows rigorous proofs of this "mutual promotion" phenomenon and the convergence of training to the topic words. Our theoretical results and their implications are corroborated by experiments performed on a synthetic dataset and real-world datasets. Additional insights are also obtained from these experiments. For example, low-capacity classifiers tend to give stronger training signals to the attention module. The "mutual promotion" effect implied by the discovered relationship can also exhibit itself as "mutual suppression" in the early training phase. Furthermore, in the real-world datasets, where perfect

