ANALYZING ATTENTION MECHANISMS THROUGH LENS OF SAMPLE COMPLEXITY AND LOSS LAND-SCAPE

Abstract

Attention mechanisms have advanced state-of-the-art deep learning models for many machine learning tasks. Despite significant empirical gains, there is a lack of theoretical analyses on their effectiveness. In this paper, we address this problem by studying the sample complexity and loss landscape of attention-based neural networks. Our results show that, under mild assumptions, every local minimum of the attention model has low prediction error, and attention models require lower sample complexity than models without attention. Besides revealing why popular self-attention works, our theoretical results also provide guidelines for designing future attention models. Experiments on various datasets validate our theoretical findings.

1. INTRODUCTION

Significant research in machine learning has focused on designing network architectures for superior performance, faster convergence and better generalization. Attention mechanisms are one such design choice that is widely used in many natural language processing and computer vision tasks. Inspired by human cognition, attention mechanisms advocate focusing on relevant regions of input data to solve a desired task rather than ingesting the entire input. Several variants of attention mechanisms have been proposed, and they have advanced the state of the art in machine translation (Bahdanau et al., 2014; Luong et al., 2015; Vaswani et al., 2017) , image captioning (Xu et al., 2015 ), video captioning (Pu et al., 2018) , visual question answering (Zhou et al., 2015; Lu et al., 2016 ), generative modeling (Zhang et al., 2018) , etc. In computer vision, spatial/spatiotemporal attention masks are employed to focus only on relevant regions of images/video frames for underlying downstream tasks (Mnih et al., 2014) . In natural language tasks, where inputoutput pairs are sequential data, attention mechanisms focus on the most relevant elements in the input sequence to predict each symbol of the output sequence. Hidden state representations of a recurrent neural network are typically used to compute these attention masks. Substantial empirical evidence on the effectiveness of attention mechanisms motivates us to study the problem from a theoretical lens. To this end, it is important to understand the loss landscape and optimization of neural networks with attention. Analyzing the loss landscape of neural networks is an active ongoing research area, and it can be challenging even for two-layer neural networks (Poggio & Liao, 2017; Rister & Rubin, 2017; Soudry & Hoffer, 2018; Zhou & Feng, 2017; Mei et al., 2018b; Soltanolkotabi et al., 2017; Ge et al., 2017; Nguyen & Hein, 2017a; Arora et al., 2018) . Convergence of gradient descent for two-layer neural networks has been studied in Allen-Zhu et al. (2019); Mei et al. (2018b); Du et al. (2019 ). Ge et al. (2017) shows that there is no bad local minima for two-layer neural nets under a specific loss landscape design. These works reveal the importance of understanding loss landscape of neural networks. Unfortunately, these results cannot be directly applied for attention mechanisms. In attention models, the network structure is different, and the attention introduces additional parameters that are jointly optimized. To the best of our knowledge, there is no existing work analyzing the loss landscape and optimization of attention models. In this work, we present theoretical analysis of self-attention models (Vaswani et al., 2017) , which uses correlations among elements of input sequence to learn an attention mask. We summarize our work as follows. We carefully analyze attention mechanisms on the loss landscape in Sections 3 and 4. In Section 3, we show that, under mild assumptions, every stationary point of attention models achieves a low generalization error. Section 4 studies other properties of attention models on the loss landscapes. After the loss landscape analyses, we discuss how our theoretical results can guide the practitioners to design better attention models in Section 5. Then we validate our theoretical findings with experiments on various datasets in Section 6. Section 7 includes a few concluding remarks. Proofs and more technical details are presented in the appendix.

2. ATTENTION MODELS

Attention mechanisms are modules that help neural networks focus only on relevant regions of input data to make predictions. To compare attention model with non-attention model, we first introduce a two-layer non-attention model as the baseline model. The network architecture consists of a linear layer followed by rectified linear units (ReLU) as a non-linear activation function, and a second linear layer. Denote the weights of the first layer by w (1) ∈ R p×d , the weights of the second layer by w (2) ∈ R d , and the ReLU function by φ(•). Then the response function for the input x ∈ R p can be written as y = w (2)T φ( w (1) , x ). We call the above function "baseline model" since it does not employ any attention. To study such mechanisms, we mainly focus on analyzing the most popular self-attention model. In this paper, we consider two types of self-attention model. For the first type of self-attention model, we consider attention weights that are determined by a function f (x): y = w (2)T φ( w (1) , x f (x) ) where f (•) is a known mapping function from R p to R p , representing the attention weight of each feature with any given x. This model is a prototype version of transformer model (Vaswani et al., 2017) , with a pre-determined function as attention weights. Second, we introduce a more practical self-attention setup, which is the transformer model proposed in Vaswani et al. (2017) . To mimic the NLP task, we set the input x i = (x 1 i , . . . , x p i ) ∈ R t×p , where x j i ∈ R t , are t-dimensional vectors. Intuitively, each x i corresponds to independent sentences for i = 1, . . . , n, and x j i 's are fixed dimensional vector embedding of each word in sentence x i . w Q ,w K ∈ R dq×t are query and key weight matrices, and w V ∈ R dv×t is the value matrix. For each input x i , the key is calculated as: K i = (w K x i ) T ∈ R p×dq ; For z th vector in the input, the query vector is computed as: Q z i = (w Q x z i ) T ∈ R 1×dq for z = 1, . . . , p. The value matrix V = w V x i ∈ R dv×p . Then the self-attention w.r.t to the z th vector in the input x i is computed as: a self (z) i (x z i , w Q , w K ) = sof tmax( Q z i K T i d q ) for z = 1, . . . , p. And a self i = (a self (1) i , . . . , a self (p) i ). This self-attention vector represents the interaction between different words in each sentence. The value vector for each word in the sentence x z i can be calculated as V z i = V a self (z) i ∈ R dv . This value vector is then passed to a 2-layer MLP parameterized by w (1) ∈ R pdv×d and w (2) ∈ R d×1 , resulting in the following general model: y i = w (2)T φ( w (1) , vec(w V x i a self i ) ) + i (2) where vec(•) represents the vectorization of a matrix, and i are i.i.d sub-Gaussian error.

3. SAMPLE COMPLEXITY ANALYSES

In this section, we focus on analyzing the loss landscape for the the self-attention model as introduced in Section 2. In Section 3.1, we consider the sample complexity of the model with known attention weight function f (x). In Section 3.2, we consider transformer self-attention model, in which the attention weight function is also need to be learnt. Section 3.3 discusses the sample complexity result for multi-layer self-attention model. To avoid the non-differentiable point of ReLU φ, we use the softplus activation function φ τ0 (x), i.e., φ τ0 (x) = 1 τ0 log(1 + e τ0x ) Note φ τ0 converges to ReLU as τ → ∞ (Glorot et al., 2011) . All

