A THEORETICAL UNDERSTANDING OF SHALLOW VI-SION TRANSFORMERS: LEARNING, GENERALIZA-TION, AND SAMPLE COMPLEXITY

Abstract

Vision Transformers (ViTs) with self-attention modules have recently achieved great empirical success in many vision tasks. Due to non-convex interactions across layers, however, the theoretical learning and generalization analysis is mostly elusive. Based on a data model characterizing both label-relevant and label-irrelevant tokens, this paper provides the first theoretical analysis of training a shallow ViT, i.e., one self-attention layer followed by a two-layer perceptron, for a classification task. We characterize the sample complexity to achieve a zero generalization error. Our sample complexity bound is positively correlated with the inverse of the fraction of label-relevant tokens, the token noise level, and the initial model error. We also prove that a training process using stochastic gradient descent (SGD) leads to a sparse attention map, which is a formal verification of the general intuition about the success of attention. Moreover, this paper indicates that a proper token sparsification can improve the test performance by removing label-irrelevant and/or noisy tokens, including spurious correlations. Empirical experiments on synthetic data and CIFAR-10 dataset justify our theoretical results and generalize to deeper ViTs.

1. INTRODUCTION

As the backbone of Transformers (Vaswani et al., 2017) , the self-attention mechanism (Bahdanau et al., 2014) computes the feature representation by globally modeling long-range interactions within the input. Transformers have demonstrated tremendous empirical success in numerous areas, including nature language processing (Kenton & Toutanova, 2019; Radford et al., 2019; 2018; Brown et al., 2020) , recommendation system (Zhou et al., 2018; Chen et al., 2019; Sun et al., 2019) , and reinforcement learning (Chen et al., 2021; Janner et al., 2021; Zheng et al., 2022) . Starting from the advent of Vision Transformer (ViT) (Dosovitskiy et al., 2020 ), Transformer-based models (Touvron et al., 2021; Jiang et al., 2021; Wang et al., 2021; Liu et al., 2021a) gradually replace convolutional neural network (CNN) architectures and become prevalent in vision tasks. Various techniques have been developed to train ViT efficiently. Among them, token sparsification (Pan et al., 2021; Rao et al., 2021; Liang et al., 2022; Tang et al., 2022; Yin et al., 2022) removes redundant tokens (image patches) of data to improve the computational complexity while maintaining a comparable learning performance. For example, Liang et al. (2022) Under what conditions does a Transformer achieve satisfactory generalization? Some recent works analyze Transformers theoretically from the perspective of proved Lipschitz constant of self-attention (James Vuckovic, 2020; Kim et al., 2021) , properties of the neural tangent kernel (Hron et al., 2020; Yang, 2020) and expressive power and Turing-completeness (Dehghani et al., 2018; Yun et al., 2019; Bhattamishra et al., 2020a; b; Edelman et al., 2022; Dong et al., 2021; Likhosherstov et al., 2021; Cordonnier et al., 2019; Levine et al., 2020) with statistical guarantees (Snell et al., 2021; Wei et al., 2021) . Likhosherstov et al. (2021) showed a model complexity for the function approximation of the self-attention module. Cordonnier et al. (2019) provided sufficient and necessary conditions for multi-head self-attention structures to simulate convolution layers. None of these works, however, characterize the generalization performance of the learned model theoretically. Only Edelman et al. (2022) theoretically proved that a single self-attention head can represent a sparse function of the input with a sample complexity for a generalization gap between the training loss and the test loss, but no discussion is provided regarding what algorithm to train the Transformer to achieve a desirable loss. Contributions: To the best of our knowledge, this paper provides the first learning and generalization analysis of training a basic shallow Vision Transformer using stochastic gradient descent (SGD). This paper focuses on a binary classification problem on structured data, where tokens with discriminative patterns determine the label from a majority vote, while tokens with non-discriminative patterns do not affect the labels. We train a ViT containing a self-attention layer followed by a two-layer perceptron using SGD from a proper initial model. This paper explicitly characterizes the required number of training samples to achieve a desirable generalization performance, referred to as the sample complexity. Our sample complexity bound is positively correlated with the inverse of the fraction of label-relevant tokens, the token noise level, and the error from the initial model, indicating a better generalization performance on data with fewer label-irrelevant patterns and less noise from a better initial model. The highlights of our technical contributions include: First, this paper proposes a new analytical framework to tackle the non-convex optimization and generalization for shallow ViTs. Due to the more involved non-convex interactions of learning parameters and diverse activation functions across layers, the ViT model, i.e., a three-layer neural network with one self-attention layer, considered in this paper is more complicated to analyze than three-layer CNNs considered in Allen-Zhu et al. (2019a) ; Allen-Zhu & Li (2019) , the most complicated neural network model that has been analyzed so far for across-layer nonconvex interactions. We consider a structured data model with relaxed assumptions from existing models and establish a new analytical framework to overcome the new technical challenges to handle ViTs. Second, this paper theoretically depicts the evolution of the attention map during the training and characterizes how "attention" is paid to different tokens during the training. Specifically, we show that under the structured data model, the learning parameters of the self-attention module grow in the direction that projects the data to the label-relevant patterns, resulting in an increasingly sparse attention map. This insight provides a theoretical justification of the magnitude-based token pruning methods such as (Liang et al., 2022; Tang et al., 2022) for efficient learning. Third, we provide a theoretical explanation for the improved generalization using token sparsification. We quantitatively show that if a token sparsification method can remove class-irrelevant and/or highly noisy tokens, then the sample complexity is reduced while achieving the same testing accuracy. Moreover, token sparsification can also remove spurious correlations to improve the testing accuracy (Likhomanenko et al., 2021; Zhu et al., 2021a) . This insight provides a guideline in designing token sparsification and few-shot learning methods for Transformer (He et al., 2022; Guibas et al., 2022) .

1.1. BACKGROUND AND RELATED WORK

Efficient ViT learning. To alleviate the memory and computation burden in training (Dosovitskiy et al., 2020; Touvron et al., 2021; Wang et al., 2022) , various acceleration techniques have been developed other than token sparsification. Zhu et al. (2021b) identifies the importance of different dimensions in each layer of ViTs and then executes model pruning. Liu et al. (2021b) ; Lin et al. (2022) ; Li et al. (2022d) quantize weights and inputs to compress the learning model. Li et al. (2022a) studies automated progressive learning that automatically increases the model capacity onthe-fly. Moreover, modifications of attention modules, such as the network architecture based on local attention (Wang et al., 2021; Liu et al., 2021a; Chu et al., 2021) , can simplify the computation of global attention for acceleration. Theoretical analysis of learning and generalization of neural networks. One line of research (Zhong et al., 2017b; Fu et al., 2020; Zhong et al., 2017a; Zhang et al., 2020a; b; Li et al., 2022c) analyzes the generalization performance when the number of neurons is smaller than the number of training samples. The neural-tangent-kernel (NTK) analysis (Jacot et al., 2018; Allen-Zhu et al., 2019a; b; Arora et al., 2019; Cao & Gu, 2019; Zou & Gu, 2019; Du et al., 2019; Chen et al., 2020; Li et al., 2022b) considers strongly overparameterized networks and eliminates the nonconvex interactions across layers by linearizing the neural network around the initialization. The generalization performance is independent of the feature distribution and cannot explain the advantages of selfattention modules. Neural network learning on structured data. Li & Liang (2018) provide the generalization analysis of a fully-connected neural network when the data comes from separated distributions. Daniely & Malach (2020) ; Shi et al. (2021) ; Karp et al. (2021) ; Brutzkus & Globerson (2021) ; Zhang et al. (2023) study fully connected networks and convolutional neural networks assuming that data contains discriminative patterns and background patterns. Allen-Zhu & Li (2022) illustrates the robustness of adversarial training by introducing the feature purification mechanism, in which neural networks with non-linear activation functions can memorize the data-dependent features. Wen & Li (2021) extends this framework to the area of self-supervised contrastive learning. All these works consider one-hidden-layer neural networks without self-attention. Notations: Vectors are in bold lowercase, and matrices and tensors are in bold uppercase. Scalars are in normal fonts. Sets are in calligraphy font. For instance, Z is a matrix, and z is a vector. z i denotes the i-th entry of z, and Z i,j denotes the (i, j)-th entry of Z. [K] (K > 0) denotes the set including integers from 1 to K. We follow the convention that f (x) = O(g(x)) (or Ω(g(x)), Θ(g(x))) means that f (x) increases at most, at least, or in the order of g(x), respectively.

2. PROBLEM FORMULATION AND LEARNING ALGORITHM

We study a binary classification problemfoot_0 following the common setup in (Dosovitskiy et al., 2020; Touvron et al., 2021; Jiang et al., 2021) . Given N training samples {(X n , y n )} N n=1 generated from an unknown distribution D and a fair initial model, the goal is to find an improved model that maps X to y for any (X, y) ∼ D. Here each data point contains L tokens x n 1 , x n 2 , • • • , x n L , i.e., X n = [x n 1 , • • • , x n L ] ∈ R d×L , where each token is d-dimensional and unit-norm. y n ∈ {+1, -1} is a scalar. A token can be an image patch (Dosovitskiy et al., 2020) . We consider a general setup that also applies to token sparsification, where some tokens are set to zero to reduce the computational time. Let S n ⊆ [L] denote the set of indices of remaining tokens in X n after sparsification. Then |S n | ≤ L, and S n = [L] without token sparsification. Learning is performed over a basic three-layer Vision Transformer, a neural network with a singlehead self-attention layer and a two-layer fully connected network, as shown in (1). This is a simplified model of practical Vision Transformers (Dosovitskiy et al., 2020) to avoid unnecessary complications in analyzing the most critical component of ViTs, the self-attention. F (X n ) = 1 |S n | l∈S n a ⊤ (l) Relu(W O W V X n softmax(X n⊤ W ⊤ K W Q x n l )), where the queue weights W Q in R m b ×d , the key weights W K in R m b ×d , and the value weights W V in R ma×d in the attention unit are multiplied with X n to obtain the queue vector W Q X n , the key vector W K X n , and the value vector W V X n , respectively (Vaswani et al., 2017) . W O is in R m×ma and A = (a (1) , a (2) , • • • , a L ) where a (l) ∈ R m , l ∈ [L] are the hidden-layer and output-layer weights of the two-layer perceptron, respectively. m is the number of neurons in the hidden layer. Relu : R m → R m where Relu(x) = max{x, 0}. softmax : R L → R L where softmax(x) = (e x1 , e x2 , • • • , e x L )/ L i=1 e xi . Let ψ = (A, W O , W V , W K , W Q ) denote the set of parameters to train. The training problem minimizes the empirical risk f N (ψ), min ψ : f N (ψ) = 1 N N n=1 ℓ(X n , y n ; ψ), where ℓ(X n , y n ; ψ) is the Hinge loss function, i.e., ℓ(X n , y n ; ψ) = max{1 -y n • F (X n ), 0}. The generalization performance of a learned model ψ is evaluated by the population risk f (ψ), where f (ψ) = f (A, W O , W V , W K , W Q ) = E (X,y)∼D [max{1 -y • F (X), 0}]. The training problem (2) is solved via a mini-batch stochastic gradient descent (SGD), as summarized in Algorithm 1. At iteration t, t = 0, 1, 2, • • • , T -1, the gradient is computed using a mini-batch B t with |B t | = B. The step size is η. Similar to (Dosovitskiy et al., 2020; Touvron et al., 2021; Jiang et al., 2021) , W V , W Q , and W 

3.1. MAIN THEORETICAL INSIGHTS

Before formally introducing our data model and main theory, we first summarize the major insights. We consider a data model where tokens are noisy versions of label-relevant patterns that determine the data label and label-irrelevant patterns that do not affect the label. α * is the fraction of labelrelevant tokens. σ represents the initial model error, and τ characterizes the token noise level. (P1). A Convergence and sample complexity analysis of SGD to achieve zero generalization error. We prove SGD with a proper initialization converges to a model with zero generalization error. The required number of iterations is proportional to 1/α * and 1/(Θ(1) -σ -τ ). Our sample complexity bound is linear in α -foot_1 * and (Θ(1) -σ -τ ) -2 . Therefore, the learning performance is improved, in the sense of a faster convergence and fewer training samples to achieve a desirable generalization, with a larger fraction of label-relevant patterns, a better initial model, and less token noise.

(P2).

A theoretical characterization of increased sparsity of the self-attention module during training. We prove that the attention weights, which are softmax values of each token in the self-attention module, become increasingly sparse during the training, with non-zero weights concentrated at label-relevant tokens. This formally justifies the general intuition that the attention layer makes the neural network focus on the most important part of data.

(P3).

A theoretical guideline of designing token sparsification methods to reduce sample complexity. Our sample complexity bound indicates that the required number of samples to achieve zero generalization can be reduced if a token sparsification method removes some label-irrelevant tokens (reducing α * ), or tokens with large noise (reducing σ), or both. This insight provides a guideline to design proper token sparsification methods.

(P4).

A new theoretical framework to analyze the nonconvex interactions in three-layer ViTs. This paper develops a new framework to analyze ViTs based on a more general data model than existing works like (Brutzkus & Globerson, 2021; Karp et al., 2021; Wen & Li, 2021) . Compared with the nonconvex interactions in three-layer feedforward neural networks, analyzing ViTs has technical challenges that the softmax activation is highly non-linear, and the gradient computation on token correlations is complicated. We develop new tools to handle this problem by exploiting structures in the data and proving that SGD iterations increase the magnitude of label-relevant tokens only rather than label-irrelevant tokens. This theoretical framework is of independent interest and can potentially applied to analyze different variants of Transformers and attention mechanisms.

3.2. DATA MODEL

There are M (2 < M < m a , m b ) distinct patterns {µ 1 , µ 2 , • • • , µ M } in R d , where µ 1 , µ 2 are discriminative patterns that determine the binary labels, and the remaining M -2 patterns µ 3 , µ 4 , • • • , µ M are non-discriminative patterns that do not affect the labels. Let κ = min 1≤i̸ =j≤M ∥µ i -µ j ∥ > 0 denote the minimum distance between patterns. Each token x n l of X n is a noisy version of one of the patterns, i.e., min j∈[M ] ∥x n l -µ j ∥ ≤ τ, and the noise level τ < κ/4. We take κ -4τ as Θ(1) for the simplicity of presentation. The label y n is determined by the tokens that correspond to discriminative patterns through a majority vote. If the number of tokens that are noisy versions of µ 1 is larger than the number of tokens that correspond to µ 2 in X n , then y n = 1. In this case that the label y n = 1, the tokens that are noisy µ 1 are refereed to as label-relevant tokens, and the tokens that are noisy µ 2 are referred to as confusion tokens. Similarly, if there are more tokens that are noisy µ 2 than those that are noisy µ 1 , the former are label-relevant tokens, the latter are confusion tokens, and y n = -1. All other tokens that are not label-relevant are called label-irrelevant tokens. Let α * and α # as the average fraction of the label-relevant and the confusion tokens over the distribution D, respectively. We consider a balanced dataset. Let D + = {(X n , y n )|y n = +1, n ∈ [N ]} and D -= {(X n , y n )|y n = -1, n ∈ [N ] } denote the sets of positive and negative labels, respec- tively. Then |D + | -|D -| = O( √ N ). Our model is motivated by and generalized from those used in the state-of-art analysis of neural networks on structured data (Li & Liang, 2018; Brutzkus & Globerson, 2021; Karp et al., 2021) . All the existing models require that only one discriminative pattern exists in each sample, i.e., either µ 1 or µ 2 , but not both, while our model allows both patterns to appear in the same sample.

