CONVEXIFYING TRANSFORMERS: IMPROVING OPTI-MIZATION AND UNDERSTANDING OF TRANSFORMER NETWORKS

Abstract

Understanding the fundamental mechanism behind the success of transformer networks is still an open problem in the deep learning literature. Although their remarkable performance has been mostly attributed to the self-attention mechanism, the literature still lacks a solid analysis of these networks and interpretation of the functions learned by them. To this end, we study the training problem of attention/transformer networks and introduce a novel convex analytic approach to improve the understanding and optimization of these networks. Particularly, we first introduce a convex alternative to the self-attention mechanism and reformulate the regularized training problem of transformer networks with our alternative convex attention. Then, we cast the reformulation as a convex optimization problem that is interpretable and easier to optimize. Moreover, as a byproduct of our convex analysis, we reveal an implicit regularization mechanism, which promotes sparsity across tokens. Therefore, we not only improve the optimization of attention/transformer networks but also provide a solid theoretical understanding of the functions learned by them. We also demonstrate the effectiveness of our theory through several numerical experiments.

1. INTRODUCTION

Transformer networks proposed by Vaswani et al. (2017) have become a dominant architecture in various tasks, especially Natural Language Processing (NLP) (Devlin et al., 2018; Radford et al., 2019) , due to their extraordinary generalization properties and high capacity to learn from vast amount of data. Although there exists substantial empirical evidence on the effectiveness of transformer networks, revealing the underlying theoretical reasons behind their success is still an open research problem due to their highly nonlinear and nonconvex structure. A significant body of research focused on analyzing certain components of transformer networks via empirical studies. As an example, Liu et al. (2021a) ; Vashishth et al. (2019) ; Dong et al. (2021) ; Voita et al. (2019) ; Takase et al. (2022) ; Liu et al. (2021a) studied the impact of the attention mechanism on transformer networks. Although these studies agreed that attention is an essential component of transformers, they also raised several issues regarding interpretability and optimization. Particularly, Voita et al. (2019) demonstrated that most attention heads can be removed without affecting the performance of the network, which is an indicator of large amount of redundancy in the network. Vashishth et al. (2019) provided a set of empirical evidence showing that attention might not be needed for some NLP tasks. Additionally, Dong et al. (2021) revealed that although attention is at the heart of transformer networks, training an attention network in the absence of Fully Connected Network (FCN) layers and skip connections is extremely challenging since the network output degenerates quickly without them. Similarly, Takase et al. (2022) discussed the importance of layer normalization and skip connections for transformer networks so that even changing the position of these might considerably impact the performance of a transformer network. However, a solid theoretical analysis of the underlying factors behind these issues is sill lacking, likely due to the highly complex and nonconvex structure of transformer networks. A series of papers also focused on designing new alternatives to the self-attention mechanism which perform similarly and might provide further interpretations towards the overall model. One set of Figure 1 : Summary of our main findings: We first propose an alternative to attention, i.e., taking the convex combinations of tokens, and then convexifying whole transformer block (attention + Fully Connected Network (FCN)) with this new attention mechanism. The equivalent convex formulation also reveals a sparsity-inducing regularization across tokens as detailed in Theorem 1, 2, and 3. work utilizes multi-layer perceptron based architectures, Tolstikhin et al. (2021) ; Tatsunami & Taki (2021) ; Touvron et al. (2021) ; Liu et al. (2021b) ; Yu et al. (2021) , while another set of of papers proposes Fourier based models Lee-Thorp et al. (2021) ; Rao et al. (2021) ; Li et al. (2020) ; Guibas et al. (2021) . Others also proposed replacing the self-attention mechanism with matrix decomposition Geng et al. (2021) . Although these works successfully applied to certain applications, they lack any solid theoretical analysis and understanding from an optimization perspective. Recently, Sahiner et al. (2022) attempted to analyzed transformer networks via convex duality by completely changing structure of the self-attention mechanism and removing FC layers. Even then, they failed to provide solid practical implications/benefits for transformers since their formulations are extremely challenging and complex to be solved in practice. Recently, another line of research has focused on understanding structures and patterns emerge throughout the training of transformer networks (Power et al., 2022; Thilak et al., 2022; Barak et al., 2022) . In particular, the grokking phenomenon was first observed by Power et al. (2022) on specific algorithmic tasks, such as modular division operations. Specifically, grokking refers to a sudden transition of validation or test accuracy to perfect generalization and this generalization happens well past the point of perfect training accuracy. This interesting behavior contradicts the common practice of early stopping in the training of deep learning models and definitely requires further understanding as to why this phenomenon emerges. In order to remedy the issues associated with the standard transformer networks, in this paper, we develop a convex optimization perspective to train, analyze and understand transformer networks. Particularly, we first propose a convex alternative to the self-attention mechanism and then develop our convex analytic framework on the resulting model as detailed in Figure 1 .

1.1. CONTRIBUTIONS

