A PRIMAL-DUAL FRAMEWORK FOR TRANSFORMERS AND NEURAL NETWORKS

Abstract

Self-attention is key to the remarkable success of transformers in sequence modeling tasks including many applications in natural language processing and computer vision. Like neural network layers, these attention mechanisms are often developed by heuristics and experience. To provide a principled framework for constructing attention layers in transformers, we show that the self-attention corresponds to the support vector expansion derived from a support vector regression problem, whose primal formulation has the form of a neural network layer. Using our framework, we derive popular attention layers used in practice and propose two new attentions: 1) the Batch Normalized Attention (Attention-BN) derived from the batch normalization layer and 2) the Attention with Scaled Head (Attention-SH) derived from using less training data to fit the SVR model. We empirically demonstrate the advantages of the Attention-BN and Attention-SH in reducing head redundancy, increasing the model's accuracy, and improving the model's efficiency in a variety of practical applications including image and time-series classification. * Co-first authors.

1. INTRODUCTION

Transformer models (Vaswani et al., 2017) have achieved impressive success with state-of-the-art performance in a myriad of sequence processing tasks, including those in computer vision (Dosovitskiy et al., 2021; Liu et al., 2021; Touvron et al., 2020; Ramesh et al., 2021; Radford et al., 2021; Arnab et al., 2021; Liu et al., 2022; Zhao et al., 2021; Guo et al., 2021) , natural language processing (Devlin et al., 2018; Al-Rfou et al., 2019; Dai et al., 2019; Child et al., 2019; Raffel et al., 2020; Baevski & Auli, 2019; Brown et al., 2020; Dehghani et al., 2018 ), reinforcement learning (Chen et al., 2021; Janner et al., 2021) , and other important applications (Rives et al., 2021; Jumper et al., 2021; Zhang et al., 2019; Gulati et al., 2020; Wang & Sun, 2022) . Transformers can also effectively transfer knowledge from pre-trained models to new tasks with limited supervision (Radford et al., 2018; 2019; Devlin et al., 2018; Yang et al., 2019; Liu et al., 2019) . The driving force behind the success of transformers is the self-attention mechanism (Cho et al., 2014; Parikh et al., 2016; Lin et al., 2017) , which computes a weighted average of feature representations of the tokens in the sequence with the weights proportional to similarity scores between pairs of representations. The weights calculated by the self-attention determine the relative importance between tokens and thus capture the contextual representations of the sequence (Bahdanau et al., 2014; Vaswani et al., 2017; Kim et al., 2017) . It has been argued that the flexibility in capturing diverse syntactic and semantic relationships is critical for the success of transformers (Tenney et al., 2019; Vig & Belinkov, 2019; Clark et al., 2019) .

1.1. BACKGROUND: SELF-ATTENTION

For a given input sequence X := [x 1 , • • • , x N ] ⊤ ∈ R N ×Dx of N feature vectors, self-attention transforms X into the output sequence H in the following two steps: Step 1. The input sequence X is projected into the query matrix Q, the key matrix K, and the value matrix V via three linear transformations Q = XW ⊤ Q ; K = XW ⊤ K ; V = XW ⊤ V , where W Q , W K ∈ R D×Dx , and W V ∈ R Dv×Dx are the weight matrices. We denote Q := [q 1 , • • • , q N ] ⊤ , K := [k 1 , • • • , k N ] ⊤ , and V := [v 1 , • • • , v N ] ⊤ , where the vectors q i , k i , v i for i = 1, • • • , N are the query, key, and value vectors, respectively. Step 2. The output sequence H := [h 1 , • • • , h N ] ⊤ is then computed as follows H = softmax QK ⊤ / √ D V := AV, where the softmax function is applied to each row of the matrix QK ⊤ / √ D. The matrix A := softmax QK ⊤ √ D ∈ R N ×N and its component a ij for i, j = 1, • • • , N are called the attention matrix and attention scores, respectively. For each query vector q i for i = 1, • • • , N , an equivalent form of Eqn. (1) to compute the output vector h i is given by h i = N j=1 softmax q ⊤ i k j / √ D v j . The self-attention computed by Eqn. ( 1) and ( 2) is called the scaled dot-product or softmax attention. In our paper, we call a transformer that uses this attention the softmax transformer. The structure that the attention matrix A learns from training determines the ability of the self-attention to capture contextual representation for each token. Additionally, a residual connection can be added to the output of the self-attention layer, h i = x i + N j=1 softmax q ⊤ i k j / √ D v j . Multi-head Attention (MHA). In MHA, multiple heads are concatenated to compute the final output. This MHA mechanism allows transformers to capture more diverse attention patterns and increase the capacity of the model. Let H be the number of heads and W multi O = W 1 O , . . . , W H O ∈ R Dv×HDv be the projection matrix for the output where W 1 O , . . . , W H O ∈ R Dv×Dv . The MHA is defined as MultiHead({H} H s=1 ) = Concat(H 1 , . . . , H H )W multi⊤ O = H s=1 H s W s⊤ O = H s=1 A s V s W s⊤ O . Despite their remarkable success, most attention layers are developed based on heuristic approaches, and a coherent principled framework for synthesizing attention layers has remained elusive.

1.2. CONTRIBUTION

We derive the self-attention as the support vector expansion of a given support vector regression (SVR) problem. The primal representation of the regression function has the form of a neural network layer. Thus, we establish a primal-dual connection between an attention layer in transformers and a neural network layer in deep neural networks. Our framework suggests a principled approach to developing an attention mechanism: Starting from a neural network layer and a support vector regression problem, we derive the dual as a support vector expansion to attain the corresponding attention layer. We then employ this principled approach to invent two novel classes of attentions: the Batch Normalized Attention (Attention-BN) derived from the batch normalization layer in deep neural networks and the Attention with Scaled Heads (Attention-SH) resulting from solving the support vector regression model with less amount of training data. Our contribution is three-fold.