3.3. FORMAL THEORETICAL RESULTS

Before presenting our main theory below, we first characterize the behavior of the initial model through Assumption 1. Some important notations are summarized in Table 1 .  V ∥, ∥W K ∥, ∥W Q ∥) ≤ 1 without loss of generality. There exist three (not necessarily different) sets of orthonormal bases P = {p 1 , p 2 , • • • , p M }, Q = {q 1 , q 2 , • • • , q M }, and R = {r 1 , r 2 , • • • , r M }, where p l ∈ R ma , q l , r l ∈ R m b , ∀l ∈ [M ], q 1 = r 1 , and q 2 = r 2 3 such that ∥W (0) V µ j -p j ∥ ≤ σ, ∥W K µ j -q j ∥ ≤ δ, and ∥W (0) Q µ j -r j ∥ ≤ δ. ( ) hold for some σ = O(1/M ) and δ < 1/2. Assumption 1 characterizes the distance of query, key, and value vectors of patterns {µ j } M j=1 to orthonormal vectors. The requirement on δ is minor because δ can be in the same order as ∥µ j ∥. Theorem 1 (Generalization of ViT). Suppose Assumption 1 holds; τ ≤ min(σ, δ); a sufficiently large model with m ≳ ϵ -2 M 2 log N for ϵ > 0, the average fraction of label-relevant patterns satisfies α * ≥ α # ϵ S e -(δ+τ ) (1 -(σ + τ )) , for some constant ϵ S ∈ (0, 1 2 ); and the mini-batch size and the number of sampled tokens of each data X n , n ∈ [N ] satisfy B ≥ Ω((α * -e -(δ+τ ) (τ + σ)) -2 ), |S n | ≥ Ω(1) (9) Then as long as the number of training samples N satisfies N ≥ Ω( 1 (α * -c ′ (1 -ζ) -c ′′ (σ + τ )) 2 ) ( ) for some constant c ′ , c ′′ > 0, and ζ ≳ 1 -η 10 , after T number of iterations such that T = Θ( 1 (1 -ϵ -(σ+τ )M π )ηα * ) ( ) with a probability at least 0.99, the returned model achieves zero generalization error as f (A (0) , W (T ) O , W (T ) V , W (T ) K , W (T ) Q ) = 0 (12) Theorem 1 characterizes under what condition of the data the neural network with self-attention in (1) trained with Algorithm 1 can achieve zero generalization error. To show that the self-attention layer can improve the generalization performance by reducing the required sample complexity to achieve zero generalization error, we also quantify the sample complexity when there is no selfattention layer in the following proposition. Proposition 1 (Generalization without self-attention). Suppose assumptions in Theorem 1 hold. When there is no self-attention layer, i.e., W K and W Q are not updated during the training, if N satisfies N ≥ Ω( 1 (α * (α * -σ -τ )) 2 ) ( ) then after T iterations with T in ( 11), the returned model achieves zero generalization error as f (A (0) , W (T ) O , W (T ) V , W K , W Q ) = 0 (14) Remark 1. (Advantage of the self-attention layer) Because m ≫ m a , m b , d, the number of trainable parameter remains almost the same with or without updating the attention layer. Combining Theorem 1 and Proposition 1, we can see that with the additional self-attention layer, the sample complexityfoot_3 is reduced by a factor 1/α 2 * with an approximately equal number of network parameters. Remark 2. (Generalization improvement by token sparsification). ( 10) and (11) show that the sample complexity N and the required number of iterations T scale with 1/α 2 * and 1/α * , respectively. Then, increasing α * , the fraction of label-relevant tokens, can reduce the sample complexity and speed up the convergence. Similarly, N and T scale with 1/(Θ(1) -τ ) 2 and 1/(Θ(1) -τ ). Then decreasing τ , the noise in the tokens, can also improve the generalization. Note that a properly designed token sparsification method can both increase α * by removing label-irrelevant tokens and decrease τ by removing noisy tokens, thus improving the generalization performance. Remark 3. (Impact of the initial model) The initial model W (0) V , W (0) K , W Q affects the learning performance through σ and δ, both of which decrease as the initial model is improved. Then from (10) and ( 11), the sample complexity reduces and the convergence speeds up for a better initial model. Proposition 2 shows that the attention weights are increasingly concentrated on label-relevant tokens during the training. Proposition 2 is a critical component in proving Theorem 1 and is of independent interest. Proposition 2. The attention weights for each token become increasingly concentrated on those correlated with tokens of the label-relevant pattern during the training, i.e., i∈S n * softmax(X n⊤ W (t) K ⊤ W (t) Q x n l ) i = i∈S n * exp(x n i ⊤ W (t) K ⊤ W (t) Q x n l ) r∈S n exp(x n r ⊤ W (t) K ⊤ W (t) Q x n l ) → 1 -η C (15) at a sublinear rate of O(1/t) when t is large for a large C > 0 and all l ∈ S n and n ∈ [N ]. Proposition 2 indicates that only label-relevant tokens are highlighted by the learned attention of ViTs, while other tokens have less weight. This provides a theoretical justification of magnitudebased token sparsification methods. softmax(•) i in (15) denotes the i-th entry of softmax(•). Proof idea sketch: The main proof idea is to show that the SGD updates scale up value, query, and key vectors of discriminative patterns, while keeping the magnitude of the projections of nondiscriminative patterns and the initial model error almost unchanged. To be more specific, by Lemma 3, 4, we can identify two groups of neurons in the hidden layer W O , where one group only learns the positive pattern, and the other group only learns the negative pattern. Claim 1 of Lemma 2 states that during the SGD updates, the neuron weights in these two groups evolve in the direction of projected discriminative patterns, p 1 and p 2 , respectively. Meanwhile, Claim 2 of Lemma 2 indicates that W K and W Q update in the direction of increasing the magnitude of the query and key vectors of label-relevant tokens from 1 to Θ(log T ), such that the attention weights correlated with label-relevant tokens gradually become dominant. Moreover, by Claim 3 of Lemma 2, the update of W V increases the magnitude of the value vectors of label-relevant tokens, by adding partial neuron weights of W O that are aligned with the value vectors to these vectors. Due to the above properties during the training, one can simplify the training process to show that the output of neural network (1) changes linearly in the iteration number t. From the above analysis, we can develop the sample complexity and the required number of iterations for the zero generalization guarantee. Technical novelty: Our proof technique is inspired by the feature learning technique in analyzing fully connect networks and convolution neural networks (Shi et al., 2021; Brutzkus & Globerson, 2021) . Our paper makes new technical contributions from the following aspects. First, we provide a new framework of studying the nonconvex interactions of multiple weight matrices in a three-layer ViT while other feature learning works (Shi et al., 2021; Brutzkus & Globerson, 2021; Karp et al., 2021; Allen-Zhu & Li, 2022; Wen & Li, 2021; Zhang et al., 2023) only study one trainable weight matrix in the hidden layer of a two-layer network. Second, we analyze the updates of the selfattention module with the softmax function during the training, while other papers either ignore this issue without exploring convergence analysis (Edelman et al., 2022) or oversimplify the analysis by applying the neural-tangent-kernel (NTK) method that considers impractical over-parameterization and updates the weights only around initialization. (Hron et al., 2020; Yang, 2020; Allen-Zhu et al., 2019a; Arora et al., 2019) . Third, we consider a more general data model, where discriminative patterns of multiple classes can exist in the same data sample, but the data models in (Brutzkus & Globerson, 2021; Karp et al., 2021) require one discriminative pattern only in each sample.

4.1. EXPERIMENTS ON SYNTHETIC DATASETS

We first verify the theoretical bounds in Theorem 1 on synthetic data. We set the dimension of data and attention embeddings to be d = m a = m b = 10. Let c 0 = 0.01. Let the total number of patterns M = 5, and {µ 1 , µ 2 , • • • , µ M } be a set of orthonormal bases. To satisfy Assumption 1, we generate every token that is a noisy version of µ i from a Gaussian distribution N (µ i , c 2 0 • I) with the mean µ i and covariance c 2 0 I, where I ∈ R d is the identity matrix. W (0) Q = W (0) Q = δ 2 I/c 2 0 , W (0) V = σ 2 U /c 2 0 , and each entry of W (0) O follows N (0, ξ 2 ) , where U is an m a × m a orthonormal matrix, and ξ = 0.01. The number of neurons m of W O is 1000. We set the ratio of different patterns the same among all the data for simplicity.

Sample complexity and convergence rate:

We first study the impact of the fraction of the labelrelevant patterns α * on the sample complexity. Let the number of tokens after sparsification be |S n | = 100, the initialization error σ = 0.1, and δ = 0.2. The fraction of non-discriminative patterns is fixed to be 0.5. We implement 20 independent experiments with the same α * and N and record the Hinge loss values of the testing data. An experiment is successful if the testing loss is smaller than 10 -3 . Figure 1 (a) shows the success rate of these experiments. A black block means that all the trials fail. A white block means that they all succeed. The sample complexity is indeed almost linear in α -2 * , as predicted in 10. We next explore the impact on σ. Set α * = 0.3 and α # = 0.2. The number of tokens after sparsification is fixed at 50 for all the data. Figure 1 In Figure 4 , the red line with asterisks shows that the sum of attention weights on label-relevant tokens, i.e., the left side of ( 15) averaged over all l, indeed increases to be close to 1 when the number of iterations increases. Correspondingly, the sum of attention weights on other tokens decreases to be close to 0, as shown in the blue line with squares. This verifies Lemma 2 on a sparse attention map. 

4.2. EXPERIMENTS ON IMAGE CLASSIFICATION DATASETS

Dataset: To characterize the effect of label-relevant and label-irrelevant tokens on generalization, following the setup of image integration in (Karp et al., 2021) , we adopt an image from CIFAR-10 dataset (Krizhevsky et al., 2010) as the label-relevant image pattern and integrate it with a noisy background image from the IMAGENET Plants synset (Karp et al., 2021; Deng et al., 2009) , which plays the role of label-irrelevant feature. Specifically, we randomly cut out a region with size 26×26 in the IMAGENET image and replace it with a resized CIFAR-10 image. Architecture: Experiments are implemented on a deep ViT model. Following (Dosovitskiy et al., 2020) , the network architecture contains 5 blocks, where we have a 4-head self-attention layer and a one-layer perceptron with skip connections and Layer-Normalization in each block. We first evaluate the impact on generalization of token sparsification that removes label-irrelevant patterns to increase α * . We consider a ten-classification problem where in both the training and testing datasets, the images used for integration are randomly selected from CIFAR-10 and IMA-GENET. The number of samples for training and testing is 50K and 10K, respectively. A pretrained model from CIFAR-100 (Krizhevsky et al., 2010) is used as the initial model with the output layer randomly initialized. Without token sparsification, the fraction of class-relevant tokens is α * ≈ 0.66. α * = 1 implies all background tokens are removed. Figure 6 (a) indicates that a larger α * by removing more label-irrelevant tokens leads to higher test accuracy. Moreover, the test performance improves with more training samples. These are consistent with our sample complexity analysis in (10). Figure 6 (b) presents the required sample complexity to learn a model with desirable test accuracy. We run 10 independent experiments for each pair of α * and N , and the experiment is considered a success if the learned model achieves a test accuracy of at least 77.5%. We then evaluate the impact of token sparsification on removing spurious correlations (Sagawa et al., 2020), as well as the impact of the initial model. We consider a binary classification problem that differentiates "bird" and "airplane" images. To introduce spurious correlations in the training data, 90% of bird images in the training data are integrated into the IMAGENET plant background, while only 10% of airplane images have the plant background. The remaining training data are integrated into a clean background by zero padding. Therefore, the label "bird" is spuriously correlated with the class-irrelevant plant background. The testing data contain 50% birds and 50% airplanes, and each class has 50% plant background and 50% clean background. The numbers of training and testing samples are 10K and 2K, respectively. We initialize the ViT using two pre-trained models. The first one is pre-trained with CIFAR-100, which contains images of 100 classes not including birds and airplanes. The other initial model is trained with a modified CIFAR-10 with 500 images per class for a total of eight classes, excluding birds and airplanes. The pre-trained model on CIFAR-100 is a better initial model because it is trained on a more diverse dataset with more samples. In Figure 6 (c), the token sparsification method removes the tokens of the added background, and the corresponding α * increases. Note that removing background in the training dataset also reduces the spurious correlations between birds and plants. Figure 6 (c) shows that from both initial models, the testing accuracy increases when more background tokens are removed. Moreover, a better initial model leads to a better testing performance. This is consistent with Remarks 2 and 3. 

5. CONCLUSION

This paper provides a novel theoretical generalization analysis of three-layer ViTs. Focusing on a data model with label-relevant and label-irrelevant tokens, this paper explicitly quantifies the sample complexity as a function of the fraction of label-relevant tokens and the token noise projected by the initial model. It proves that the learned attention map becomes increasingly sparse during the training, where the attention weights are concentrated on those of label-relevant tokens. Our theoretical results also offer a guideline on designing proper token sparsification methods to improve the test performance. This paper considers a simplified but representative Transformer architecture to theoretically examine the role of self-attention layer as the first step. The appendix contains 6 sections. We first provides a brief discussion about comparisons between our works and other two related works in Section A. In Section B, we add additional experiments for the verification of our theory. In Section C, we introduce some definitions and assumptions in accordance with the main paper for the ease of the proof in the following. Section D first states a core lemma for the proof, based on which we provide the proof of Theorem 1 and Proposition 1 and 2. Section E gives the proof of Lemma 2 with three subsections to prove its three main claims. Section F shows key lemmas and the proof of lemmas for this paper. We finally discuss the extension of our analysis in Section G, including extension to multi-classification cases, general data model cases, multi-head attention cases, and cases with skip connections in Section G.1, G.2, G.3, and G.4, respectively.

A COMPARISON WITH TWO RELATED WORKS

A.1 COMPARISON WITH (ALLEN-ZHU & LI, 2023) Allen -Zhu & Li (2023) studies ensemble learning and knowledge distillation. Its main proof idea is that given large amounts of multi-view data, each single model learns one feature, and then ensemble learning integrates all learned features and, thus, improves over single models. Knowledge distillation applies softmax logits to make use of information learned from the ensemble model. It is analyzed in a similar approach to studying single models. The single models considered in (Allen-Zhu & Li, 2023 ) is a two-layer Relu network. In contrast, in this paper, we consider a two-layer Relu network with an additional self-attention layer. The network architecture and training algorithm for the self-attention layer is completely different from those for the softmax logit in the knowledge distillation function. In our proof, we analyze the impact of the gradient of W Q , W K , and W V on different patterns (Claims 2 and 3 of Lemma 2), showing that the training process helps to enlarge the magnitude of label-relevant features. We also show that neurons in W O mainly learn from discriminative patterns (Claim 1 of Lemma 2). Such a learning process is affected by the error in the initial model and the noise in tokens. Please see details in "Proof idea sketch" and "Technical novelty" in Section 3.3 on Page 7 of the paper. This technique we develop plays a critical role in our analysis of self-attention layers. This technique is novel and did not appear in any existing works. A.2 COMPARISON WITH (JELASSI ET AL., 2022) Jelassi et al. ( 2022) is a concurrent work which theoretically studies Vision Transformers. The major difference between (Jelassi et al., 2022) and our work is that we consider different data models and network architectures. In (Jelassi et al., 2022) , the data model requires spatial association between tokens. The attention map is replaced with position encoding, and the training process of the attention map is simplified to train a linear layer. Our setup models the data mainly based on the category of patterns. We keep the classical structure and training process of self-attention, where W Q , W K , and W V are trained separately. The required number of samples and iterations are derived as functions of the fraction of label-relevant patterns. In addition, the non-linear activation function they consider is polynomial activation, instead of Relu or Gelu as in practice. Based on these conditions, they are able to study a different and mroe general labelling function.

