CONVEXIFYING TRANSFORMERS: IMPROVING OPTI-MIZATION AND UNDERSTANDING OF TRANSFORMER NETWORKS

Abstract

Understanding the fundamental mechanism behind the success of transformer networks is still an open problem in the deep learning literature. Although their remarkable performance has been mostly attributed to the self-attention mechanism, the literature still lacks a solid analysis of these networks and interpretation of the functions learned by them. To this end, we study the training problem of attention/transformer networks and introduce a novel convex analytic approach to improve the understanding and optimization of these networks. Particularly, we first introduce a convex alternative to the self-attention mechanism and reformulate the regularized training problem of transformer networks with our alternative convex attention. Then, we cast the reformulation as a convex optimization problem that is interpretable and easier to optimize. Moreover, as a byproduct of our convex analysis, we reveal an implicit regularization mechanism, which promotes sparsity across tokens. Therefore, we not only improve the optimization of attention/transformer networks but also provide a solid theoretical understanding of the functions learned by them. We also demonstrate the effectiveness of our theory through several numerical experiments.

1. INTRODUCTION

Transformer networks proposed by Vaswani et al. (2017) have become a dominant architecture in various tasks, especially Natural Language Processing (NLP) (Devlin et al., 2018; Radford et al., 2019) , due to their extraordinary generalization properties and high capacity to learn from vast amount of data. Although there exists substantial empirical evidence on the effectiveness of transformer networks, revealing the underlying theoretical reasons behind their success is still an open research problem due to their highly nonlinear and nonconvex structure. 2022) discussed the importance of layer normalization and skip connections for transformer networks so that even changing the position of these might considerably impact the performance of a transformer network. However, a solid theoretical analysis of the underlying factors behind these issues is sill lacking, likely due to the highly complex and nonconvex structure of transformer networks. A series of papers also focused on designing new alternatives to the self-attention mechanism which perform similarly and might provide further interpretations towards the overall model. One set of 2022) attempted to analyzed transformer networks via convex duality by completely changing structure of the self-attention mechanism and removing FC layers. Even then, they failed to provide solid practical implications/benefits for transformers since their formulations are extremely challenging and complex to be solved in practice. Recently, another line of research has focused on understanding structures and patterns emerge throughout the training of transformer networks (Power et al., 2022; Thilak et al., 2022; Barak et al., 2022) . In particular, the grokking phenomenon was first observed by Power et al. ( 2022) on specific algorithmic tasks, such as modular division operations. Specifically, grokking refers to a sudden transition of validation or test accuracy to perfect generalization and this generalization happens well past the point of perfect training accuracy. This interesting behavior contradicts the common practice of early stopping in the training of deep learning models and definitely requires further understanding as to why this phenomenon emerges. In order to remedy the issues associated with the standard transformer networks, in this paper, we develop a convex optimization perspective to train, analyze and understand transformer networks. Particularly, we first propose a convex alternative to the self-attention mechanism and then develop our convex analytic framework on the resulting model as detailed in Figure 1 .

1.1. CONTRIBUTIONS

Our contributions can be summarized as follows: • We propose an alternative formulation to the standard self-attention mechanism and study the regularized training problem of attention/transformer networks with it. • We convexify the regularized training problem of attention/transformer networks with the proposed attention layer as shown in Figure 1 and therefore enable finding a globally optimal solution without requiring any nonconvex optimization heuristic, e.g., layer normalization and skip connections. • We also apply our convex analytic framework to various architectures, e.g., networks with or without an FCN layer. Thus, we are able to explain the impact of each component on the models learned throughout training. • We reveal an implicit regularization mechanism induced by our attention mechanism. We further characterize this regularization as a sparsity-inducing factor across tokens.



significant body of research focused on analyzing certain components of transformer networks via empirical studies. As an example, Liu et al. (2021a); Vashishth et al. (2019); Dong et al. (2021); Voita et al. (2019); Takase et al. (2022); Liu et al. (2021a) studied the impact of the attention mechanism on transformer networks. Although these studies agreed that attention is an essential component of transformers, they also raised several issues regarding interpretability and optimization. Particularly, Voita et al. (2019) demonstrated that most attention heads can be removed without affecting the performance of the network, which is an indicator of large amount of redundancy in the network. Vashishth et al. (2019) provided a set of empirical evidence showing that attention might not be needed for some NLP tasks. Additionally, Dong et al. (2021) revealed that although attention is at the heart of transformer networks, training an attention network in the absence of Fully Connected Network (FCN) layers and skip connections is extremely challenging since the network output degenerates quickly without them. Similarly, Takase et al. (

Figure1: Summary of our main findings: We first propose an alternative to attention, i.e., taking the convex combinations of tokens, and then convexifying whole transformer block (attention + Fully Connected Network (FCN)) with this new attention mechanism. The equivalent convex formulation also reveals a sparsity-inducing regularization across tokens as detailed in Theorem 1, 2, and 3.