Our contributions can be summarized as follows: • We propose an alternative formulation to the standard self-attention mechanism and study the regularized training problem of attention/transformer networks with it. • We convexify the regularized training problem of attention/transformer networks with the proposed attention layer as shown in Figure 1 and therefore enable finding a globally optimal solution without requiring any nonconvex optimization heuristic, e.g., layer normalization and skip connections. • We also apply our convex analytic framework to various architectures, e.g., networks with or without an FCN layer. Thus, we are able to explain the impact of each component on the models learned throughout training. • We reveal an implicit regularization mechanism induced by our attention mechanism. We further characterize this regularization as a sparsity-inducing factor across tokens. • We demonstrate the effectiveness of our convex reformulation via various experimental results. We also show that our reformulation significantly mitigates the grokking phenomenon studied in recent papers (Power et al., 2022; Thilak et al., 2022) . We use lowercase and uppercase bold letters to denote vectors and matrices, respectively. We denote a certain column/element of a vector or matrix using subscripts. As an example, w jk denotes the jk th entry of the matrix W.

1.2. NOTATIONS

We use I k to denote the identity matrix of size k × k and 0 (or 1) to denote a vector/matrix of zeros (or ones) with appropriate sizes. We also use [n] for the set of integers ranging from 1 to n. We represent the Euclidean and Frobenius norms as • 2 and • F , respectively. We also use 1[x ≥ 0] to denote the 0-1 valued indicator function. We provide more notations we use throughout the paper in Table 1 .

2. TRANSFORMER NETWORKS

Given a data sample (or sentence) X ∈ R h×d as a sequence of h tokens with the embedding dimension d, we define the key, query, and value matrices as Q = XW q , W q ∈ R d×d K = XW k , W k ∈ R d×d V = XW v , W v ∈ R d×d , which are the main components of the self-attention mechanism. Then, a single transformer block, which is basically a stack of self attention, residual connection, layer normalization, and point-wise feedforward connections, can be formulated as follows A s,j = softmax QK V A o = A s W o , W o ∈ R d×d X A = LayerNorm (A o ) + X X B = σ (X A W 1 ) W 2 , where σ (•) denotes the activation function for the FCN layer. Although skip connections, layer normalization and FCN also play a crucial role in a transformer block, the success of these networks has been mostly attributed to the self-attention part, denoted as A o (Vaswani et al., 2017) . Therefore, in the following section, we first study the training problem of a simplified transformer network, for which the network output is directly A o . We then extend our derivations to a transformer network with FCN layers.

3. ATTENTION-ONLY NETWORKS

We first consider a simplified transformer network only with a self attention layer that maps input sequence X ∈ R n×d to the output sequence Ŷ ∈ R n×c with c outputs as follows Ŷ = softmax XW q W k X XW v W o . We also call the model (2) as an attention-only network. This is a meaningful model and has been applied to various tasks, including machine translation, language modeling, image captioning, and object recognition (Vashishth et al., 2019) . We next consider a standard regression framework with an arbitrary convex loss function. Given a training set {X i , Y i } N i=1 , where X i ∈ R n×d and Y i ∈ R n×c denote the input sequence and the labels/target outputs, respectively, the weight decay regularized training problem for the attentiononly network in (2) is as follows min Wq,W k ,Wv,Wo N i=1 L softmax X i W q W k X i XW v W o , Y i + β 2 #∈{q,k,v,o} W # 2 F , where L (•) is an arbitrary convex loss function, including squared loss and cross entropy, and β > 0 is the regularization coefficient. Although the attention-only model in ( 2) is quite powerful across various NLP tasks, e.g., natural language inference, neural machine translation, and text classification (Vashishth et al., 2019) , the corresponding training problem in (3) is an extremely challenging optimization task and requires various nonconvex optimization heuristics (Dong et al., 2021) to be adequately trained. To remedy these issues, in the following sections, we first reformulate the training problem by replacing the attention part with an alternative convex layer and then cast the reformulated training problem as an interpretable convex optimization problem that enables the globally optimizing the network parameters.

3.1. CONVEX ATTENTION LAYER

We first note that since the softmax (•) operation is highly nonlinear and nonconvex, the training problem in (3) is a challenging nonconvex optimization problem. Therefore, one may not adequately train attention networks and obtain trivial models at the end of training. For example, Dong et al. (2021) shows that attention networks are likely to degenerate throughout the training and the output converges to a rank-1 matrix. Thus, they fail to learn the underlying tasks. To avoid the issues associated with the nonconvex formulation in (2), we first replace the softmax operation with a simpler yet effective alternative. Particularly, since softmax converts the rows of its input matrix to a probability distribution, it can be relaxed as a linear operation with unit simplex constraints as follows for any U ∈ R n×n , ∃W ∈ ∆ s.t. softmax (U) X = WX, where ∆ := {W ∈ R n×n : w i ≥ 0, 1 w i = 1, ∀i ∈ [n]} denotes a convex set of constraints, also termed as unit simplex constraints. Thus, we simplified and convexified the attention mechanism without disturbing its structure. Based on this observation, (3) can be reformulated as follows min W1∈∆ W2∈R d×d ,W3∈R d×c N i=1 L (W 1 X i W 2 W 3 , Y i ) + β 2 W 2 2 F + W 3 2 F . Note that the model above utilizes a single head attention model and, therefore, may not be practically relevant due to its insufficient expressive power. Thus, we introduce the concept of head to the problem in (4) as follows min W1j ∈∆ W2j ∈R d×d ,W3j ∈R d×c N i=1 L   h j=1 W 1j X i W 2j W 3j , Y i   + β 2   h j=1 W 2j 2 F + W 3j 2 F   . (5) Now, we are ready to apply the convex analytic tools to (5) as detailed in the next section.