B MORE EXPERIMENTS

Following a similar setup in Section 4, we add more experiments. For experiments on synthetic data, we set the dimension of data and attention embeddings to be d = m a = m b = 20. We vary the number of patterns M to be 10, 15, and 20. Data generation and the network architecture follow the setup in Section 4. One can observe the same trend in Figure 7 and 8 as in Figure 4 and 5, respectively, indicating that our conclusion that the attention map becomes sparse during the training, and that pruning label-irrelevant tokens or noisy tokens improves the performance, both hold for different choices of M . 

C PRELIMINARIES

We first formally restate the neural network with different notations of loss functions, and the Algorithm 1 of the training steps after token sparsification. The notations used in the Appendix is summarized in Table 2 . Table 2: Summary of notations F (X n ), Loss(X n , y n ) The network output for X n and the loss function of a single data. Loss b , Loss, Loss The loss function of a mini-batch, the empirical loss, and the population loss, respectively. pj(t), qj(t), rj(t) The features in value, key, and query vectors at the iteration t for pattern j, respectively. We have p j (0) = p j , q j (0) = q j , and r j (0) = r j . z n j (t), n n j (t), o n j (t) The error terms in the value, key, and query vectors of the j-th token and n-th data compared to their features at iteration t. W(t), U(t) The set of lucky neurons at t-th iterations. ϕn(t), νn(t), pn(t), λ The bounds of value of some attention weights at iteration t. λ is the threshold between inner products of tokens from the same pattern and different patterns. The mean of fraction of label-relevant tokens, confusion tokens, and non-discriminative tokens, respectively. For the network 5 F (X n ) = 1 |S n | l∈S n a ⊤ (l) Relu(W O W V X n softmax(X n⊤ W ⊤ K W Q x n l )) The loss function of a single data, a mini-batch, the empirical loss, and the population loss is defined in the following. Loss(X n , y n ) = max{1 -y n • F (X n ), 0} 5 Note that in our proof in the Appendix, we often use the notation softmax( x n i ⊤ W (t) K ⊤ W (t) Q x n l ), which is the same meaning as softmax(X n⊤ W (t) K ⊤ W (t) Q x n l )i. Loss b = 1 B n∈B b Loss(X n , y n ) (18) Loss = 1 N N n=1 Loss(X n , y n ) (19) Loss = E (X,y)∼D [Loss] The formal algorithm is as follows. We assume that each entry of from N (0, ξ 2 ), and every entry of a W (0) O is randomly initialized from N (0, ξ 2 ) where ξ = 1 √ M . Define that a (0) (l)i , i ∈ [m], l ∈ [L] is uniformly initialized from +{ 1 √ m , -1 √ m } (0) (l) from Uniform({+ 1 √ m , -1 √ m }). W (0) V , W (0) K and W (0) Q from a pre-trained model. 3: Stochastic Gradient Descent: for t = 0, 1, • • • , T -1 and W (t) ∈ {W (t) O , W (t) V , W (t) K , W (t) Q } W (t+1) = W (t) -η • 1 B n∈Bt ∇ W (t) ℓ(X n , y n ; W (t) O , W V , W K , W (t) Q ) (21) 4: Output: W (T ) O , W (T ) V , W (T ) K , W Q . Assumption 1 can be interpreted as that we initialize W V , W K , and W Q to be the matrices that can map tokens to orthogonal features with added error terms. Assumption 2. Define P = (p 1 , p 2 , • • • , p M ) ∈ R ma×M , Q = (q 1 , q 2 , • • • , q M ) ∈ R m b ×M and R = (r 1 , r 2 , • • • , r M ) ∈ R m b ×M as three feature matrices, where P = {p 1 , p 2 , • • • , p M }, Q = {q 1 , q 2 , • • • , q M } and R = {r 1 , r 2 , • • • , r M } are three sets of orthonormal bases. Define the noise terms z n j (t), n n j (t) and o n j (t) with ∥z n j (0)∥ ≤ σ + τ and ∥n n j (0)∥, ∥o n j (0 )∥ ≤ δ + τ for j ∈ [L]. q 1 = r 1 , q 2 = r 2 . Suppose ∥W (0) V ∥, ∥ W (0) K ∥, ∥W (0) Q ∥ ≤ 1, σ, τ < O(1/M ) and δ < 1/2. Then, for x n l ∈ S n j 1. W (0) V x n l = p j + z n j (0). 2. W (0) K x n l = q j + n n j (0). 3. W (0) Q x n l = r j + o n j (0). Assumption 2 is a straightforward combination of Assumption 1 and ( 5) by applying the triangle inequality to bound the error terms for tokens. Definition 1. 1. ϕ n (t) = 1 |S n 1 |e ∥q 1 (t)∥ 2 +(δ+τ )∥q 1 (t)∥ +|S n |-|S n 1 | . 2. ν n (t) = 1 |S n 1 |e ∥q 1 (t)∥ 2 -(δ+τ )∥q 1 (t)∥ +|S n |-|S n 1 | . 3. p n (t) = |S n 1 |e ∥q1(t)∥ 2 -(δ+τ )∥q1(t)∥ ν n (t). 4. S n * = S n 1 , if y n = 1 S n 2 , if y n = -1 , S n # = S n 2 , if y n = 1 S n 1 , if y n = -1 5. α * = E |S n * | |S n | , α # = E |S n # | |S n | , α nd = M l=3 E |S n l | |S n | . Definition 2. Let θ i 1 be the angle between p 1 and W O (i,•) . Let θ i 2 be the angle between p 2 and W O (i,•) . Define W(t), U(t) as the sets of lucky neurons at the t-th iteration such that W(t) = {i : θ i 1 ≤ σ + τ, i ∈ [m]} (22) U(t) = {i : θ i 2 ≤ σ + τ, i ∈ [m]} (23) Assumption 3. For one data X n , if the patch i and j correspond to the same feature k ∈ [M ], i.e., i ∈ S n k and j ∈ S n k , we have x i n⊤ x n j ≥ 1 (24) If the patch i and j correspond to the different feature k, l ∈ [M ], k ̸ = l i.e., i ∈ S n k and j ∈ S n l , k ̸ = l, we have x i n⊤ x n j ≤ λ < 1 ( ) This assumption is equivalent to the data model by ( 5) since τ < O(1/M ). For the simplicity of presentation, we scale up all tokens a little bit to make the threshold of linear separability be 1. We also take 1 -λ and λ as Θ(1) for the simplicity. Definition 3. (Vershynin, 2010) We say X is a sub-Gaussian random variable with sub-Gaussian norm K > 0, if (E|X| p ) 1 p ≤ K √ p for all p ≥ 1. In addition, the sub-Gaussian norm of X, denoted ∥X∥ ψ2 , is defined as ∥X∥ ψ2 = sup p≥1 p -1 2 (E|X| p ) 1 p . Lemma 1. (Vershynin (2010) Proposition 5.1, Hoeffding's inequality) Let X 1 , X 2 , • • • , X N be in- dependent centered sub-gaussian random variables, and let K = max i ∥X i ∥ ψ2 . Then for every a = (a 1 , • • • , a N ) ∈ R N and every t ≥ 0, we have P N i=1 a i X i ≥ t ≤ e • exp(- ct 2 K 2 ∥a∥ 2 ) ( ) where c > 0 is an absolute constant.

D PROOF OF THE MAIN THEOREM AND PROPOSITIONS

We state Lemma 2 first before we introduce the proof of main theorems. Lemma 2 is the key lemma in our paper to show the training process of our ViT model using SGD. It has three major claims. Claim 1 involves the growth of W Lemma 2. For l ∈ S n 1 for the data with y n = 1, define V n l (t) =W (t) V X n softmax(X n⊤ W (t) K ⊤ W (t) Q x n l ) = s∈S1 softmax(x n s ⊤ W (t) K ⊤ W (t) Q x n l )p 1 + z(t) + j̸ =1 W n j (t)p j -η t b=1 ( i∈W(b) V i (b)W (b) O (i,•) ⊤ + i / ∈W(b) V i (b)λW (b) O (i,•) ⊤ ) (27) with W n l (t) ≤ ν n (t)|S n j | (28) V i (t) ≲ 1 2B n∈B b + - |S n 1 | mL p n (t), i ∈ W(t)] V i (t) ≳ 1 2B n∈B b - |S n 2 | mL p n (t), i ∈ U(t) (30) V i (t) ≥ - 1 √ Bm , if i is an unlucky neuron. ( ) We also have the following claims: Claim 1. For the lucky neuron i ∈ W(t) and b ∈ [T ], we have W (t) O (i,•) p 1 ≳ 1 Bt t b=1 n∈B b η 2 tbm |S n |a 2 |S n 1 |∥p 1 ∥ 2 p n (b) + ξ(1 -(σ + τ )) W (t) O (i,•) p ≤ ξ∥p∥, for p ∈ P/p 1 , ( 1 Bt t b=1 n∈B b η 2 tbm|S n 1 | a 2 |S n | ∥p∥ 2 p n (b)) 2 ≤ ∥W (t) O (i,•) ∥ 2 ≤ M ξ 2 ∥p∥ 2 + ( η 2 t 2 m a 2 ) 2 ∥p∥ 2 (34) and for the noise z l (t), ∥W O (i,•) z l (t)∥ ≤ ((σ + τ ))( √ M ξ + η 2 t 2 m a 2 )∥p∥ For i ∈ U(t), we also have equations as in ( 32) to ( 35), including W (t) O (i,•) p 2 ≳ 1 Bt t b=1 n∈B b η 2 tbm|S n 2 | |S n |a 2 ∥p 2 ∥ 2 p n (b) + ξ(1 -(σ + τ )) W (t) O (i,•) p ≤ ξ∥p∥, for p ∈ P/p 2 , ( ) ( 1 Bt t b=1 n∈B b η 2 tb|S n 2 |m a 2 |S n | ∥p∥ 2 p n (b)) 2 ≤ ∥W (t) O (i,•) ∥ 2 ≤ M ξ 2 ∥p∥ 2 + ( η 2 t 2 m a 2 ) 2 ∥p∥ 2 (38) and for the noise z l (t), ∥W O (i,•) z l (t)∥ ≤ ((σ + τ ))( √ M ξ + η 2 t 2 m a 2 )∥p∥ For unlucky neurons, we have W (t) O (i,•) p ≤ ξ∥p∥, for p ∈ P/{p 1 , p 2 } (40) ∥W (t) O (i,•) z l (t)∥ ≤ ((σ + τ )) √ M ξ∥p∥ (41) ∥W (t) O (i,•) ∥ 2 ≤ M ξ 2 ∥p∥ 2 (42) Claim 2. Given conditions in (8), there exists K(t), Q(t) > 0, where t is large enough before the end of training, such that for j ∈ S n * , softmax(x n j ⊤ W (t+1) K W (t+1) Q x n l ) ≳ e (1+K(t))∥q1(t)∥ 2 -(δ+τ )∥q1(t)∥ |S n 1 |e (1+K(t))∥q1(t)∥ 2 -(δ+τ )∥q1(t)∥ + (|S n | -|S n 1 |) (43) softmax(x n j ⊤ W (t+1) K ⊤ W (t+1) Q x n j ) -softmax(x n j ⊤ W (t) K ⊤ W (t) Q x n l ) ≳ |S n | -|S n 1 | (|S n 1 |e (1+K(t))∥q1(t)∥ 2 -(δ+τ )∥q1(t)∥ + (|S n | -|S n 1 |)) 2 e ∥q1(t)∥ 2 -(δ+τ )∥q1(t)∥ • K(t), and for j / ∈ S n * , we have softmax(x n j ⊤ W (t+1) K ⊤ W (t+1) Q x n l ) ≲ 1 |S n 1 |e (1+K(t))∥q1(t)∥ 2 -δ∥q1(t)∥ + (|S n | -|S n 1 |) (45) softmax(x n j ⊤ W (t+1) K W (t+1) Q x n l ) -softmax(x n j ⊤ W (t) K ⊤ W (t) Q x n l ) ≲ - |S n 1 | (|S n 1 |e (1+K(t))∥q1(t)∥ 2 -δ∥q1(t)∥ + (|S n | -|S 1 |)) 2 e ∥q1(t)∥ 2 -δ∥q1(t)∥ • K(t) For i = 1, 2, q i (t) = t-1 l=0 (1 + K(l))q i (47) r i (t) = t-1 l=0 (1 + Q(l))r i ( ) Claim 3. For the update of W (t) V , there exists λ ≤ Θ(1) such that W (t) V x n j = p 1 -η t b=1 ( i∈W(b) V i (b)W (b) O (i,•) ⊤ + i / ∈W(b) λV i (b)W (b) O (i,•) ⊤ ) + z j (t), j ∈ S n 1 (49) W (t) V x n j = p 1 -η t b=1 ( i∈U (b) V i (b)W (b) O (i,•) ⊤ + i / ∈U (b) λV i (b)W (b) O (i,•) ⊤ ) + z j (t), j ∈ S n 2 (50) W (t+1) V x n j = p 1 -η t b=1 m i=1 λV i (b)W (b) O (i,•) ⊤ + z j (t), j ∈ [|S n |]/(S n 1 ∪ S n 2 ) (51) ∥z j (t)∥ ≤ (σ + τ ) To prove Theorem 1, we either show F (X n ) > 1 for y n = 1 or show F (X n ) < -1 for y n = -1. Take y n = 1 as an example, the basic idea of the proof is to make use of Lemma 2 to find a lower bound as a function of α * , σ, τ , etc.. The remaining step is to derive conditions on the sample complexity and the required number of iterations in terms of α * , σ, and τ such that the lower bound is greater than 1. Given a balanced dataset, these conditions also ensure that F (X n ) < -1 for y n = -1. During the proof, we may need to use some of equations as intermediate steps in the proof of Lemma 2. Since that these equations are not concise for presentation, we prefer not to state them formally in Lemma 2, but still refer to them as useful conclusions. The following is the details of the proof. Proof of Theorem 1: For y n = 1, define K l + = {i ∈ [m] : W (t) O (i,•) V n l (t) ≥ 0} and K l -= {i ∈ [m] : W (t) O (i,•) V n l (t) < 0}. We have F (X n ) = 1 |S n | l∈S n i∈W(t) 1 m Relu(W O (i,•) V n l (t)) + 1 |S n | l∈S n i∈K l + /W(t) 1 m Relu(W O (i,•) V n l (t)) - 1 |S n | l∈S n i∈K l - 1 m Relu(W (t) O (i,•) V n l (t)) By Lemma 2, we have 1 |S n | l∈S n i∈W(t) 1 m Relu(W (t) O (i,•) V n l (t)) = 1 |S n | l∈S n 1 i∈W(t) 1 m Relu(W (t) O (i,•) V n l (t)) + l / ∈S n 1 i∈W(t) 1 m Relu(W (t) O (i,•) V n l (t)) ≳|S n 1 | 1 a|S n | • W (t) O (i,•) s∈S n 1 p s softmax(x n s ⊤ W (t) K ⊤ W (t) Q x n l ) + z(t) + l̸ =s W l (u)p l -ηt( j∈W(t) V j (t)W (t) O (j,•) ⊤ + j / ∈W(t) V j (t)λW (t) O (j,•) ⊤ ) |W(t)| + 0 ≳ |S n 1 |m |S n |aM (1 -ϵ m - (σ + τ )M π ) 1 Bt t b=1 n∈B b η 2 t 2 m a 2 ( b|S n * | t|S n | ∥p 1 ∥ 2 p n (b) -(σ + τ ))p n (t) + 1 Bt t b=1 n∈B b |S n 1 |p n (b)ηtm |S n |aM (1 -ϵ m - (σ + τ )M π ) • ( 1 Bt t b=1 n∈B b η 2 tbm|S n 1 | a 2 |S n | p n (b)) 2 ∥p 1 ∥ 2 p n (t) where the second step comes from ( 27) and the last step is by (136) . By the definition of K l + , we have 1 |S n | l∈S n i∈K l + /W(t) 1 m Relu(W (t) O (i,•) V n l (t)) ≥ 0 (55) Combining ( 136) and ( 138), we can obtain 1 |S n | l∈S n i∈K l - 1 m Relu(W (t) O (i,•) V n l (t)) ≤ |S n 2 |m |S n |aM • (1 -ϵ m - (σ + τ )M π )(ξ∥p∥ + 1 Bt t b=1 n∈B b |S n 1 |p n (b)m |S n |aM • (1 -ϵ m - (σ + τ )M π )( η 2 t 2 m a 2 ) 2 (σ + τ )∥p∥ + η 2 λ √ M ξt 2 m √ Ba 2 ∥p 1 ∥ 2 + ((σ + τ ))( √ M ξ + η 2 t 2 m a 2 ) + 1 Bt t b=1 n∈B b |S n 2 |p n (b)ηtm |S n |aM • (1 -ϵ m - (σ + τ )M π )( η 2 t 2 m a 2 ) 2 ∥p 1 ∥ 2 )ϕ n (t)|S n 2 | + M l=3 |S n l | |S n | (ξ∥p∥ + η 2 λ √ M ξt 2 m √ Ba 2 ∥p∥ 2 + ((σ + τ ))( √ M ξ + η 2 t 2 m a 2 )∥p∥ + 1 Bt t b=1 n∈B b (|S n 2 | + |S n 1 |)p n (t)ηtm |S n |aM (1 -ϵ m - (σ + τ )M π ) • η 2 t 2 m a 2 ξ∥p 1 ∥ 2 )ϕ n (t)(|S n | -|S n 1 |) + |S n 1 | |S n | c M 1 1 Bt t b=1 n∈B b η 2 tbm|S n * | a 2 |S n | ∥p 1 ∥ 2 p n (t) • |(σ + τ ) -p n (t)| + 1 Bt t b=1 n∈B b |S n 1 |p n (t)ηtm |S n |aM c M 2 ( 1 Bt t b=1 n∈B b η 2 tbm|S n 1 | a 2 |S n | • p n (b)) 2 ∥p 1 ∥ 2 p n (t) (56) for some c 1 , c 2 ∈ (0, 1).