3.2. CONVEX OPTIMIZATION FOR ATTENTION-ONLY NETWORKS

As a warm-up, let us consider the scalar output prediction problem where the targets are onedimensional, i.e., y i ∈ R. Then, (5) reduces to the following optimization problem min w1j ∈∆ w2j ∈R d ,w3j ∈R N i=1 L   h j=1 w 1j X i w 2j w 3j , y i   + β 2 h j=1 w 2j 2 2 + (w 3j ) 2 . ( ) Next, we first apply a rescaling between the parameters w 2j and w 3j such that (6) can be described as an 1 regularized optimization problem.  (3d 2 + d) O(n 2 dh) h(n + d + 1) O(nd) nd O(nd) Multi output h(3d 2 + dc) O(n 2 dh + ndhc) h(n + d + c) O(nd + c) ndc O(ndc) Multi output with FCN h(3d 2 + dc) O(n 2 dh + ndhc) h(n + d + c) O(nd + c) ndch O(ndch) Lemma 1. The problem in (6) is equivalent to the following 1 regularized training problem min w1j ∈∆ w2j 2 ≤1,w3j ∈R N i=1 L   h j=1 w 1j X i w 2j w 3j , y i   + β w 3 1 . ( ) Based on the equivalent formulation in Lemma 1, the next theorem introduces a convex optimization problem that is equivalent to (6). Theorem 1. The nonconvex optimization problem (6) can be equivalently cast as the following convex optimization problem min Z∈R n×d 1 2 N i=1 L trace Z X i , y i + β n k=1 z k 2 . ( ) Note that the equivalent convex model in ( 8) requires a single parameter matrix Z ∈ R n×d , where each row is the attentions scores of the corresponding token. We also remark that the regularization in (8), i.e., the sum of 2 norms of the rows of the parameter matrix Z, is a specific type of regularization, also known as group 1 or Lasso, introduced by (Bakin et al., 1999) and shown to promote group sparsity across parameters (Yuan & Lin, 2006) . In our case, the group sparsity is across the token index k. Therefore, one can interpret the model in ( 8) as a sparse linear model, where the sparsity is across tokens. In other words, (8) can be explained as a model that tries to use as few tokens as possible to fit the training labels {y i } N i=1 . Unlike the nonnegative attention scores in (6), denoted as w 1j ∈ ∆, the convex parameters Z ∈ R n×d do not require any constraints. Therefore, one can directly apply standard training algorithms, such as SGD and Adam, to train the convex problem (8). Moreover, an optimal set of parameters for (6) can be recovered from a solution to (8) as proven in the following result. Proposition 1. After solving the convex optimization problem in (8), one can recover an optimal solution to the nonconvex optimization problem in (6), denoted as {w * 1j , w * 2j , w * 3j } h j=1 , as follows w * 1j = e j , w * 2j = z j z j 2 , w * 3j = z j 2 , ∀j ∈ [h], where e j ∈ R n is the j th ordinary basis vector, z j ∈ R d is the j th row of Z, and we assume that there are h nonzero rows out of n rows of Z due to the sparsity-inducing regularization in (8). Proposition 1 proves that there is a one-to-one mapping between the parameters of the nonconvex formulation in ( 6) and the convex formulation in (8). Therefore, there is no need to solve the challenging nonconvex optimization problem (6) which also requires several optimization heuristics to be adequately trained. Instead, one can solve the convex problem (8) and then use the mapping in Proposition 1 to obtain an optimal solution to (6).

3.3. EXTENSION TO MULTIDIMENSIONAL OUTPUTS

In the previous section, we considered a setting with scalar target variables, i.e., y i ∈ R. However, for some problems, e.g., multiclass classification, target variables can be multidimensional. Therefore, we now extend the analysis to the problems with multiple/vector outputs as follows min w1j ∈∆ w2j ∈R d ,w3j ∈R c N i=1 L   h j=1 w 1j X i w 2j w 3j , y i   + β 2 h j=1 w 2j 2 2 + w 3j 2 1 , where y i ∈ R c and c denotes the number of outputs/classes. Note that here we put 2 1 -norm on w 3j to enable our convex arguments but this does not impact performance of the network in practice. Then, following the same derivations yields the convex program in the next result. Theorem 2. The nonconvex optimization problem (9) is equivalent to the following convex optimization problem min Z l ∈R n×d N i=1 c l=1 L trace Z l X i , y il + β c l=1 n k=1 z lk 2 . ( ) Theorem 2 shows that the equivalent convex model becomes separable over the output index l, i.e., instead of a single parameter matrix in ( 8), here we have c parameter matrices due to having c outputs in the nonconvex model ( 9) (see Table 2 for details). This also illustrates that the number of outputs in the network directly controls the overparameterization level of the equivalent convex formulation.

3.4. ATTENTION NETWORKS WITH FCN LAYERS

Although the model in ( 5) exhibits interesting properties in various applications (Dong et al., 2021) , it is basically a linear function of the token matrix X. Therefore, it is likely to suffer from inadequate performance especially for some challenging problems in NLP. A series of papers (Dong et al., 2021; Geva et al., 2021; Meng et al., 2022; Geva et al., 2022b ;a) also confirmed the importance of FCNs via extensive empirical evidence. Therefore, in this section, we include an FCN layer to our attention only model in ( 5) and derive an equivalent convex formulation for this new model. Here, we consider the following optimization problem min w1j ∈∆ w2j ,w3j ∈R c N i=1 L   σ   h j=1 w 1j X i w 2j   w 3j , y i   + β 2   h j=1 w 2j 2 2 + w 3j 2 1   , where σ (•) is the activation function. Theorem 3. The nonconvex optimization problem (11) with the gated ReLU activation is equivalent the following convex optimization problem min Z jl ∈R n×d N i=1 c l=1 L   h j=1 1 ij trace Z jl X i , y il   + β c l=1 h j=1 n k=1 z jlk 2 , ( ) where 1 ij := 1 u 1j X i u 2j ≥ 0 denotes the indicator function for gated ReLU activation and here {u 1j , u 2j } h j=1 are fixed vectors that can be randomly selected. Theorem 3 implies that introducing the activation function further increases in the overparameterization level of the equivalent convex formulation. Precisely, (12) has h times more parameters than (10) as shown in Table 2 .

4. NUMERICAL EXPERIMENTS

In this section, we present experimental results corroborating our theory in the previous sections. Student-teacher setting with BERT: We first consider a student-teacher setting with the pretrained BERT model in the Hugging Face repository, i.e., bert-base-uncased. Particularly, we feed the samples from the mrpc subset of the glue dataset (Warstadt et al., 2018; Wang et al., 2019) through the pretreained BERT model and save the input and output activations in a certain layer. Then, we train the attention-only models, i.e., standard nonconvex self-attention (3), alternative nonconvex attention (9), and convex (10), from scratch using these pre and post activations as our training dataset. All the experiments throughout this section are performed using a single GPU on Google Colab. We also use the same regularization coefficient β and optimizer, i.e., Adam, and tune the learning rate and regularization coefficient by performing a grid search on a validation dataset for both algorithms. However, notice that we do not use any nonconvex optimization heuristics, e.g., layer normalization and skip connections, for the convex model in all the experiments. In Figure 2 , we plot the objective values (i.e. training loss + regularization term) and test losses with respect to time in seconds using the data extracted from the sixth layer of pretrained BERT model. We observe that our convex training approach achieves almost an order of magnitude smaller objective value than the standard nonconvex training, which is possibly stuck at a local minimum. This effectiveness in training also translates into better generalization, i.e., our convex training approach obtains a lower test loss than the standard nonconvex training. In order to understand the functions learned by each models, we also analyzed the attention maps in Figure 3 . Here, standard nonconvex training fails to learn the underlying model and outputs a uniform attention map across token. However, our convex training outputs an attention map that is quite similar to the ground truth attention map, and therefore we successfully learn the structure in the training data. Hence, these experiments clearly illustrate the effectiveness of our convex training approach in both training and testing. Algorithmic datasets and Grokking: Inspired by the grokking phenomenon observed in Power et al. (2022) , we next validate the effectiveness of our convex training approach against standard transformer networks with the self-attention mechanism in (1) on algorithmic datasets. Particularly, we use the same setting in Power et al. (2022) , and evaluate the performance on modular division operations with mod 97 and mod 15, where we train the architectures till they reach 99% test accuracy whenever possible. In Figure 4 , we first replicate the results in Power et al. (2022) and confirm that the grokking phenomenon indeed emerges here, i.e., the nonconvex curve (purple) reaches 100% training accuracy at around 10 3 iterations in Figure 4a while it requires more than 10 5 iterations to reach perfect generalization in Figure 4b . We also compare the nonconvex and convex training approaches. Here, we show that our convex training approach converges to the perfect generalization accuracy 10× faster than the nonconvex one in Figure 4b . Moreover, the , where L denotes the number of layers in each model. We observe that introducing one more layer substantially improves the convergence speed of our convex formulation while it fails to make a noticeable impact on the nonconvex formulation. convex model also yield significantly lower test loss in Figure 4c , which implies that it has higher confidence in test predictions and therefore more robust than standard nonconvex training. We remark that in the previous section, we theoretically analyze only single attention/transformer blocks. However, since the benign impact of depth or number of layers (denoted as L) has already been empirically proven in the deep learning literature, we also propose an extension of our convex model to deeper settings. We basically stack the convex transformer layers in ( 12) to obtain an arbitrarily deep network. In Figure 5 , we compare the performance of two-layer transformer networks with one-layer networks. Here, we observe that while adding one more layer results in significant improvements for the convex model, especially in terms of optimization speed, it fails to make any discernible difference for the nonconvex model. Moreover, we run the algorithms on the mod 15 operation which is basically more challenging task due to smaller number of samples. In this case, one-layer models are not able learn the underlying task perfectly as demonstrated in Figure 6 but our convex model is significantly better in terms of both test accuracy and test loss. By increasing the number of layers to four, we enable both models to achieve perfect generalization accuracy. Our deep model reaches this level much faster and also yields lower test loss than the nonconvex model. We next empirically analyze the grokking phenomenon on both our convex and standard nonconvex models. For this purpose, we plot the number of iterations to reach 99% test accuracy for each of our experiments in Figure 7a . Notice that here we do not include the one-layer results for the mod 15 case, since both models fail to achieve perfect generalization in that case. Figure 7a clearly shows that our convex training approach converges to the 99% accuracy level substantially faster than the standard nonconvex training. Therefore, we also mitigate the impact of the grokking phenomenon as demonstrated in Figure 7b , where we quantify the amount of grokking in terms of the number of iterations. Based on this experiment, we also conjecture that the grokking phenomenon can be mostly attributed to the highly nonlinear and nonconvex structure of standard transformer models. Here, one-layer networks fail to achieve 99% test accuracy however our convex training approach (light green) still generalizes better than the nonconvex training (light purple). We also show that perfect generalization in terms of accuracy can be achieved with four-layer networks. Figure 7 : Amount of grokking in terms of # of iterations required by our convex and standard nonconvex training approaches, where p and L denote the coefficient for the modular division and the number of layers, respectively. Here, we do not include the one-layer results in Figure 6 since both algorithms fail to achieve 99% test accuracy. We demonstrate that the impact of grokking is substantially mitigated with our convex training approach.