Note that at the

T -th iteration, K(t) ≳η 1 Bt t b=1 n∈B b |S n 1 |p n (t)ηtm |S n |aM (1 -ϵ m - (σ + τ )M π )( 1 Bt t b=1 n∈B b η 2 tbm|S n 1 | a 2 |S n | p n (b)) 2 ∥p 1 ∥ 2 p n (t) + 1 Bt t b=1 n∈B b η 2 tbm a 2 ( |S n 1 | |S n | p n (b) -(σ + τ ))∥p 1 ∥ 2 p n (t) ϕ n (t)(|S n | -|S n 1 |)∥q 1 (t)∥ 2 ≳ η e ∥q1(t)∥ 2 -(δ+τ )∥q1(t)∥ (57) Since that q 1 (T ) ≳ (1 + min l=0,1,••• ,T -1 {K(l)}) T ≳ (1 + η e ∥q1(T )∥ 2 -(δ+τ )∥q1(T )∥ ) T (58) To find the order-wise lower bound of q 1 (T ), we need to check the equation q 1 (T ) ≲ (1 + 1 e ∥q1(T )∥ 2 -(δ+τ )∥q1(T )∥ ) T (59) One can obtain Θ(log T ) = ∥q 1 (T )∥ 2 = o(T ) Therefore, p n (T ) ≳ T C T C + 1-α α ≥ 1 - 1 α 1-α (η -1 ) C ≥ 1 -Θ(η C ) (61) ϕ n (T )(|S n | -|S n 1 |) ≤ η C (62) for some large C > 0. We require that (64) |S n 1 |m |S n |aM (1 -ϵ m - (σ + τ )M π ) 1 BT T b=1 n∈B b |S n 1 |p n (b)ηtm |S n |aM (1 -ϵ m - (σ + τ )M π ) • ( 1 BT T b=1 n∈B b η 2 tbm|S n 1 | a 2 |S n | p n (b)) 2 ∥p 1 ∥ 2 p n (b) + 1 BT T b=1 n∈B b η 2 t 2 m a 2 ( b|S n * | t|S n | ∥p 1 ∥ 2 p n (b) -(σ + τ ))p n (b) ≳ |S n 1 |m |S n |aM (1 -ϵ m - (σ + τ )M π ) 1 N N n=1 |S n 1 |p n (T )ηT m |S n |aM (1 -ϵ m - (σ + τ )M π ) • ( 1 N N n=1 η 2 T 2 m|S n 1 | a 2 |S n | • p n (T )) 2 ∥p 1 ∥ 2 p n (T ) + 1 N N n=1 η 2 T 2 m a 2 ( |S n * | |S n | ∥p 1 ∥ 2 p n (T ) -(σ + τ ))p n (T ) :=a 0 (ηT ) 5 + a 1 (ηT ) 2 >1, We also require η 2 λ √ M ξt 2 m √ Ba 2 ≤ ϵ 0 , for some ϵ 0 > 0. We know that 1 N N n=1 |S n * | |S n | p n (T )(p n (T ) -(σ + τ )) -E |S n * | |S n | ≤ 1 N N n=1 |S n * | |S n | p n (T )(p n (T ) -(σ + τ )) -E |S n * | |S n | p n (T )(p n (T ) -(σ + τ )) + E |S n * | |S n | p n (T )(p n (T ) -(σ + τ )) -1 ≲ log N N + c ′ (1 -ζ) + c ′′ ((σ + τ )) for c ′ > 0 and c ′′ > 0. We can then have t ≥ T = η -1 a 1 = η -1 |S n 1 | |S n | (1 -ϵ m -(σ+τ )M π ) 1 N N n=1 ( |S n * | |S n | ∥p 1 ∥ 2 p n (t) -(σ + τ ))p n (t) =Θ( η -1 (1 -ϵ m -(σ+τ )M π )(E |S n * | |S n | -log N N -c ′ (1 -ζ) -c ′′ (σ + τ )) ) =Θ( η -1 (1 -ϵ m -(σ+τ )M π )E |S n * | |S n | ) (67) where α ≥ 1 -α nd 1 + ϵ S e -(δ+τ ) (1 -(σ + τ )) (68) by ( 157), as long as N ≥ Ω( 1 (α -c ′ (1 -ζ) -c ′′ ((σ + τ ))) 2 ) ( ) and B ≳ 1 ((1 -(τ + σ))e -(δ+τ ) 1 2 α * -(τ + σ)e -(δ+τ ) ) 2 = Θ( 1 (α * -e -(δ+τ ) (τ + σ)) 2 ) = N Θ(1) (70) where ζ ≥ 1 -η 10 . If there is no mechanism like the self-attention to compute the weight using the correlations between tokens, we have c ′ (1 -ζ) = O(α * (1 -α * )), which can scale up the sample complexity in ( 69) by α -2 * . Therefore, we can obtain F (X n ) > 1 (72) Similarly, we can derive that for y = -1, F (X) < -1 (73) Hence, for all n ∈ [N ], Loss(X n , y n ) = 0 We also have Loss = E (X n ,y n )∼D [Loss(X n , y n )] = 0 (75) with the conditions of sample complexity and the number of iterations.

Proof of Proposition 1:

The main proof is the same as the proof of Theorem 1. The only difference is that we need to modify (66) as follows 1 N N n=1 |S n * | |S n | p n (T )(p n (T ) -(σ + τ )) -E |S n * | |S n | ≤ 1 N N n=1 |S n * | |S n | p n (0)(p n (0) -(σ + τ )) -E |S n * | |S n | p n (0)(p n (T ) -(σ + τ )) + E |S n * | |S n | p n (0)(p n (0) -(σ + τ )) -1 ≲ log N N + |1 -Θ(α 2 * ) + Θ(α * )(σ + τ )| (76) where the first step is because p n (T ) does not update since W (t) K and W (t) Q are fixed at initialization W (0) K and W (0) Q , and the second step is by p n (0) = Θ(α * ). Since that log N N + |1 -Θ(α 2 * ) + Θ(α * )(σ + τ )| ≤ Θ(1) • α * , we have N ≥ 1 (Θ(α * ) -1 + Θ(α 2 * ) -Θ(α * )(σ + τ )) 2 = Ω( 1 (α * (α * -σ -τ )) 2 ) (78)

Proof of Proposition 2:

It can be easily derived from Claim 2 of Lemma 2, (60), and ( 61).

E PROOF OF LEMMA 2

We prove the whole lemma by a long induction, which is the reason why we prefer to wrap three claims into one lemma. To make it easier to follow, however, we break this Section into three parts to introduce the proof of three claims of Lemma 2 separately.

E.1 PROOF OF CLAIM 1 OF LEMMA 2