5. CONCLUSION

In this paper, we studied the regularized training problem of attention/transformer networks and developed a convex analytic framework to train these networks. Particularly, we first proposed a convex alternative to the self-attention mechanism and then reformulated the training problem with this alternative attention mechanism as convex optimization problems. Thanks to our convex reformulation, we globally optimize the network parameters without requiring any kind of nonconvex optimization heuristics. In addition, the functions learned by our reformulation is transparent and interpretable. More importantly, the reformulated problem reveals a sparsity-inducing regularization mechanism across tokens in the data, which also sheds more light on the structure of the resulting function and its generalization properties. We then empirically verified effectiveness of our convex training approach over standard nonconvex training via several numerical experiments. We also note that analyzing transformer networks through the lens of convex optimization theory is extremely crucial since it may result in substantial improvements in the understanding and optimization of these networks. However, it is also quite challenging due to the inherent nonconvex structure of the network model. To the best of our knowledge, this paper is the first step in this direction and therefore has some limitations which can hopefully be eliminated by future work. Specifically, in this paper, we mainly focused on the theory side of convex analysis and empirically validated the theory on a few small-scale problem instances. We hope that a comprehensive and large-scale empirical verification of our theory will be conducted by the follow-up papers. L   h j=1 w 1j X i w 2j w 3j , y i   + β 2 h j=1 w 2j 2 2 + (w 3j ) 2 , ( ) We first apply the following scaling for {w 2j , w 3j } m j=1 w2j := α j w 2j , w3j := w 3j α j . ( ) where α j > 0. Since this scaling doesn't change the output of the network, i.e., h j=1 w 1j X i w2j w3j = h j=1 w 1j X i w 2j w 3j , the training loss part of the objective function stays the same. Thus, we can directly search for the optimal scaling parameter α j > 0 by minimizing the regularization term via the following AM-GM inequality h j=1 w2j 2 2 + ( w3j ) 2 = h j=1 α 2 j w 2j 2 2 + (w 3j ) 2 α 2 j ≥ 2 h j=1 w 2j 2 |w 3j | = 2 h j=1 w2j 2 | w3j | where the equality is achieved when α j = |w3j | w2j 2 . Thus, we obtain a reformulation of (13) where the regularization term is in a multiplicative form as follows min w1j ∈∆ w2j ∈R d ,w3j ∈R N i=1 L   h j=1 w 1j X i w 2j w 3j , y i   + β h j=1 w 2j 2 |w 3j |. ( ) Next, we apply a variable change to the reformulation in (15) as follows w 2j := w 2j w 2j 2 , w 3j := w 3j w 2j 2 . With this variable change, we rewrite (15) as min w1j ∈∆ w 2j : w 2j 2 =1 w 3j ∈R N i=1 L   h j=1 w 1j X i w 2j w 3j , y i   + β h j=1 |w 3j |. This concludes the proof and yields the following equivalent formulation of ( 16) min w1j ∈∆ w 2j : w 2j 2 =1 w 3j ∈R N i=1 L   h j=1 w 1j X i w 2j w 3j , y i   + β w 3 1 . We also note that the equality constraint w 2j 2 = 1 can be relaxed as w 2j 2 ≤ 1 due to the optimality conditions arising from the regularization term w 3 1 .

A.2 PROOF OF PROPOSITION 1