Although it looks cumbersome, the key idea of Claim 1 is to characterize the growth of W (t) O in terms of p l , l ∈ [M ]. We compare W (t+1) O (i,•) p l and W (t) O (i,•) p l to see the direction of growth by computing the gradient. One can eventually find that lucky neurons grow the most in directions of p 1 and p 2 , i.e., the feature of label-relevant patterns, while unlucky neurons do not change much in magnitude. We start our poof. At the t-th iteration, if l ∈ S n 1 , let V n l (t) = W (t) V X n softmax(X n⊤ W (t) K ⊤ W (t) Q x n l ) = s∈S1 softmax(x n s ⊤ W (t) K ⊤ W (t) Q x n l )p 1 + z(t) + j̸ =1 W n j (t)p j -η( i∈W(t) V i (t)W (t) O (i,•) ⊤ + i / ∈W(t) V i (t)λW (t) O (i,•) ⊤ ) (79) , l ∈ [M ], where the second step comes from ( 27). Then we have W n l (t) ≤ |S n j |e δ∥q1(t)∥ (|S n | -|S n 1 |)e δ∥q1(t)∥ + |S n 1 |e ∥q1(t)∥ 2 -δ∥q1(t)∥ = ν n (t)|S n j | Hence, by ( 18), ∂Loss b ∂W O (i,•) ⊤ = - 1 B n∈B b y n 1 |S n | l∈S n a (l)i 1[W O (i,•) V n l (t) ≥ 0]V n l (t) ⊤ ( ) Define that for j ∈ [M ], I 4 = 1 B n∈B b ηy n 1 |S n | l∈S n a (l)i 1[W (t) O (i,•) V n l (t) ≥ 0] k∈W(t) V k (t)W (t) O (k,•) p j ( ) I 5 = 1 B n∈B b ηy n 1 |S n | l∈S n a (l)i 1[W (t) O (i,•) V n l (t) ≥ 0] k / ∈W(t) V k (t)W (t) O (k,•) p j , and we can then obtain W (t+1) O (i,•) ⊤ , p j -W (t) O (i,•) ⊤ , p j = 1 B n∈B b ηy n 1 |S n | l∈S n a (l)i 1[W (t) O (i,•) V n l (t) ≥ 0]V n l (t) ⊤ p j = 1 B n∈B b ηy n 1 |S n | l∈S n a (l)i 1[W (t) O (i,•) V n l (t) ≥ 0]z l (t) ⊤ p j + 1 B n∈B b ηy n 1 |S n | l∈S n a (l)i 1[W (t) O (i,•) V n l (t) ≥ 0] s∈S l softmax(x ⊤ s W (t) K ⊤ W (t) Q x l )p ⊤ l p j + 1 B n∈B b ηy n 1 |S n | l∈S n a (l)i 1[W (t) O (i,•) V n l (t) ≥ 0] k̸ =l W l (t)p ⊤ k p j + I 4 + I 5 :=I 1 + I 2 + I 3 + I 4 + I 5 , where I 1 = 1 B n∈B b ηy n 1 |S n | l∈S n a (l)i 1[W (t) O (i,•) V n l (t) ≥ 0]z l (t) ⊤ p j (85) I 2 = 1 B n∈B b ηy n 1 |S n | l∈S n a (l)i 1[W (t) O (i,•) V n l (t) ≥ 0] s∈S l softmax(x ⊤ s W (t) K ⊤ W (t) Q x l )p ⊤ l p j ( ) I 3 = 1 B n∈B b ηy n 1 |S n | l∈S n a (l)i 1[W (t) O (i,•) V n l (t) ≥ 0] k̸ =l W l (t)p ⊤ k p j ( ) We then show the statements in different cases. (1) When j = 1, since that Pr(y n = 1) = Pr(y n = -1) = 1/2, by Hoeffding's inequality in ( 26), we can derive Pr 1 B n∈B b y n ≥ log B B ≤ B -c (88) Pr z l (t) ⊤ p 1 ≥ ((σ + τ )) 2 log m ≤ m -c (89) Hence, with probability at least 1 -(mB) -c , we have |I 1 | ≤ η((σ + τ )) a log m log B B ( ) For i ∈ W(t), from the derivation in ( 133) later, we have W (t) O (i,•) L s=1 W (t) V x n s softmax(x n s ⊤ W (t) K ⊤ W (t) Q x n l ) > 0 (91) Denote p n (t) = |S n 1 |ν n (t)e ∥q1(t)∥ 2 -2δ∥q1(t)∥ . Hence, I 2 ≳ η • 1 B n∈B b |S n 1 | -|S n 2 | |S n | • 1 a ∥p 1 ∥ 2 • p n (t) ≳ η 1 B n∈B b |S n 1 | |S n | • 1 a ∥p 1 ∥ 2 • p n (t) I 3 = 0 (93) I 4 ≳ 1 B t b=1 n∈B b η 2 b|S n 1 | |S n |a 1 2B n∈B b |S n 1 |m |S n |aM p n (t)∥p 1 ∥ 2 (1 -ϵ m - (σ + τ )M π )W O (i,•) p 1 (94) |I 5 | ≲ 1 B T b=1 n∈B b η 2 b|S n 1 | |S n |a (1 -ϵ m - (σ + τ )M π ) 1 2B n∈B b |S n 2 |m |S n |aM p n (t)∥p 1 ∥ 2 W O (i,•) p 2 + η 2 tm √ Ba 2 W O (i,•) p M (1 + (σ + τ )) Hence, combining ( 90), ( 92), ( 93), ( 94), and (95), we can obtain W (t+1) O (i,•) ⊤ , p 1 -W (t) O (i,•) ⊤ , p 1 ≳ η a • 1 B n∈B b ( |S n 1 | |S n | p n (t) -((σ + τ )) + ηt|S n 1 | |S n | 1 2B n∈B b |S n 1 |m |S n |aM p n (t)(1 -ϵ m - (σ + τ )M π ) • W O (i,•) p 1 (1 -(σ + τ )) - ηt|S n 1 | |S n | 1 2B n∈B b |S n 2 |m |S n |aM p n (t)(1 -ϵ m - (σ + τ )M π ) • W O (i,•) p 2 (1 + (σ + τ )) - ηtmW O (i,•) p M (1 + (σ + τ )) √ Ba )∥p 1 ∥ 2 ≳ η aB n∈B b ( |S n 1 | |S n | p n (t) -((σ + τ )) + ηt|S n 1 | |S n | 1 2B n∈B b |S n 1 |m |S n |aM p n (t) • (1 -ϵ m - (σ + τ )M π )W O (i,•) p 1 )∥p 1 ∥ 2 (96) Since that W (0) O (i,•) ∼ N (0, ξ 2 I ma ) , by the standard property of Gaussian distribution, we have Pr(∥W (0) O (i,•) ∥ ≤ ξ) ≤ ξ (97) Therefore, with high probability for all i ∈ [m], we have ∥W (0) O (i,•) ∥ ≳ ξ (98) Therefore, we can derive W (t+1) O (i,•) p 1 ≳ exp( 1 B(t + 1) t+1 b=1 n∈B b η 2 b(t + 1)m |S n |a 2 |S n 1 |∥p 1 ∥ 2 p n (b)) + ξ(1 -(σ + τ )) ≳ exp( 1 B n∈B b η 2 (t + 1) 2 m |S n |a 2 |S n 1 |∥p 1 ∥ 2 p n (t)) + ξ(1 -(σ + τ )) by verifying that η a + η 2 tm a 2 exp(( 1 Θ(1) • η 2 t 2 m a 2 ) -1 + ξ) ≥ exp( 1 Θ(1) • η 2 t 2 m a 2 )(exp( η 2 (2t + 1)m Θ(1) • a 2 ) -1) ≳ exp( 1 Θ(1) • η 2 t 2 m a 2 ) η 2 tm 2a 2 (100) When ηt < a m , we have η a + η 2 tm a 2 (-1 + ξ) ≥ 0 (101) When ηt ≥ a m , we have that g(t) := η 2 tm a 2 ( 1 2 exp( η 2 t 2 m a 2 Θ(1) ) -1 + ξ) + η a ≥ g( a ηm ) > 0 (102) since that g(t) is monotonically increasing. Hence, (99) is verified. Since that ηt ≤ O(1), (103) to simplify the further analysis, we will use the bound W (t+1) O (i,•) p 1 ≳ 1 Bt t+1 b=1 n∈B b η 2 (t + 1)bm |S n |a 2 |S n 1 |∥p 1 ∥ 2 p n (b) + ξ(1 -(σ + τ )) Note that this bound does not order-wise affect the final result of the required number of iterations. (2) When p j ∈ P/p + , we have I 2 = 0 (105) |I 3 | ≤ 1 B n∈B b ν n (t) η|S n l | a log m log B B ∥p∥ 2 (106) |I 4 | ≤ η 2 a t b=1 log m log B B 1 2B t b=1 n∈B b |S n 1 |ηbm |S n |aM p n (b)( (ηt) 2 m a 2 + ξ)∥p∥ (107) |I 5 | ≲ η 2 tm √ Ba 2 ξ∥p∥ 2 + η 2 a t b=1 log m log B B 1 2B n∈B b |S n 2 |m |S n |aM p n (t)ξ∥p∥ with probability at least 1 -(mB) -c . (107) comes from (34). Then, combining (90), ( 105), ( 106), ( 107) and ( 108), we can obtain W (t+1) O (i,•) ⊤ , p j -W (t) O (i,•) ⊤ , p j ≲ η a • 1 B n∈B b ( |S n l | |S n | |S n l |ν n (t) + ((σ + τ )) + t b=1 |S n 1 |p n (b)ηm |S n |aM ( η 2 t 2 m a 2 + ξ)) log m log B B ∥p∥ 2 + η 2 tm √ Ba 2 ξ∥p∥ (109) Furthermore, we have W (t+1) O (i,•) p j ≲ η a (t+1) b=1 • 1 B n∈B b ( |S n l | |S n | |S n l |ν n (b) + ((σ + τ )) + t b=1 |S n 1 |p n (b)ηm |S n |aM ( η 2 t 2 m + ξ)) log m log B B ∥p∥ + η 2 t 2 m √ Ba 2 ξ∥p∥ + ξ∥p∥ ≤ξ∥p∥ (110) where the last step is by ηt ≤ O(1) (111) to ensure a non-zero gradient. (3) If i ∈ U(t), following the derivation of ( 104) and ( 110), we can conclude that W (t+1) O (i,•) p 2 ≳ 1 B(t + 1) t+1 b=1 n∈B b η 2 (t + 1)bm|S n 2 | |S n |a 2 ∥p 2 ∥ 2 p n (b) + ξ(1 -(σ + τ )) W (t) O (i,•) p ≤ ξ∥p∥, for p ∈ P/p 2 , ( ) (4) If i / ∈ (W(t) ∪ U(t)), |I 2 + I 3 | ≤ η a log m log B B ∥p∥ 2 Following ( 107) and ( 108), we have |I 4 | ≤ t b=1 η 2 a log m log B B 1 2B n∈B b |S n 1 |m |S n |aM p n (b)( η 2 t 2 m a 2 + ξ)∥p∥ |I 5 | ≲ η 2 tm √ Ba 2 ξ∥p∥ 2 + t b=1 η 2 a log m log B B 1 2B n∈B b |S n 2 |m |S n |aM p n (b)ξ∥p∥ Hence, combining (114), (115), and (116), we can obtain W (t+1) O (i,•) ⊤ , p -W (t) O (i,•) ⊤ , p ≲ η a • (∥p∥ + ((σ + τ )) + t b=1 |S n 1 |p n (b)ηm |S n |M a ( η 2 t 2 m a 2 + ξ)) log m log B B ∥p∥ + η 2 tm √ Ba 2 ξ∥p∥ 2 , (t+1) O (i,•) p ≲ t+1 b=1 η a • (∥p∥ + ((σ + τ )) + t b=1 |S n 1 |p n (b)ηm |S n |aM ( η 2 t 2 m a 2 + ξ)) • log m log B B ∥p∥ + η 2 t 2 m √ Ba 2 ξ∥p∥ 2 + ξ∥p∥ ≤ξ∥p∥ (118) where the last step is by ηt ≤ O(1) (5) We finally study the bound of W O (i,•) and the product with the noise term according to the analysis above. By (52), for the lucky neuron i, since that the update of W  O (i,•) ∥ 2 = M l=1 (W (t+1) O (i,•) p l ) 2 ≥ (W (t+1) O (i,•) p 1 ) 2 ≳( 1 B(t + 1) t+1 b=1 n∈B b η 2 (t + 1)b|S n 1 |m a 2 |S n | ∥p∥ 2 p n (b)) 2 (120) ∥W (t+1) O (i,•) ∥ 2 ≤ M ξ 2 ∥p∥ 2 + ( η 2 (t + 1) 2 m a 2 ) 2 ∥p∥ 2 (121) ∥W (t+1) O (i,•) z l (t)∥ ≤ ((σ + τ )) p∈P W (t) O (i,•) ⊤ , p 2 ≤((σ + τ ))( √ M ξ + η 2 (t + 1) 2 m a 2 )∥p∥ For the unlucky neuron i, we can similarly obtain |W (t+1) O (i,•) z l (t)| ≤ ((σ + τ )) p∈P W (t+1) O (i,•) ⊤ , p ≤ ((σ + τ )) √ M ξ∥p∥ O (i,•) ∥ 2 ≤M ξ 2 ∥p∥ 2 We can also verify that this claim holds when t = 1. The proof of Claim 1 finishes here.

E.2 PROOF OF CLAIM 2 OF LEMMA 2