We first note that in order to maintain strong duality in our convex problem derivations, we basically use the arguments in Pilanci & Ergen (2020) ; Rosset et al. (2007b) , where the authors proved that as long as h exceeds a certain threshold h * , then there will be sparsity in the solution due to the sparsity-inducing regularization in ( 7). And we have the following upperbound h * ≤ N + 1. Note that this N + 1 upperbound is the worst case scenario and h * N + 1 in practice as validated in Pilanci & Ergen (2020) . Thus, below, we assume that there is a sparsity pattern in the solution. Given an optimal solution to (8), denoted as Z * ∈ R n×d , we first rewrite this solution as a summation of rank-1 matrices as follows Z * = h j=1 e j z j = h j=1 e j z j z j 2 z j 2 where e j ∈ R n is the j th ordinary basis vector and we assume that there are h nonzero rows out of n rows of Z due to the sparsity-inducing regularization in (8). Then, this implies that the output of the optimal can be equivalently formulated as follows trace Z * X i = h j=1 e j X i z j z j 2 z j 2 = h j=1 w * 1j X i w * 2j w * 3j =⇒ w * 1j = e j , w * 2j = z j z j 2 , w * 3j = z j 2 , where {w * 1j , w * 2j , w * 3j } h j=1 denotes an optimal solution to (6). Next, we show that both of these solution sets achieve the same objective value f {X i , y i } N i=1 := N i=1 L   h j=1 w * 1j X i w * 2j w * 3j , y i   + β 2 h j=1 w * 2j 2 2 + w * 3j 2 = N i=1 L   h j=1 e j X i z j z j 2 z j 2 , y i   + β 2 h j=1   z j z j 2 2 2 + z j 2 2   = N i=1 L   h j=1 e j X i z j , y i   + β 2 h j=1 z j 2 + z j 2 = N i=1 L   h j=1 trace z j e j X i , y i   + β h j=1 z j 2 = N i=1 L trace Z * X i , y i + β n k=1 z k 2 , where the last inequality follows from the fact that there are h nonzero rows out of n rows of Z due to the sparsity-inducing regularization in (8). Note that ( 17) and ( 8) are the same objectives evaluated at Z * , which concludes the proof. Extension to multidimensional outputs in Section 3.3: Here we show that the proof above can be straightforwardly extended to the multidimensional output case in Section 3.3. Given an optimal solution to (10), denoted as Z * l ∈ R n×d , we first rewrite this solution as a summation of rank-1 matrices as follows Z * l = h k=1 e lk z lk = h k=1 e lk z lk z lk 2 z lk 2 . Then, this implies that the output of the optimal can be equivalently formulated as follows trace Z * l X i = h k=1 e lk X i z lk z lk 2 z lk 2 =⇒ w * 1j = e lk , w * 2j = z lk z lk 2 , w * 3j = e l z lk 2 , where {w * 1j , w * 2j , w * 3j } hc j=1 denotes an optimal solution to (9). Note that here index j ∈ [hc] instead of j ∈ [h] in the scalar output case.

A.3 PROOF OF THEOREM 1

We first provide a summary of our proof strategy. For the derivations of the convex formulation, we basically need to find the bidual form of (6), i.e., the dual of the dual problem. Thus, we start with taking the dual of (6). To avoid nonconvexity in the dual problem, we reformulate the dual constraint, which makes the problem nonconvex, as a convex constraint. Therefore, we obtain a convex dual problem. Then, we take the dual of the dual problem to get the bidual formulation of ( 6). Since we convexify the dual problem, the bidual formulation is also a convex problem. Therefore, we achieve an equivalent convex formulation of the original nonconvex training problem (6). We also note that a similar proof strategy was also used in Pilanci & Ergen (2020) . In order to take the dual of (7) (i.e. restated below for the convenience of the reader) we need to form the Lagrangian function for the following optimization problem min w1j ∈∆ w2j 2 ≤1,w3j ∈R N i=1 L   h j=1 w 1j X i w 2j w 3j , y i   + β w 3 1 . To construct the Lagrangian function, we first introduce an additional variable ŷ ∈ R N as follows min ŷ∈R N ,w1j ∈∆ w2j 2 ≤1,w3j ∈R N i=1 L (ŷ i , y i ) + β w 3 1 s.t. ŷi = h j=1 w 1j X i w 2j w 3j , ∀i ∈ [n]. Now we can form the Lagrangian for (18) as L(v, y, w 3 ) := N i=1 L (ŷ i , y i ) + β w 3 1 + N i=1 v i   ŷi - h j=1 w 1j X i w 2j w 3j   = N i=1 L (ŷ i , y i ) + N i=1 v i ŷi + β w 3 1 - N i=1 v i h j=1 w 1j X i w 2j w 3j = N i=1 L (ŷ i , y i ) + N i=1 v i ŷi + β w 3 1 - h j=1 N i=1 v i w 1j X i w 2j w 3j Minimizing the Lagrangian L(•) yields the following dual problem of ( 6) max v∈R N -L * (v, y) s.t. max w1∈∆, w2 2 ≤1 N i=1 v i w 1 X i w 2 ≤ β, where L * (•) denotes the Fenchel congregate function of the original loss function L (•) (Boyd & Vandenberghe, 2004) , which is defined as follows L * (v, y) := max z∈R N z v -L (z, y) . In order to convexify the dual constraint, we next find the maximizers of the dual constraint as follows max w1∈∆, w2 2 ≤1 N i=1 v i w 1 X i w 2 = max w1∈∆ N i=1 v i w 1 X i 2 = max w1∈∆ N i=1 n k=1 v i w 1k x ik 2 ≤ max w1∈∆ n k=1 w 1k N i=1 v i x ik 2 = max k∈[n] N i=1 v i x ik 2 , ( ) where the upperbound is achieved when each w 1 has is a vector of zeros except a single one located at the index of maximum norm of weighted tokens. Based on the observation in (20), we can equivalently write the dual problem in (19) as follows d * = max v∈R N -L * (v, y) s.t. max k∈[h] N i=1 v i x ik 2 ≤ β = max v∈R N -L * (v, y) s.t. N i=1 v i x ik 2 ≤ β, ∀k [n]. Next, we form the Lagrangian for the dual problem ( 21) L(v, y, λ) := -L * (v, y) + n k=1 λ k β - N i=1 v i x ik 2 and the corresponding optimization problem can be written in terms of the Lagrangian as min λ≥0 max v∈R N L(v, y, λ) = - 1 2 v -y 2 2 + 1 2 y 2 2 + n k=1 λ k β - N i=1 v i x ik 2 . Then, we introduce additional variables r k ∈ R d to equivalently formulate the optimization problem above as min λ≥0 max v∈R n min r k : r k 2 ≤1 -L * (v, y) + n k=1 λ k β -r k N i=1 v i x ik . Due to Sion's minimax theorem (Sion, 1958) , we can change the order the minimization and maximization to obtain closed-form solutions for the maximization over the dual variable v. This yields the following problem min λ≥0 min r k : r k 2 ≤1 N i=1 L n k=1 λ k r k x ik , y i + β n k=1 λ k . Next, we apply a variable change as z k := λ k r k , then the problem above reduces to min z k : z k 2 ≤λ k N i=1 L n k=1 z k x ik , y i 2 + β n k=1 λ k . From the KKT conditions, we now that λ k = z k 2 at a global optimum. In particular, if λ k > z k 2 , then one can further minimize the objective function by reducing the λ k and therefore λ k would not be optimal. With this, the problem can be reformulated as min z k N i=1 L n k=1 z k x ik , y i + β n k=1 z k 2 , which is the same formulation with (8) and therefore concludes the proof.