The proof of Claim 2 is one of the most challenging parts in our paper, since that we need to deal with the complicated softmax function. The core idea of proof is that we pay more attention on the changes of label-relevant features in the gradient update, which should be the most crucial factor based on our data model. We then show the attention map converges to be sparse as long as the data model satisfies (8). We first study the gradient of W (t+1) Q in part (a) and the gradient of W (t+1) K in part (b). (a) By (232), we have η 1 B n∈B b ∂Loss(X n , y n ) ∂W Q =η 1 B n∈B b ∂Loss(X n , y n ) ∂F (X n ) F (X n ) ∂W Q =η 1 B n∈B b (-y n ) 1 |S n | l∈S n m i=1 a (l)i 1[W O (i,•) W V Xsoftmax(X n⊤ W ⊤ K W Q x n l ) ≥ 0] • W O (i,•) s∈S n W V x n s softmax(x n s ⊤ W ⊤ K W Q x n l ) • r∈S n softmax(x n r ⊤ W ⊤ K W Q x n l )W K (x n s -x n r )x n l ⊤ =η 1 B n∈B b (-y n ) 1 |S n | l∈S n m i=1 a (l)i 1[W O (i,•) W V X n softmax(X n⊤ W ⊤ K W Q x n l ) ≥ 0] • W O (i,•) s∈S n W V x n s softmax(x n s ⊤ W ⊤ K W Q x n l ) • (W K x n s - r∈S n softmax(x n r ⊤ W ⊤ K W Q x n l )W K x n r )x n l ⊤ For r, l ∈ S n 1 , by (43) we have softmax(x n j ⊤ W (t) K W (t) Q x n l ) ≳ e ∥q1(t)∥ 2 -(δ+τ )∥q1(t)∥ |S n 1 |e ∥q1(t)∥ 2 -(δ+τ )∥q1(t)∥ + (|S n | -|S n 1 |) For r / ∈ S n 1 and l ∈ S n 1 , we have softmax(x n j ⊤ W (t+1) K ⊤ W (t+1) Q x n l ) ≲ 1 |S n 1 |e (1+K(t))∥q1(t)∥ 2 -(δ+τ )∥q1(t)∥ + (|S n | -|S n 1 |) Published as a conference paper at ICLR 2023 Therefore, for s, r, l ∈ S n 1 , let W (t) K x n s - r∈S n softmax(x n r ⊤ W (t) K ⊤ W (t) Q x n l )W (t) K x n r := β n 1 (t)q 1 (t) + β n 2 (t), where β n 1 (t) ≳ |S n | -|S n 1 | |S n 1 |e ∥q1(t)∥ 2 +(δ+τ )∥q1(t)∥ + |S n | -|S n 1 | := ϕ n (t)(|S n | -|S n 1 |). (129) β n 1 (t) ≲ ν n (t)(|S n | -|S n 1 |) ≲ e 2(τ +δ)∥q1(t)∥ ϕ n (t)(|S n | -|S n 1 |) ≤ ϕ n (t)(|S n | -|S n 1 |) (130) where the last inequality holds when the final iteration log T ≤ Θ(1). β n 2 (t) ≈Θ(1) • o n j (t) + Q e (t)r 2 (t) + M l=3 γ ′ l r l (t) - M a=1 r∈S n l softmax(x ⊤ r W (t) K ⊤ W (t) Q x l )r a (t) =Θ(1) • o n j (t) + M l=1 ζ ′ l r l (t) (131) for some Q e (t) > 0 and γ ′ l > 0. Here |ζ ′ l | ≤ β n 1 (t) |S n l | |S n | -|S n 1 | (132) for l ≥ 2. Note that |ζ ′ l | = 0 if |S n | = |S n 1 |, l ≥ 2. Therefore, for i ∈ W(t), W (t) O (i,•) s∈S n W (t) V x n s softmax(x n s ⊤ W (t) K ⊤ W (t) Q x n l ) =W (t) O (i,•) s∈S1 p s softmax(x n s ⊤ W (t) K ⊤ W (t) Q x n l ) + z(t) + l̸ =s W l (u)p l -η t b=1 ( j∈W(b) V j (b)W (b) O (j,•) ⊤ + j / ∈W(b) V j (b)λW (b) O (j,•) ⊤ ) ≳ 1 Bt t b=1 n∈B b η 2 btm|S n 1 | |S n |a 2 (p n (b)) 2 ∥p 1 ∥ 2 -((σ + τ )) η 2 t 2 m a 2 ∥p 1 ∥ 2 + ξ(1 -(σ + τ )) -((σ + τ )) • √ M ξ∥p 1 ∥ -ξ∥p 1 ∥ + η 1 B t b=1 n∈B b |S n 1 |p n (b)m |S n |aM (1 -ϵ m - (σ + τ )M π )( η 2 t 2 m a 2 ) 2 ∥p 1 ∥ 2 -η 1 B t b=1 n∈B b |S n 2 |p n (b)m |S n |aM (1 -ϵ m - (σ + τ )M π )( η 2 t 2 m a 2 )ξ∥p 1 ∥ 2 - η 2 λξ √ M t 2 m √ Ba 2 ∥p 1 ∥ 2 >0 (133) where the first step is by ( 27) and the second step is a combination of (32) to (35). The final step holds as long as σ + τ ≲ O(1), and B ≥ ( λξ √ M ϵ 1 B n∈B b |S n 1 | |S n | (p n (t)) 2 ) 2 Then we study how large the coefficient of q 1 (t) in (125). If s ∈ S n 1 , by basic computation given (32) to (35), W (t) O (i,•) W (t) V x n s softmax(x n s ⊤ W (t) K ⊤ W (t) Q x n l ) ≳ η 2 t 2 m a 2 ( 1 Bt t b=1 n∈B b |S n 1 |b |S n |t p n (t) -(σ + τ ))∥p 1 ∥ 2 -((σ + τ )) √ M ξ∥p 1 ∥ -ξ∥p 1 ∥ + η t b=1 1 B • n∈B b |S n 1 |p n (t)m |S n |aM (1 -ϵ m - (σ + τ )M π )( 1 Bt t b=1 n∈B b η 2 tb|S n 1 | m|S n | p n (b)) 2 ∥p 1 ∥ 2 - η 2 λξ √ M t 2 m √ Ba 2 • ∥p 1 ∥ 2 -η t b=1 1 B n∈B b |S n 2 |p n (t)m |S n |aM (1 -ϵ m - (σ + τ )M π )( η 2 t 2 m a 2 ) 2 (σ + τ )∥p 1 ∥ 2 p n (t) |S n 1 | ≳ 1 Bt t b=1 n∈B b η 2 t 2 m a 2 ( |S n 1 |b |S n |t p n (t) -(σ + τ ))∥p 1 ∥ 2 p n (t) |S n 1 | + η t b=1 1 B n∈B b |S n 1 |p n (b)m |S n |aM (1 -ϵ m - (σ + τ )M π )( 1 Bt t b=1 n∈B b η 2 tbm|S n 1 | a 2 |S n | p n (b)) 2 ∥p 1 ∥ 2 p n (t) |S n 1 | (136) where the last step is by ( 134) and (135). If s ∈ S n 2 , from ( 36) to (39), we have W (t) O (i,•) W (t) V x n s softmax(x n s ⊤ W K (t) ⊤ W (t) Q x n l ) ≲(ξ∥p∥ + 1 Bt t b=1 n∈B b |S n 1 |p n (b)ηbm |S n |aM (1 -ϵ m - (σ + τ )M π )( η 2 t 2 m a 2 ) 2 (σ + τ )∥p∥ + η 2 λ √ M ξt 2 m √ Ba 2 ∥p 1 ∥ 2 + ((σ + τ ))( √ M ξ + η 2 t 2 m a 2 ) + 1 Bt t b=1 n∈B b |S n 2 |p n (t)ηbm |S n |aM • (1 -ϵ m - (σ + τ )M π )( η 2 t 2 m a 2 ) 2 ∥p 1 ∥ 2 )ϕ n (t) (137) If i ∈ W(t) and s / ∈ (S n 1 ∪ S n 2 ), W (t) O (i,•) W (t) V x n s softmax(x n s ⊤ W K (t) ⊤ W (t) Q x n l ) ≲(ξ∥p∥ + η 2 λ √ M ξt 2 m √ Ba 2 ∥p∥ 2 + ((σ + τ ))( √ M ξ + η 2 t 2 m a 2 )∥p∥ + η 1 Bt t b=1 n∈B b (|S n 2 | + |S n 1 |)p n (b)ηbm |S n |aM (1 -ϵ m - (σ + τ )M π ) η 2 t 2 m a 2 ξ∥p 1 ∥ 2 )ϕ n (t) by ( 40) to (42). Hence, for i ∈ W(t), j ∈ S g 1 , combining ( 129) and ( 136), we have W (t) O (i,•) s∈S n W (t) V x n s softmax(x n s ⊤ W K (t) ⊤ W (t) Q x n l )q 1 (t) ⊤ •(W (t) K x n s - L r=1 softmax(x n r ⊤ W (t) K ⊤ W (t) Q x n l )W (t) K x n r )x n l ⊤ x g j ≳ 1 Bt t b=1 n∈B b |S n 1 |p n (b)ηtm |S n |aM (1 -ϵ m - (σ + τ )M π )( 1 Bt t b=1 n∈B b η 2 tbm|S n 1 | a 2 |S n | p n (b)) 2 ∥p 1 ∥ 2 p n (t) + 1 Bt t b=1 n∈B b η 2 t 2 m a 2 ( |S n 1 |b |S n |t p n (b) -(σ + τ ))∥p 1 ∥ 2 p n (t) ϕ n (t)(|S n | -|S n 1 |)∥q 1 (t)∥ 2 (139) For i ∈ U(t) and l ∈ S n 1 , j ∈ S g 1 W (t) O (i,•) s∈S n W (t) V x n s softmax(x n s ⊤ W (t) K ⊤ W (t) Q x n l )q 1 (t) ⊤ •(W (t) K x n s - r∈S n softmax(x n r ⊤ W (t) K ⊤ W (t) Q x n l )W (t) K x n r )x n l ⊤ x g j ≲ 1 Bt t b=1 n∈B b |S n 2 |p n (b)ηtm |S n |aM (1 -ϵ m - (σ + τ )M π )( η 2 t 2 m a 2 ) 2 ∥p 1 ∥ 2 ϕ n (t)|S n 2 |β 1 (t)∥q 1 (t)∥ 2 + 1 Bt t b=1 n∈B b |S n 1 |p n (b)ηtm |S n |aM (1 -ϵ m - (σ + τ )M π )( η 2 t 2 m a 2 ) 2 (σ + τ )∥p∥ 2 ϕ n (t)|S n 2 | • β 1 (t)∥q 1 (t)∥ 2 (140) For i / ∈ (W(t) ∪ U(t)) and l ∈ S n 1 , j ∈ S g 1 , W (t) O (i,•) s∈S n W (t) V x n s softmax(x n s ⊤ W K (t) ⊤ W (t) Q x n l )q 1 (t) ⊤ •(W (t) K x n s - r∈S n softmax(x n r ⊤ W (t) K ⊤ W (t) Q x n l )x n r )x n l ⊤ x g j ≲(ξ∥p∥ + ((σ + τ ))( η 2 t 2 m a 2 + √ M ξ)∥p 1 ∥ 2 + 1 Bt t b=1 n∈B b (|S n 1 | + |S n 2 |)p n (b)ηtm aM |S n | • (1 -ϵ m - (σ + τ )M π )ξ η 2 t 2 m a 2 ∥p 1 ∥ + η 2 t 2 λξ √ M m∥p∥ 2 √ Ba 2 ) • β 1 (t)∥q 1 (t)∥ 2 To study the case when l / ∈ S n 1 for all n ∈ [N ], we need to check all other l's. Recall that we focus on the coefficient of q 1 (t) in this part. Based on the computation in ( 137) and ( 138), we know that the contribution of coefficient from non-discriminative patches is no more than that from discriminative patches, i.e., for l / ∈ (S n 1 ∪ S n 2 ), n ∈ [N ] and k ∈ S n 1 , W (t) O (i,•) s∈S n W (t) V x n s softmax(x n s ⊤ W (t) K ⊤ W (t) Q x n l )q 1 (t) ⊤ •(W (t) K x n s - r∈S n softmax(W (t) K x n r ⊤ W (t) K ⊤ W (t) Q x n l )W (t) K x n r )x n l ⊤ x g j | ≤ W (t) O (i,•) s∈S n W (t) V x n s softmax(x n s ⊤ W (t) K ⊤ W (t) Q x n k )q 1 (t) ⊤ •(W (t) K x n s - r∈S n softmax(W (t) K x n r ⊤ W (t) K ⊤ W (t) Q x n l )W (t) K x n r )x n k ⊤ x g j (142) Similar to (139), we have that for l ∈ S n 2 , j ∈ S g 1 , and i ∈ U(t), W (t) O (i,•) s∈S n W (t) V x n s softmax(x n s ⊤ W (t) K ⊤ W (t) Q x n l )q 1 (t) ⊤ •(W (t) K x n s - r∈S n softmax(W (t) K x n r ⊤ W (t) K ⊤ W (t) Q x n l )W (t) K x n r )x n l ⊤ x g j ≲ 1 Bt t b=1 n∈B b |S n 2 |p n (t)ηtm |S n |aM (1 -ϵ m - (σ + τ )M π )( η 2 t 2 m a 2 ) 2 ∥p 2 ∥ 2 • β 1 (t)λ |S n # | |S n | -|S n * | • ∥q 1 (t)∥ 2 + 1 Bt t b=1 n∈B b |S n 1 |p n (b)ηtm |S n |aM (1 -ϵ m - (σ + τ )M π )( η 2 t 2 m a 2 ) 2 (σ + τ )∥p∥ 2 β 1 (t) • λ |S n # | |S n | -|S n * | ∥q 1 (t)∥ 2 (143) Therefore, by the update rule, W (t+1) Q x j = W (t) Q x j -η ∂L ∂W Q W (t) Q x j = r 1 (t) + K(t)q 1 (t) + Θ(1) • n j (t) + t-1 b=0 |K e (b)|q 2 (b) + M l=3 γ ′ l q l (t) = (1 + K(t))q 1 (t) + Θ(1) • n j (t) + t-1 b=0 |K e (b)|q 2 (b) + M l=3 γ ′ l q l (t) where the last step is by the condition that q 1 (t) = k 1 (t) • r 1 (t), and q 2 (t) = k 2 (t) • r 2 (t) for k 1 (t) > 0 and k 2 (t) > 0 from induction, i.e., q 1 (t) and r 1 (t), q 1 (t) and r 1 (t) are in the same direction, respectively. We also have K(t) ≳η 1 Bt t b=1 n∈B b |S n 1 |p n (b)ηtm |S n |aM (1 -ϵ m - (σ + τ )M π )( 1 Bt t b=1 n∈B b η 2 tbm|S n 1 | a 2 |S n | p n (b)) 2 ∥p 1 ∥ 2 • p n (t) + 1 Bt t b=1 n∈B b η 2 t 2 m a 2 ( b|S n 1 | t|S n | p n (t) -(σ + τ ))∥p 1 ∥ 2 p n (t) ϕ n (t)(|S n | -|S n 1 |)∥q 1 (t)∥ 2 -η 1 Bt t b=1 n∈B b |S n 2 |p n (b)ηtm |S n |aM (1 -ϵ m - (σ + τ )M π )( η 2 t 2 m a 2 ) 2 ∥p 1 ∥ 2 ϕ n (t)|S 2 |β 1 (t)∥q 1 (t)∥ 2 -η 1 Bt t b=1 n∈B b |S n 1 |p n (b)ηtm |S n |aM (1 -ϵ m - (σ + τ )M π )( η 2 t 2 m a 2 ) 2 (σ + τ )∥p∥ 2 ϕ n (t)|S n 2 |β 1 (t) • ∥q 1 (t)∥ 2 -η(ξ∥p∥ + ((σ + τ ))( η 2 t 2 m a 2 + √ M ξ)∥p 1 ∥ 2 + 1 Bt t b=1 n∈B b |S n 1 |p n (b)ηtm aM |S n | • (1 -ϵ m - (σ + τ )M π )ξ η 2 t 2 m ∥p 1 ∥ + ηtλξ √ M m∥p∥ 2 √ Ba 2 ) • β 1 (t)∥q 1 (t)∥ 2 -η 1 Bt t b=1 n∈B b |S n 2 |p n (b)ηtm |S n |aM (1 -ϵ m - (σ + τ )M π )( η 2 tbm a 2 ) 2 ∥p 2 ∥ 2 • β 1 (t)λ |S n # | |S n | -|S n * | • ∥q 1 (t)∥ 2 -η 1 Bt t b=1 n∈B b |S n 1 |p n (b)ηtm |S n |aM (1 -ϵ m - (σ + τ )M π )( η 2 t 2 m a 2 ) 2 (σ + τ )∥p∥ 2 β 1 (t) • λ |S n # | |S n | -|S n * | ∥q 1 (t)∥ 2 ≳η 1 Bt t b=1 n∈B b |S n 1 |p n (b)ηtm |S n |aM (1 -ϵ m - (σ + τ )M π )( 1 Bt t b=1 n∈B b η 2 tb|S n 1 |m a 2 |S n | p n (b)) 2 ∥p 1 ∥ 2 • p n (t) + 1 Bt t b=1 n∈B b η 2 t 2 m a 2 ( b|S n 1 | t|S n | p n (t) -(σ + τ ))∥p 1 ∥ 2 p n (t) ϕ n (t)(|S n | -|S n 1 |)∥q 1 (t)∥ 2 >0 (147) |γ ′ l | ≲ 1 B n∈B b K(t) • |S n l | |S n | -|S n 1 | (148) |K e (t)| ≲ 1 B n∈B b K(t) • |S n 2 | |S n | -|S n 1 | Published as a conference paper at ICLR 2023 as long as 1 Bt t b=1 n∈B b ϵ S η 2 t 2 m a 2 p n (t)( |S n 1 |b |S n |t p n (b) -(σ + τ ))∥p 1 ∥ 2 + 1 Bt t b=1 n∈B b |S n 1 |p n (t)ηtm |S n |aM • (1 -ϵ m - (σ + τ )M π )( 1 Bt t b=1 n∈B b η 2 tbm|S n 1 | a 2 |S n | p n (t)) 2 ∥p 1 ∥ 2 p n (t) ϕ n (t) • (|S n | -|S n 1 |)∥q 1 (t)∥ 2 ≳ 1 Bt t b=1 n∈B b |S n 2 |p n (b)ηtm |S n |aM (1 -ϵ m - (σ + τ )M π )( η 2 t 2 m a 2 ) 2 ∥p 2 ∥ 2 • β 1 (t)λ |S n # | |S n | -|S n * | • ∥q 1 (t)∥ 2 + 1 Bt t b=1 n∈B b |S n 1 |p n (b)ηtm |S n |aM (1 -ϵ m - (σ + τ )M π )( η 2 t 2 m a 2 ) 2 (σ + τ )∥p∥ 2 β 1 (t) • λ |S n # | |S n | -|S n * | ∥q 1 (t)∥ 2 (150) To find the sufficient condition for (150), we first compare the first terms of both sides in (150). Note that when ηt ≤ O(1), (151) we have η 2 t 2 ≳ η 5 t 5 (152) When |S n | > |S n 1 |, by (130), ϕ n (t)(|S n | -|S n 1 |) ≳ β n 1 (t) (153) From Definition 1, we know 1 ≥ p n (t) ≥ p n (0) = Θ( |S n 1 | |S n 1 |e -(δ+τ ) + |S n | -|S n 1 | ) ≥ Θ(e -(δ+τ ) ) Moreover, ( 1 Bt t b=1 n∈B b |S n 1 |b |S n |t -(τ + σ))e -(δ+τ ) -(1 -(τ + σ))e -(δ+τ ) 1 2 E D [ |S n 1 | |S n | ] ≤ ( 1 Bt t b=1 n∈B b |S n 1 |b |S n |t -(τ + σ))e -(δ+τ ) -e -(δ+τ ) E D [( 1 Bt t b=1 n∈B b |S n 1 |b |S n |t -(τ + σ))] + e -(δ+τ ) E D [ 1 Bt t b=1 n∈B b |S n 1 | |S n | ( 1 2 - b t )] + e -(δ+τ ) (τ + σ)(1 - 1 2 E D [ |S n 1 | |S n | ]) ≤e -(δ+τ ) log Bt Bt + 1 • 1 2t e -(δ+τ ) + (τ + σ)e -(δ+τ ) ≤e -(δ+τ ) log Bt Bt + (τ + σ)e -(δ+τ ) (155) where the first inequality is by the triangle inequality and the second inequality comes from |S n 1 |b/(|S n |t) ≤ 1. Therefore, a sufficient condition for (150) is ϵ S e -(δ+τ ) (1 -(τ + σ)) 1 2 E D [ |S n 1 | |S n | ] ≳ 1 Bt t b=1 n∈B b |S n 2 | |S n | ≥ E D [ |S n 2 | |S n | ] - log Bt Bt by Hoeffding's inequality in (26). Thus, α * =: E D [ |S n 1 | |S n | ] ≥ 1 -α nd 1 + ϵ S e -(δ+τ ) (1 -(τ + σ)) (157) if Bt ≥ 1 ((1 -(τ + σ))e -(δ+τ ) 1 2 α * -τ ) 2 For the second terms on both sides in (150), since (τ + σ) ≤ O(1/M ), the inequality also holds with the same condition as in α * and Bt. and α # = E |S n # | |S n | (159) α nd = E M l=3 |S n l | |S n | Given that |S n | = Θ(1), ( 157) is equivalent to (8). For the second terms on both sides in (150), since (σ + τ ) ≤ 1/M , the inequality also holds with the same condition on α * and Bt.

Note that if |S

n | = |S n 1 |, we let |S n l |/(|S n | -|S n 1 |) = 0 for l ∈ [M ]. We use the presentation in (148, 149) above and (170, 171) below for simplicity. Then we give a brief derivation of W (t+1) Q x n j for j / ∈ S n 1 in the following. To be specific, for j ∈ S n /(S n 1 ∪ S n 2 ), η 1 B n∈B b ∂Loss(X n , y n ) ∂W (t) Q x n j , q 1 (t) ≳η η t b=1 1 B n∈B b |S n 1 |p n (b)m |S n |aM (1 -ϵ m - (σ + τ )M π )( 1 Bt t b=1 n∈B b η 2 tbm|S n 1 | a 2 |S n | p n (b)) 2 ∥p 1 ∥ 2 p ′ n (t) + 1 Bt t b=1 n∈B b η 2 t 2 m a 2 ( |S n 1 |b |S n |t p n (t) -(σ + τ ))∥p 1 ∥ 2 p ′ n (t) ϕ n (t)(|S n | -|S n 1 |)∥q 1 (t)∥ 2 (161) where p ′ n (t) = |S n 1 |e q1(t) ⊤ t b=1 K(b)q1(0)-(δ+τ )]∥q1(t)∥ |S n 1 |e q1(t) ⊤ t b=1 K(b)q1(b)-(δ+τ )]∥q1(t)∥ + |S n -|S n 1 | (162) When K(b) is close to 0 + , we have t b=1 1 + K(b)∥q(0)∥ 2 ≳ e t b=1 K(b)∥q1(0)∥ 2 ≥ t b=1 K(b)∥q 1 (0)∥ 2 (163) where the first step is by log(1 + x) ≈ x when x → 0 + . Therefore, one can derive that η 1 B n∈B b ∂Loss(X n , y n ) ∂W (t) Q x n j , q 1 (t) ≳ Θ(1) • K(t) Meanwhile, the value of p ′ n (t) will increase to 1 during training, making the component of q 1 (t) the major part in η 1 B n∈B b ∂Loss(X n ,y n ) ∂W (t) Q x n j . Hence, if j ∈ S n l for l ≥ 3, W (t+1) Q x j = q l (t) + Θ(1) • n j (t) + Θ(1) • t-1 b=0 K(b)q 1 (b) + M l=2 γ ′ l q l (t) Similarly, for j ∈ S n 2 , W Q x j = (1 + K(t))q 2 (t) + Θ(1) • n j (t) + t-1 b=0 |K e (b)|q 1 (b) + M l=2 γ ′ l q l (t) (b) For the gradient of W K , we have ∂Loss b ∂W K = 1 B n∈B b ∂Loss(X n , y n ) ∂F (X) F (X) ∂W K = 1 B n∈B b (-y n ) l∈S n m i=1 a (l)i 1[W O (i,•) W V Xsoftmax(X n⊤ W ⊤ K W Q x n l ) ≥ 0] • W O (i,•) s∈S n W V x n s softmax(x n s ⊤ W ⊤ K W Q x n l )W ⊤ Q x n l • (x n s - r∈S n softmax(x n r ⊤ W ⊤ K W Q x n l )x n r ) ⊤ Hence, for j ∈ S n 1 , we can follow the derivation of ( 144) to obtain W (t+1) K x j = (1 + Q(t))q 1 (t) + Θ(1) • o n j (t) ± t-1 b=0 |Q e (b)|(1 -λ)r 2 (b) + M l=3 γ ′ l r l (t), where Q(t) ≥K(t)(1 -λ) > 0 (169) for λ < 1 introduced in Assumption 3, and |γ l | ≲ 1 B n∈B b Q(t) • |S n l | |S n | -|S n * | (170) |Q e (t)| ≲ 1 B n∈B b Q(t) • |S n # | |S n | -|S n * | Similarly, for j ∈ S n 2 , we have W (t+1) K x j ≈ (1 + Q(t))q 2 (t) + Θ(1) • o n j (t) ± t-1 b=0 |Q e (b)|(1 -λ)r 1 (b) + M l=3 γ ′ l r l (t), (172) For j ∈ S n z , z = 3, 4, • • • , M , we have 1 B n∈B b ∂Loss(X n , y n ) ∂F (X) F (X) ∂W K x n j , q 1 (t) ≲ 1 B n∈B b (-y n ) l∈S n 1 m i=1 a (l)i 1[W O (i,•) W V Xsoftmax(X n⊤ W ⊤ K W Q x n l ) ≥ 0] • W O (i,•) ( s∈S n z +λ s∈S n 1 )W V x n s softmax(x n s ⊤ W ⊤ K W Q x n l ) ∥q 1 (t)∥ 2 ≤λ|Q f (t)|∥q 1 (t)∥ 2 (173) 1 B n∈B b ∂Loss(X n , y n ) ∂F (X) F (X) ∂W K x n j , q z (t) ≲ 1 B n∈B b (-y n ) l∈S n z m i=1 a (l)i 1[W O (i,•) W V Xsoftmax(X n⊤ W ⊤ K W Q x n l ) ≥ 0] • W O (i,•) ( s∈S n z +λ s∈S n 1 )W V x n s softmax(x n s ⊤ W ⊤ K W Q x n l ) ∥q z (t)∥ 2 ≤λ|Q f (t)|∥q z (t)∥ 2 (174) W (t+1) K x j ≈(1 ± c k1 λ|Q f (t)|)q l (t) + Θ(1) • o n j (t) ± c k2 λ • t-1 b=0 |Q f (b)|r 1 (b) ± c k3 λ • t-1 b=0 |Q f (b)|r 2 (b) + M i=3 γ ′ i r i (t), where 0 < c k1 , c k2 , c k3 < 1, and |Q f (t)| ≲ Q(t) (176) Therefore, for l ∈ S n 1 , if j ∈ S n 1 , x n j ⊤ W (t+1) K ⊤ W (t+1) Q x n l ≳(1 + K(t))(1 + Q(t))∥q 1 (t)∥ 2 -(δ + τ )∥q 1 (t)∥ + t-1 b=0 K e (b) t-1 b=0 Q e (b)∥q 2 (b)∥∥r 2 (b)∥ + M l=3 γ l γ ′ l ∥q l (t)∥∥r l (t)∥ ≳(1 + K(t))(1 + Q(t))∥q 1 (t)∥ 2 -(δ + τ )∥q 1 (t)∥ - M l=3 ( 1 B n∈B b Q(t) |S n l | |S n | -|S n * | ) 2 ∥r l (t)∥ 2 • M l=3 ( 1 B n∈B b K(t) |S n l | |S n | -|S n * | ) 2 ∥q l (t)∥ 2 ≳(1 + K(t) + Q(t))∥q 1 (t)∥ 2 -(δ + τ )∥q 1 (t)∥ 177) where the second step is by Cauchy-Schwarz inequality. If j / ∈ S n 1 , x n j ⊤ W (t+1) K ⊤ W (t+1) Q x n l ≲Q f (t)∥q 1 (t)∥ 2 + (δ + τ )∥q 1 (t)∥ (178) Hence, for j, l ∈ S n 1 , softmax(x n j ⊤ W (t+1) K W (t+1) Q x n l ) ≳ e (1+K(t))∥q1(t)∥ 2 -(δ+τ )∥q1(t)∥ |S n 1 |e (1+K(t))∥q1(t)∥ 2 -(δ+τ )∥q1(t)∥ + (|S n | -|S n 1 |) (179) softmax(x n j ⊤ W (t+1) K ⊤ W (t+1) Q x n l ) -softmax(x n j ⊤ W (t) K ⊤ W (t) Q x n l ) ≳ e (1+K(t))∥q1(t)∥ 2 -(δ+τ )∥q1(t)∥ |S n 1 |e (1+K(t))∥q1(t)∥ 2 -(δ+τ )∥q1(t)∥ + (|S n | -|S n 1 |) - e ∥q1(t)∥ 2 -(δ+τ )∥q1(t)∥ |S n 1 |e ∥q1(t)∥ 2 -(δ+τ )∥q1(t)∥ + (|S n | -|S n 1 |) = |S n | -|S n 1 | (|S n 1 |e x + (|S n | -|S n 1 |)) 2 e ∥q1(t)∥ 2 -(δ+τ )∥q1(t)∥ (e K(t) -1) ≥ |S n | -|S n 1 | (|S n 1 |e (1+K(t))∥q1(t)∥ 2 -(δ+τ )∥q1(t)∥ + (|S n | -|S 1 |)) 2 e ∥q1(t)∥ 2 -(δ+τ )∥q1(t)∥ • K(t) ) where the second to last step is by the Mean Value Theorem with x ∈ [∥q 1 (t)∥ 2 -(δ + τ )∥q 1 (t)∥, (1 + K(t))∥q 1 (t)∥ 2 -(δ + τ )∥q 1 (t)∥] We then need to study if l / ∈ (S n 1 ∪ S n 2 ) and j ∈ S n 1 , i.e., x n j ⊤ W (t+1) K ⊤ W (t+1) Q x n l ≳ (1 + Q(t)) t-1 b=0 |K(b)|∥q 1 (t)∥∥q 1 (b)∥ -(δ + τ )∥q 1 (t)∥ (182) For j, l / ∈ (S n 1 ∪ S n 2 ), x n j ⊤ W (t+1) K ⊤ W (t+1) Q x n l ≲ ± c k2 λ • t-1 b=0 |Q f (b)|r 1 (b) ± c k3 λ • t-1 b=0 |Q f (b)|r 2 (b) + (1 ± c k1 λ|Q f (t)|)∥q l (t)∥ 2 We know that the magnitude of ∥q 1 (t)∥ increases along the training and finally reaches no larger than Θ( √ log T ). At the final step, we have t-1 b=0 K(b)∥q 1 (b)∥ ≥ T e ∥q1(T )∥ 2 -(δ+τ )∥q1(T )∥ ≥ Θ( log T ) Therefore, when t is large enough during the training but before the final step of convergence, we have if j ′ , l / ∈ (S n 1 ∪ S n 2 ) and j ∈ S n 1 , we can obtain (x n j -x n j ′ ) ⊤ W (t+1) K ⊤ W (t+1) Q x n l ≳ Θ(1) • ((1 + K(t))∥q 1 (t)∥ 2 -(δ + τ )∥q 1 (t)∥) We can derive the same conclusion for j ∈ S n 2 in (185). Therefore, by ) 2 e ∥q1(t)∥ 2 -(δ+τ )∥q1(t)∥ (e K(t) -1) |S 2 | ≤ ϵ S e -(δ+τ ) (1 -(σ - τ ))|S n 1 | in (157), we can obtain softmax(x n j ⊤ W (t+1) K W (t+1) Q x n l ) ≳ e (1+K(t))∥q1(t)∥ 2 -(δ+τ )∥q1(t)∥ (|S n 1 | + |S n 2 |)e (1+K(t))∥q1(t)∥ 2 -(δ+τ )∥q1(t)∥ + (|S n | -|S n 1 | -|S n 2 |) ≳ e (1+K(t))∥q1(t)∥ 2 -(δ+τ )∥q1(t)∥ |S n 1 |e (1+K(t))∥q1(t)∥ 2 -(δ+τ )∥q1(t)∥ + (|S n | -|S n 1 |) (186) softmax(x n j ⊤ W (t+1) K ⊤ W (t+1) Q x n l ) -softmax(x n j ⊤ W (t) K ⊤ W (t) Q x n l ) ≳ |S n | -|S n 1 | (|S n 1 |e Θ(1)•((1+K(t))∥q1(t)∥ 2 -(δ+τ )∥q1(t))∥ + (|S n | -|S 1 |)) 2 e Θ(1)(∥q1(t)∥ 2 -(δ+τ )∥q1(t)∥) • K(t) ≳ |S n | -|S n 1 | (|S n 1 |e (1+K(t))∥q1(t)∥ 2 -(δ+τ )∥q1(t)∥ + (|S n | -|S 1 |)) 2 e ∥q1(t)∥ 2 -(δ+τ )∥q1(t)∥ • K(t) (187) Meanwhile, for l ∈ S n 1 and j / ∈ S n 1 , softmax(x n j ⊤ W (t+1) K ⊤ W (t+1) Q x n l ) ≲ 1 |S n 1 |e (1+K(t))∥q1(t)∥ 2 -(δ+τ )∥q1(t)∥ + (|S n | -|S n 1 |) ≤ - |S n 1 | (|S 1 |e (1+K(t))∥q1(t)∥ 2 -(δ+τ )∥q1(t)∥ + (|S n | -|S n 1 |) ) 2 e ∥q1(t)∥ 2 -(δ+τ )∥q1(t)∥ • K(t) (189) where the second to last step is by the Mean Value Theorem with x ∈ [∥q 1 (t)∥ 2 -(δ + τ )∥q 1 (t)∥, (1 + K(t))∥q 1 (t)∥ 2 -(δ + τ )∥q 1 (t)∥] (190) The same conclusion holds if l / ∈ (S n 1 ∪ S n 2 ) and j / ∈ S n 1 . Note that q 1 (t + 1) = (1 + K(t))q 1 (t) q 2 (t + 1) = (1 + K(t))q 2 (t) (192) r 1 (t + 1) = (1 + Q(t))r 1 (t) (193) r 2 (t + 1) = (1 + Q(t))r 2 (t) It can also be verified that this claim holds when t = 1.

E.3 PROOF OF CLAIM 3 OF LEMMA 2