STRONG DUALITY PROOF

To get the bidual of (7), we first utilize semi-infinite duality theory as follows. We first compute the dual of ( 21) with respect to the dual parameter v to have p * ∞ := min µ N i=1 L   w1j ∈∆ w2j 2 ≤1 w 1j X i w 2j dµ(W 1 , w 2 ), y i   + β µ T V , where µ T V represents the total variation norm of the signed measure µ. Remark that ( 22) is an infinite-dimensional dimensional training problem such as the ones in Bach (2017) . Also, notice that this problem is convex withe respect to the linear measure µ (Bach, 2017) . Therefore, strong duality holds, i.e., d * = p * ∞ where d * denotes the objective value of ( 21). In addition to this, although ( 22) is an infinite-dimensional problem, it has at most N + 1 heads at the optimum due to Caratheodory's theorem (Rosset et al., 2007a) . Therefore, ( 22) is equivalent to the following problem p * ∞ = min w1j ∈∆ w2j 2 ≤1 N i=1 L   h * j=1 w 1j X i w 2j w 3j , y i   + β w 3 1 where h * ≤ N + 1. We note that that provided that h ≥ h * , ( 23) and ( 7) are the same problems, which proves strong duality, i.e., p * = p * ∞ = d * , where p * denotes the objective value of (7).

A.5 PROOF OF THEOREM 2