The computation of the gradient of W V is straightforward. The gradient would be related to W O by their connections. One still need to study the influence of the gradient on different patterns, where we introduce the discussion for the term V i (t)'s. For the gradient of W V , by (18) we have 195) Consider a data {X n , y n } where y n = 1. Let l ∈ S n ∂Loss b ∂W V = 1 B n∈B b ∂Loss(X n , y n ) ∂F (X n ) ∂F (X n ) ∂W V = -y 1 B n∈B b 1 |S n | l∈S n m i=1 a * (l)i 1[W O (i,•) W V X n softmax(X n⊤ W ⊤ K W Q x n l ) ≥ 0] • W O (i,•) ⊤ softmax(X n⊤ W ⊤ K W Q x n l ) ⊤ X n⊤ 1 s∈S n 1 softmax(x n s ⊤ W (t) K ⊤ W (t) Q x n l ) ≥ p n (t) Then for j ∈ S g 1 , g ∈ [N ], 1 B n∈B b ∂Loss(X n , y n ) ∂W (t) V W (t) V x j = 1 B n∈B b (-y n ) 1 |S n | l∈S n m i=1 a (l)i 1[W (t) O (i,•) s∈S n softmax(x n s ⊤ W (t) K ⊤ W (t) Q x n l )W (t) V x n s ≥ 0] • W (t) O (i,•) ⊤ s∈S n softmax(x n s ⊤ W (t) K ⊤ W (t) Q x l )x n s ⊤ x g j = i∈W(t) V i (t)W O (i,•) ⊤ + i / ∈W(t) λV i (t)W O (i,•) ⊤ , (197) If i ∈ W(t), by the fact that S n # contributes more to V i (t) compared to S n l for l ≥ 3 and Assumption 3, we have V i (t) ≲ 1 2B n∈B b + - |S n 1 | a|S n | p n (t) + |S n 2 | a|S n | |λ|ν n (t)(|S n | -|S n 1 |) ≲ 1 2B n∈B b + - |S n 1 | a|S n | p n (t) Similarly, if i ∈ U(t), V i (t) ≳ 1 2B n∈B b - |S n 2 | a|S n | p n (t) ( ) if i is an unlucky neuron, by Hoeffding's inequality in (26), we have  V i (t) ≥ 1 √ B • 1 a • √ M ξ∥p∥ ≳ - 1 √ Ba O (i,•) j∈U (b) V j (b)W O (b) (j,•) ⊤ ≲ - 1 Bt t b=1 n∈B b |S n 2 |p n (b)m |S n |aM (1 -ϵ m - (σ + τ )M π )( η 2 t 2 m a 2 ) 2 (σ + τ )∥p 1 ∥ 2 (202) -ηtW O (i,•) j / ∈(W(t)∪U (t)) V j (t)W O (j,•) ⊤ ≲ η 2 t 2 mλξ √ M ∥p∥ 2 √ Ba 2 Hence, (1) If j ∈ S n 1 for one n ∈ [N ],  W (t+1) V x n j = W (t) V x n j -η ∂L ∂W V W (t) V (2) If j ∈ S n 2 , we have W (t+1) V x j = W (0) V x n j -η ∂L ∂W V W (0) V x n j = p 2 -η t+1 b=1 i∈U (b) V i (b)W (b) O (i,•) ⊤ -η t+1 b=1 i / ∈U (b) λV i (b)W (b) O (i,•) ⊤ + z j (t) (3) If j ∈ S n /(S n 1 ∪ S n 2 ), we have W (t+1) V x n j = W (0) V x n j -η ∂L ∂W V W (0) V x n j = p k -η t+1 b=1 m i=1 λV i (b)W (b) O (i,•) ⊤ + z j (t) Here ∥z j (t)∥ ≤ (σ + τ ) (207) for t ≥ 1. Note that this claim also holds when t = 1. Proof: Let θ l be the angle between p l and the initial weight for one i ∈ [m] and all l ∈ [M ]. For the lucky neuron i ∈ W(0), θ 1 should be the smallest among {θ l } M l=1 with noise ∆θ. Hence, the probability of the lucky neuron can be bounded as Pr θ 1 + ∆θ ≤ θ l -∆θ ≤ 2π, 2 ≤ l ≤ M = L l=2 Pr θ 1 + ∆θ ≤ θ l -∆θ ≤ 2π =( 2π -θ 1 -2∆θ 2π ) M -1 , where the first step is because the Gaussian W  for small σ > 0. Therefore, Pr i ∈ W(0) = 2π 0 1 2π • ( 2π -θ 1 -2∆θ 2π ) M -1 dθ 1 = - 1 M ( 2π -2∆θ -x 2π ) M 2π 0 ≳ 1 M (1 - ∆θ π ) M ≳ 1 M (1 - (σ + τ )M π ), where the first step comes from that θ 1 follows the uniform distribution on [0, 2π] due to the Gaussian initialization of W O . We can define the random variable v i such that v i = 1, if i ∈ W(0), 0, else We know that v i belongs to Bernoulli distribution with probability 1 M (1 -(σ+τ )M π ). By Hoeffding's inequality in (26), we know that with probability at least 1 -N -10 , 1 M (1 - (σ + τ )M π ) - log N m ≤ 1 m m i=1 v i ≤ 1 M (1 - (σ + τ )M π ) + log N m Let m ≥ Θ(ϵ -2 m M 2 log B), we have |W(0)| = m i=1 v i ≥ m M (1 -ϵ m - (σ + τ )M π ) where we require (σ + τ ) ≤ π M to ensure a positive probability in (216). Likewise, the conclusion holds for U(0). Lemma 4. Let W(t) and U(t) be defined in Definition 2. We then have W(0) ⊆ W(t) U(0) ⊆ U(t) The dataset D can be divided into four groups as D 1 ={(X n , y n )|y n = (1, 1)} D 2 ={(X n , y n )|y n = (1, -1)} D 3 ={(X n , y n )|y n = (-1, 1)} D 4 ={(X n , y n )|y n = (-1, -1)} (231) The hinge loss function for data (X n , y n ) will be Loss(X n , y n ) = max{1 -y n⊤ F (X n ), 0} We can divide the weights W O (i,•) (i ∈ [m]) into two groups, respectively. W 1 ={i|a l (i) = 1 √ m • (1, 1)} W 2 ={i|a l (i) = 1 √ m • (1, -1)} W 3 ={i|a l (i) = 1 √ m • (-1, 1)} W 4 ={i|a l (i) = 1 √ m • (-1, -1)} Therefore, for W Ou in the network (228), we have ∂Loss(X n , y n ) ∂W O (i,•) ⊤ = -y n 1 ∂F 1 (X n ) ∂W O 1(i,•) -y n 2 ∂F 2 (X n ) W O 2(i,•) where the derivation of ∂F1(X n ) ∂W O 1(i,•) and ∂F2(X n ) ∂W O 2(i,•) can be found in the analysis of binary classification above. For any i ∈ W 2 , following the proof of Claim 1 of Lemma 2, if the data (X n , y n ) ∈ D 2 , we have - ∂Loss(X n , y n ) ∂W O (i,•) ⊤ = y n 1 ∂F 1 (X n ) ∂W O 1(i,•) +y n 2 ∂F 2 (X n ) W O 2(i,•) ≈∝ 1• 1 √ m p 2 -1•(- 1 √ m )p 2 = 2 √ m p 2 (235) (W O (i,•) -W (t) O (i,•) )p 2 ∝ ∥p 2 ∥ 2 > 0 (236) if (X n , y n ) ∈ D 1 , we have - ∂Loss(X n , y n ) ∂W O (i,•) ⊤ ≈∝ 1 • 1 √ m p 1 + 1 • (- 1 √ m )p 1 = 0 (237) (W O (i,•) -W (t) O (i,•) )p 1 ≈= 0 (238) if (X n , y n ) ∈ D 3 , we have - ∂Loss(X n , y n ) ∂W O (i,•) ⊤ ≈∝ -1 • 1 √ m p 3 + 1 • (- 1 √ m )p 3 = - 2 √ m p 3 (239) (W O (i,•) -W (t) O (i,•) )p 3 ≤ 0 (240) if (X n , y n ) ∈ D 4 , we have - ∂Loss(X n , y n ) ∂W O (i,•) ⊤ ≈∝ -1 • 1 √ m p 4 -1 • (- 1 √ m )p 4 = 0 (241) (W O (i,•) -W (t) O (i,•) )p 4 ≈ 0 By the algorithm, W O (i,•) will update along the direction of p 2 for i ∈ W 2 . We can analyze W V , W K and W Q similarly.



Extension to multi-classification is briefly discussed in Section G.1. It is common to fix the output layer weights as the random initialization in the theoretical analysis of neural networks, including NTK(Allen-Zhu et al., 2019a;Arora et al., 2019), model recovery(Zhong et al., 2017b), and feature learning(Karp et al., 2021;Allen-Zhu & Li, 2022) type of approaches. The optimization problem here of WQ, WK , WV , and WO with non-linear activations is still highly non-convex and challenging. The condition q1 = r1 and q2 = r2 is to eliminate the trivial case that the initial attention value is very small. This condition can be relaxed but we keep this form to simplify the representation. The sample complexity bounds in (10) and (13) are sufficient but not necessary. Thus, rigorously speaking, one can not compare two cases based on sufficient conditions only. In our analysis, however, these two bounds are derived with exactly the same technique with the only difference in handling the self-attention layer. Therefore, we believe it is fair to compare these two bounds to show the advantage of ViT.



; Tang et al. (2022) prune tokens following criteria designed based on the magnitude of the attention map. Despite the remarkable empirical success, one fundamental question about training Transformers is still vastly open, which is

initial model. Every entry of W O is generated from N (0, ξ 2 ). Every entry of a equal probability. A does not update during the training 2 .

Figure 1: The impact of α * and σ on sample complexity.

Figure 2: The number of iterations against α -1 * . Attention map: We then evaluate the evolution of the attention map during the training. Let |S n | = 50 for all n ∈ [N ]. The number of training samples is N = 200. σ = 0.1, δ = 0.2, α * = 0.5, α # = 0.05.In Figure4, the red line with asterisks shows that the sum of attention weights on label-relevant tokens, i.e., the left side of (15) averaged over all l, indeed increases to be close to 1 when the number of iterations increases. Correspondingly, the sum of attention weights on other tokens decreases to be close to 0, as shown in the blue line with squares. This verifies Lemma 2 on a sparse attention map.

Figure 3: Comparison of ViT and CNN Figure 4: Concentration of attention weights Figure 5: Impact of token sparsification on testing loss

Figure 6: (a) Test accuracy when N and α * change. (b) Relationship of sample complexity against α * . (c) Test accuracy when token sparsification removes spurious correlations.

Figure 7: Concentration of attention weights when (a) M = 10 (b) M = 15 (c) M = 20.

Figure 8: Impact of token sparsification on testing data when (a) M = 10 (b) M = 15 (c) M = 20.

set of sampled tokens of pattern j for the n-th data. S n * , S n # are sets of sampled tokens of the label-relevant pattern and the confusion pattern for the n-th data, respectively. α * , α # , α nd

and fixed during the training. W V , W K , and W Q are initialized from a good pretrained model. Algorithm 1 Training with SGD 1: Input: Training data {(X n , y n )} N n=1 , the step size η, the total number of iterations T , batch size B. 2: Initialization: Every entry of W (0) O

in terms of different directions of p l , i ∈ [M ]. Claim 2 describes the training dynamics of W separately to show the tendency to a sparse attention map. Claim 3 studies the gradient update process of W (t) V .

where the first step is by letting a = √ m and m ≳ M 2 . We replace p n (b) with p n (T ) because when b achieves the level of T , b o1 p n (b) o2 is close to b o1 for o 1 , o 2 ≥ 0 by (61). Thus, T b=1 b o1 p n (b) o2 ≳ T o1+1 p n (Θ(1) • T ) o2 ≳ T o1+1 p n (T ) o2

(i,•)  lies in the subspace spanned by P and p 1 , p 2 , • • • , p M all have a unit norm, we can derive ∥W (t+1)

(1+K(t))∥q1(t)∥ 2 -(δ+τ )∥q1(t)∥ + (|S n | -|S n 1 |) -1 |S n 1 |e ∥q1(t)∥ 2 -(δ+τ )∥q1(t)∥ + (|S n | -|S n 1 1 |e x + (|S n | -|S n 1 |)



OTHER USEFUL LEMMAS Lemma 3. If the number of neurons m is larger enough such that m ≥ ϵ -2 m M 2 log N, (208) the number of lucky neurons at the initialization |W(0)|, |U(0)| satisfies|W(0)|, |U(0)| ≥ m M (1 -ϵ m -(σ + τ )M π )(209)

(i,•) and orthogonal p l , l ∈ [M ] generate independent W (0) O (i,•) p l .From the definition of W(0), we have2 sin 1 2 ∆θ ≤ (σ + τ ),(211)which implies ∆θ ≲ (σ + τ )

Some important notations

Zizheng Pan, Bohan Zhuang, Jing Liu, Haoyu He, and Jianfei Cai. Scalable vision transformers with hierarchical pooling. In Proceedings of the IEEE/cvf international conference on computer vision, pp. 377-386, 2021. Alec Radford, Karthik Narasimhan, Tim Salimans, Ilya Sutskever, et al. Improving language understanding by generative pre-training. 2018.

ACKNOWLEDGMENTS

This work was supported by AFOSR FA9550-20-1-0122, ARO W911NF-21-1-0255, NSF 1932196 and the Rensselaer-IBM AI Research Collaboration (http://airc.rpi.edu), part of the IBM AI Horizons Network (http://ibm.biz/AIHorizons). We thank Dr. Shuai Zhang at Rensselaer Polytechnic Institute for useful discussions. We thank all anonymous reviewers for their constructive comments.

annex

as long as B ≳ Θ(1) (220) Proof:We show this lemma by induction.(1) t = 0. For i ∈ W(0), by Definition 2, we know that the angle between W (0) O (i,•) and p 1 is smaller than (σ + τ ). Hence, we havefor all p ∈ P/p 1 .(2) Suppose that the conclusion holds when t = s. When t = s + 1, from Lemma 2 Claim 1, we can obtainandCombining ( 222) and ( 223), we can approximately compute that ifwe can derive W (s+1)Therefore, we have) for all t ≥ 0. One can develop the proof for U(t) following the above steps.

G EXTENSION TO MORE GENERAL CASES G.1 EXTENSION TO MULTI-CLASSIFICATION

Consider the classification problem with four classes, we use the label y ∈ {+1, -1} 2 to denote the corresponding class. Similarly to the previous setup, there are four orthogonal discriminative patterns. In the output layer, a l (i) for the data

G.2 EXTENSION TO A MORE GENERAL DATA MODEL

We generalize the patterns from vectors to sets of vectors. Consider that there areM 2 denote sets of discriminative patterns for the binary labels, andis the minimum distance between patterns of different sets. Each token x n l of X n is a noisy version of one pattern, i.e., minDefine that for l, s corresponding to b 1 , b 2 , respectively, minwe have 2τ + ∆ < κ.To simplify our theoretical analysis, one can similarly rescale all tokens a little bit like in Assumption 3 such that tokens corresponding to patterns in the same pattern set has an inner product larger than 1, while tokens corresponding to patterns from different pattern sets has an inner product smaller than λ < 1. Assumption 1 can be modified such thatTherefore, we make sure that the initial query, key and value features from different sets of patterns are still close to be orthogonal to each other. Then, we can follow our main proof idea. To be more specific, for label-relevant tokens x n l , by computing ( 125) and (167), W (t)Q x n l will grow in the direction of a fixed linear combination of q l,1 , • • • , q l,lm , and r l,1 , • • • , r l,lm . The coefficient of the linear combination is a function of fractions of different pattern vectors µ l,b in M l . One can still derive a sparse attention map with weights of non-discriminative patterns decreasing to be close to zero during the training.

G.3 EXTENSION TO MULTI-HEAD NETWORKS

Suppose there are H heads in total. The network is modified towhereOne can make similar assumptions for W (0)K h , andBased on the modified assumption with H heads, the backbone of the proof remains the same.

Lucky neurons in W

Hence, the properties of the Relu activation are almost the same as the single-head case because luck neurons are still activated by either of two label-relevant patterns with a high probability. In fact, one can expect a more stable training process by multiple heads due to a more stable Relu gate for lucky neurons.Consider a basic case where a skip connection is added after the self-attention layer. Let m a = d. The network is changed intoThe assumption of W (0)V in Assumption 1 should be changed intowhile the assumption of WQ and W (0)K remain the same. One can easily verify that the gradients of W K , W Q , and W V for (250) are almost the same as those for (1) except for the Relu gate. The major differences come from the gradient of W O , which also helps to determine the Relu gate. One needs to redefinefor l ∈ S n 1 . The inner product between the lucky neuron and the term j̸ =1 W n j (t)(x n j + x n l ) can still be upper bounded by the inner product between the lucky neuron and the termQ x n l )p 1 given good initialization of W K and W Q . Therefore, (250) can be analyzed following our proof techniques. For layer normalization, one usually use that approach to normalize each data. It is consistent with our normalization of x n l , which plays an important role in our proof. By normalization, the training process becomes more stable because of the unified norm of all tokens.