We first apply the scaling technique in Lemma 1 for {w 2j , w 3j } m j=1 w2j := α j w 2j , w3j := w 3j α j . Then, following the same steps in Lemma 1, (9) can be equivalently formulated as min w1j ∈∆ w2j : w2j 2 ≤1 w3j ∈R c N i=1 L   h j=1 w 1j X i w 2j w 3j , y i   + β 2 h j=1 w 3j 1 . Next, we again construct the Lagrangian function by introducing an additional variable ŷi ∈ R c , ∀i ∈ [N ] as follows min ŷi∈R c ,w1j ∈∆ w2j 2 ≤1,w3j ∈R N i=1 L (ŷ i , y i ) + β h j=1 w 3j 1 s.t. ŷi = h j=1 w 1j X i w 2j w 3j , ∀i ∈ [N ]. (25) Now we can form the Lagrangian for (25) as L {v i } N i=1 , y, {w 3j } h j=1 := N i=1 L (ŷ i , y i ) + β h j=1 w 3j 1 + N i=1 v i   ŷi - h j=1 w 1j X i w 2j w 3j   = N i=1 L (ŷ i , y i ) + N i=1 v i ŷi + β h j=1 w 3j 1 - N i=1 v i h j=1 w 1j X i w 2j w 3j = N i=1 L (ŷ i , y i ) + N i=1 v i ŷi + β h j=1 w 3j 1 - h j=1 N i=1 w 1j X i w 2j v i w 3j Minimizing the Lagrangian L(•) yields the following dual problem of ( 9) max vi∈R c -L * {v i } N i=1 , {y i } N i=1 s.t. max w1∈∆, w2 2 ≤1 N i=1 v i w 1 X i w 2 ∞ ≤ β, where L * (•) denotes the Fenchel congregate function of the original loss function L (•) (Boyd & Vandenberghe, 2004) , which is defined as follows L * {v i } N i=1 , {y i } N i=1 := max Z∈R N trace Z V -L (Z, Y) , where V, Y ∈ R N ×c are the matrix representations for the set of variables {v i , y i } N i=1 . In order to characterize the optimal layer weight explicitly, we next find the maximizers of the dual constraint as follows max w1∈∆, w2 2 ≤1 N i=1 v i w 1 X i w 2 ∞ = max w1∈∆ max l∈[c] N i=1 v il w 1 X i 2 = max w1∈∆ max l∈[c] N i=1 n k=1 v il w 1k x ik 2 ≤ max w1∈∆ max l∈[c] n k=1 w 1k N i=1 v i x ik 2 = max l∈[c] max k∈[n] N i=1 v il x ik 2 , where the upperbound is achieved when each w 1 has is a vector of zeros except a single one located at the index of maximum norm of weighted tokens. Based on the observation in ( 27 (31) where L * (•) denotes the Fenchel congregate function of the original loss function L (•) (Boyd & Vandenberghe, 2004) , which is defined as follows L * {v i } N i=1 , {y i } N i=1 := max Z∈R N ×c trace Z V -L (Z, Y) , where V, Y ∈ R N ×c are the matrix representations for the set of variables {v i , y i } N i=1 . We next note that we utilize the gated ReLU nonlinearity introduced in Mishkin et al. (2022) . Thus, the activations σ w 1j X i w 2j can be expressed as σ w 1j X i w 2j := 1 ij w 1j X i w 2j , where 1 ij := 1 u 1j X i u 2j ≥ 0 and here {u 1j , u 2j } h j=1 are fixed vectors that can be randomly selected. For instance, a common choice is u 1j ∼ N (0, I n ) and u 2j ∼ N (0, I d ). For the rest of the derivations, we use this equivalent formulation of the activation function. In order to characterize the optimal layer weight explicitly, we next find the maximizers of the dual constraint as follows  v i 1 ij x ik 2 = max j∈[h] max l∈[c] max k∈[n] N i=1 v il 1 ij x ik 2 , ( ) where the upperbound is achieved when each w 1j has is a vector of zeros except a single one located at the index of maximum norm of weighted tokens. Based on the observation in (32), we can equivalently write the dual problem in (31) as follows 



Figure 2: Comparison of the convex and nonconvex models on the dataset extracted from pretrained BERT architecture in a student-teacher setting. Here, include two non-convex models, specifically standard self-attention model (Nonconvex-Standard) in (3) and alternative attention (Nonconvex-Alternative) in (9). Our convex training approach achieves significantly lower objective value and test error than the original nonconvex training.

Figure4: Comparison of the convex and nonconvex models on the modular division operation mod 97. Here, we train the networks to reach 99% test accuracy and show that our convex training approach exhibits a significantly faster convergence and lower test loss than the nonconvex training.

Figure5: Comparison of one-and two-layer transformer networks on the modular division task mod 97, where L denotes the number of layers in each model. We observe that introducing one more layer substantially improves the convergence speed of our convex formulation while it fails to make a noticeable impact on the nonconvex formulation.

Figure6: Comparison of one-and four-layer transformer networks on the modular division task mod 15. Here, one-layer networks fail to achieve 99% test accuracy however our convex training approach (light green) still generalizes better than the nonconvex training (light purple). We also show that perfect generalization in terms of accuracy can be achieved with four-layer networks.

), we can equivalently write the dual problem in (26) as followsmax ∀k ∈ [n], ∀l ∈ [c]. (28)Then directly following the steps in the proof of Theorem 1 yields the following convex optimization problemmin Z l ∈R n×d trace Z l X i , y il + β the scaling technique in Lemma 1 for {w 2j , w 3j } m j=1 w2j := α j w 2j , w3j := w 3j α j .Then, following the same steps in Lemma 1, (11) can be equivalently formulated as min again construct the Lagrangian function by introducing an additional variable ŷi ∈ R c , ∀i ∈ [N ] as followsmin ŷi∈R c ,w1j ∈∆ w2j 2 ≤1,w3j ∈R N i=1 L (ŷ i , y i ) + β h j=1 w 3j 1 s.t. ŷi = h j=1 σ w 1j X i w 2j w 3j , ∀i ∈ [N ]. X i w 2j v i w 3jMinimizing the Lagrangian L(•) yields the following dual problem of (11)max vi∈R c -L * {v i } w 1j X i w 2j ∞ ≤ β, ∀j ∈ [h],

∀k ∈ [n], ∀l ∈ [c], ∀j ∈ [h]. (33)Then directly following the steps in the proof of Theorem 2 yields the following convex optimization problem min Z jl ∈R n×d ij trace Z jl X i , y il

Notations.

Number of parameters and FLOPs for the convex and nonconvex models. Here, we use the following notations: n: # of tokens, d: embedding dimension, h: # of heads, and c: # of outputs.

Appendix Table of Contents

Then the corresponding dual problem is as followsIn order to characterize the optimal layer weight explicitly, we next find the maximizers of the dual constraint as followsBased on the equivalent formulation in (36), the dual problem in ( 35) can be equivalently written asThe rest of the derivations directly follows from the proof of Theorem 2 and yields the following result.Theorem A.1. Based on the characterization of the dual constraint in (36), the non-convex optimization problem (34) can be equivalently cast as the following convex optimization problem minRemark A.2. Instead of the non-convex formulation in (34), we can also start from the following formulation in this settingwhich also yields the convex formulation in (38).

