THEORETICAL CHARACTERIZATION OF HOW NEURAL NETWORK PRUNING AFFECTS ITS GENERALIZATION

Abstract

It has been observed in practice that applying pruning-at-initialization methods to neural networks and training the sparsified networks can not only retain the testing performance of the original dense models, but also sometimes even slightly boost the generalization performance. Theoretical understanding for such experimental observations are yet to be developed. This work makes the first attempt to study how different pruning fractions affect the model's gradient descent dynamics and generalization. Specifically, this work considers a classification task for overparameterized two-layer neural networks, where the network is randomly pruned according to different rates at the initialization. It is shown that as long as the pruning fraction is below a certain threshold, gradient descent can drive the training loss toward zero and the network exhibits good generalization performance. More surprisingly, the generalization bound gets better as the pruning fraction gets larger. To complement this positive result, this work further shows a negative result: there exists a large pruning fraction such that while gradient descent is still able to drive the training loss toward zero (by memorizing noise), the generalization performance is no better than random guessing. This further suggests that pruning can change the feature learning process, which leads to the performance drop of the pruned neural network. Up to our knowledge, this is the first generalization result for pruned neural networks, suggesting that pruning can improve the neural network's generalization.

1. INTRODUCTION

Neural network pruning can be dated back to the early stage of the development of neural networks (LeCun et al., 1989) . Since then, many research works have been focusing on using neural network pruning as a model compression technique, e.g. (Molchanov et al., 2019; Luo & Wu, 2017; Ye et al., 2020; Yang et al., 2021) . However, all these work focused on pruning neural networks after training to reduce inference time, and, thus, the efficiency gain from pruning cannot be directly transferred to the training phase. It is not until the recent days that Frankle & Carbin (2018) showed a surprising phenomenon: a neural network pruned at the initialization can be trained to achieve competitive performance to the dense model. They called this phenomenon the lottery ticket hypothesis. The lottery ticket hypothesis states that there exists a sparse subnetwork inside a dense network at the random initialization stage such that when trained in isolation, it can match the test accuracy of the original dense network after training for at most the same number of iterations. On the other hand, the algorithm Frankle & Carbin (2018) proposed to find the lottery ticket requires many rounds of pruning and retraining which is computationally expensive. Many subsequent works focused on developing new methods to reduce the cost of finding such a network at the initialization (Lee et al., 2018; Wang et al., 2019; Tanaka et al., 2020; Liu & Zenke, 2020; Chen et al., 2021b) . A further investigation by Frankle et al. (2020) showed that some of these methods merely discover the layer-wise pruning ratio instead of sparsity pattern. The discovery of the lottery ticket hypothesis sparkled further interest in understanding this phenomenon. Another line of research focused on finding a subnetwork inside a dense network at the random initialization such that the subnetwork can achieve good performance (Zhou et al., 2019; Ramanujan et al., 2020) . Shortly after that, Malach et al. (2020) formalized this phenomenon which they called the strong lottery ticket hypothesis: under certain assumption on the weight initialization distribution, a sufficiently overparameterized neural network at the initialization contains a subnet-work with roughly the same accuracy as the target network. Later, Pensia et al. (2020) improved the overparameterization parameters and Sreenivasan et al. (2021) showed that such a type of result holds even if the weight is binary. Unsurprisingly, as it was pointed out by Malach et al. (2020) , finding such a subnetwork is computationally hard. Nonetheless, all of the analysis is from a function approximation perspective and none of the aforementioned works have considered the effect of pruning on gradient descent dynamics, let alone the neural networks' generalization. Interestingly, via empirical experiments, people have found that sparsity can further improve generalization in certain scenarios (Chen et al., 2021a; Ding et al., 2021; He et al., 2022) . There have also been empirical works showing that random pruning can be effective (Frankle et al., 2020; Su et al., 2020; Liu et al., 2021b) . However, theoretical understanding of such benefit of pruning of neural networks is still limited. In this work, we take the first step to answer the following important open question from a theoretical perspective: How does pruning fraction affect the training dynamics and the model's generalization, if the model is pruned at the initialization and trained by gradient descent? We study this question using random pruning. We consider a classification task where the input data consists of class-dependent sparse signal and random noise. We analyze the training dynamics of a two-layer convolutional neural network pruned at the initialization. Specifically, this work makes the following contributions: • Mild pruning. We prove that there indeed exists a range of pruning fraction where the pruning fraction is small and the generalization error bound gets better as pruning fraction gets larger. In this case, the signal in the feature is well-preserved and due to the effect of pruning purifying the feature, the effect from noise is reduced. We provide detailed explanation in Section 3. Up to our knowledge, this is the first theoretical result on generalization for pruned neural networks, which suggests that pruning can improve generalization under some setting. Further, we conduct experiments to verify our results. • Over pruning. To complement the above positive result, we also show a negative result: if the pruning fraction is larger than a certain threshold, then the generalization performance is no better than a simple random guessing, although gradient descent is still able to drive the training loss toward zero. This further suggests that the performance drop of the pruned neural network is not solely caused by the pruned network's own lack of trainability or expressiveness, but also by the change of gradient descent dynamics due to pruning. • Technically, we develop novel analysis to bound pruning effect to weight-noise and weightsignal correlation. Further, in contrast to many previous works that considered only the binary case, our analysis handles multi-class classification with general cross-entropy loss. Here, a key technical development is a gradient upper bound for multi-class cross-entropy loss, which might be of independent interest. Pictorially, our result is summarized in Figure 1 . We point out that the neural network training we consider is in the feature learning regime, where the weight parameters can go far away from their initialization. This is fundamentally different from the popular neural tangent kernel regime, where the neural networks essentially behave similar to its linearization.

1.1. RELATED WORKS

The Lottery Ticket Hypothesis and Sparse Training. The discovery of the lottery ticket hypothesis (Frankle & Carbin, 2018) has inspired further investigation and applications. One line of research has focused on developing computationally efficient methods to enable sparse training: the static sparse training methods are aiming at identifying a sparse mask at the initialization stage based on different criterion such as SNIP (loss-based) (Lee et al., 2018) , GraSP (gradient-based) (Wang et al., 2019) , SynFlow (synaptic strength-based) (Tanaka et al., 2020) , neural tangent kernel based method (Liu & Zenke, 2020) and one-shot pruning (Chen et al., 2021b) . Random pruning has also been considered in static sparse training such as uniform pruning (Mariet & Sra, 2015; He et al., 2017; Gale et al., 2019; Suau et al., 2018) , non-uniform pruning (Mocanu et al., 2016) , expander-graph-related techniques (Prabhu et al., 2018; Kepner & Robinett, 2019) Erdös-Rényi (Mocanu et al., 2018) and Erdös-Rényi-Kernel (Evci et al., 2020) . On the other hand, dynamic sparse training allows the Figure 1 : A pictorial demonstration of our results. The bell-shaped curves model the distribution of the signal in the features, where the mean represents the signal strength and the width of the curve indicates the variance of noise. Our results show that mild pruning preserves the signal strength and reduces the noise variance (and hence yields better generalization), whereas over pruning lowers signal strength albeit reducing noise variance. sparse mask to be updated (Mocanu et al., 2018; Mostafa & Wang, 2019; Evci et al., 2020; Jayakumar et al., 2020; Liu et al., 2021c; d; a; Peste et al., 2021) . The sparsity pattern can also be learned by using sparsity-inducing regularizer (Yang et al., 2020) . Recently, He et al. (2022) discovered that pruning can exhibit a double descent phenomenon when the data-set labels are corrupted. Another line of research has focused on studying pruning the neural networks at its random initialization to achieve good performance (Zhou et al., 2019; Ramanujan et al., 2020) . In particular, Ramanujan et al. (2020) showed that it is possible to prune a randomly initialized wide ResNet-50 to match the performance of a ResNet-34 trained on ImageNet. This phenomenon is named the strong lottery ticket hypothesis. Later, Malach et al. (2020) proved that under certain assumption on the initialization distribution, a target network of width d and depth l can be approximated by pruning a randomly initialized network that is of a polynomial factor (in d, l) wider and twice deeper even without any further training. However finding such a network is computationally hard, which can be shown by reducing the pruning problem to optimizing a neural network. Later, Pensia et al. (2020) improved the widening factor to being logarithmic and Sreenivasan et al. (2021) proved that with a polylogarithmic widening factor, such a result holds even if the network weight is binary. A follow-up work shows that it is possible to find a subnetwork achieving good performance at the initialization and then fine-tune (Sreenivasan et al., 2022) . Our work, on the other hand, analyzes the gradient descent dynamics of a pruned neural network and its generalization after training.

Analyses of Training Neural

Networks by Gradient Descent. A series of work (Allen-Zhu et al., 2019; Du et al., 2019; Lee et al., 2019; Zou et al., 2020; Zou & Gu, 2019; Ji & Telgarsky, 2019; Chen et al., 2020b; Song & Yang, 2019; Oymak & Soltanolkotabi, 2020) has proved that if a deep neural network is wide enough, then (stochastic) gradient descent provably can drive the training loss toward zero in a fast rate based on neural tangent kernel (NTK) (Jacot et al., 2018) . Further, under certain assumption on the data, the learned network is able to generalize (Cao & Gu, 2019; Arora et al., 2019) . However, as it is pointed out by Chizat et al. (2019) , in the NTK regime, the gradient descent dynamics of the neural network essentially behaves similarly to its linearization and the learned weight is not far away from the initialization, which prohibits the network from performing any useful feature learning. In order to go beyond NTK regime, one line of research has focused on the mean field limit (Song et al., 2018; Chizat & Bach, 2018; Rotskoff & Vanden-Eijnden, 2018; Wei et al., 2019; Chen et al., 2020a; Sirignano & Spiliopoulos, 2020; Fang et al., 2021) . Recently, people have started to study the neural network training dynamics in the feature learning regime where data from different class is defined by a set of class-related signals which are low rank (Allen-Zhu & Li, 2020; 2022; Cao et al., 2022; Shi et al., 2021; Telgarsky, 2022) . However, all previous works did not consider the effect of pruning. Our work also focuses on the aforementioned feature learning regime, but for the first time characterizes the impact of pruning on the generalization performance of neural networks.

2. PRELIMINARIES AND PROBLEM FORMULATION

In this section, we introduce our notation, data generation process, neural network architecture and the optimization algorithm. Notations. We use lower case letters to denote scalars and boldface letters and symbols (e.g. x) to denote vectors and matrices. We use ⊙ to denote element-wise product. For an integer n, we use [n] to denote the set of integers {1, 2, . . . , n}. We use x = O(y), x = Ω(y), x = Θ(y) to denote that there exists a constant C such that x ≤ Cy, x ≥ Cy, x = Cy respectively. We use O, Ω and Θ to hide polylogarithmic factor in these notations. Finally, we use x = poly(y) if x = O(y C ) for some positive constant C, and x = poly log y if x = poly(log y).

2.1. SETTINGS

Definition 2.1 (Data distribution of K classes). Consider we are given the set of signal vectors {µe i } K i=1 , where µ > 0 denotes the strength of the signal, and e i denotes the i-th standard basis vector with its i-th entry being 1 and all other coordinates being 0. Each data point (x, y) with x = [x ⊤ 1 , x ⊤ 2 ] ⊤ ∈ R 2d and y ∈ [K] is generated from the following distribution D: 1. The label y is generated from a uniform distribution over [K]. 2. A noise vector ξ is generated from the Gaussian distribution N (0, σ 2 n I). 3. With probability 1/2, assign x 1 = µ y , x 2 = ξ; with probability 1/2, assign x 2 = µ y , x 1 = ξ where µ y = µe y . The sparse signal model is motivated by the empirical observation that during the process of training neural networks, the output of each layer of ReLU is usually sparse instead of dense. This is partially due to the fact that in practice the bias term in the linear layer is used (Song et al., 2021) . For samples from different classes, usually a different set of neurons fire. Our study can be seen as a formal analysis on pruning the second last layer of a deep neural network in the layer-peeled model as in Zhu et al. (2021) ; Zhou et al. (2022) . We also point out that our assumption on the sparsity of the signal is necessary for our analysis. If we don't have this sparsity assumption and only make assumption on the ℓ 2 norm of the signal, then in the extreme case, the signal is uniformly distributed across all coordinate and the effect of pruning to the signal and the noise will be essentially the same: their ℓ 2 norm will both be reduced by a factor of √ p. Network architecture and random pruning. We consider a two-layer convolutional neural network model with polynomial ReLU activation σ(z) = (max{0, z}) q , where we focus on the case when q = 3foot_0 The network is pruned at the initialization by mask M where each entry in the mask M is generated i.i.d. from Bernoulli(p). Let m j,r denotes the r-th row of M j . Given the data (x, y), the output of the neural network can be written as F (W ⊙ M, x) = (F 1 (W 1 ⊙ M 1 , x), F 2 (W 2 ⊙ M 2 , x), . . . , F k (W k ⊙ M k , x)) where the j-th output is given by F j (W j ⊙ M j , x) = m r=1 [σ(⟨w j,r ⊙ m j,r , x 1 ⟩) + σ(⟨w j,r ⊙ m j,r , x 2 ⟩)] = m r=1 [σ(⟨w j,r ⊙ m j,r , µ⟩) + σ(⟨w j,r ⊙ m j,r , ξ⟩)]. The mask M is only sampled once at the initialization and remains fixed through the entire training process. From now on, we use tilde over a symbol to denote its masked version, e.g., W = W ⊙ M and w j,r = w j,r ⊙ m j,r . Since µ j ⊙ m j,r = 0 with probability 1 -p, some neurons will not receive the corresponding signal at all and will only learn noise. Therefore, for each class j ∈ [k], we split the neurons into two sets based on whether it receives its corresponding signal or not: S j signal = {r ∈ [m] : µ j ⊙ m j,r ̸ = 0}, S j noise = {r ∈ [m] : µ j ⊙ m j,r = 0}. Gradient descent algorithm. We consider the network is trained by cross-entropy loss with softmax. We denote by logit i (F, x) := x) and the cross-entropy loss can be written as ℓ(F (x, y)) = -log logit y (F, x). The convolutional neural network is trained by minimizing the empirical cross-entropy loss given by e F i (x) j∈[k] e F j ( L S (W) = 1 n n i=1 ℓ[F (W ⊙ M; x i , y i )] = E S ℓ[F (W ⊙ M; x i , y i )], where S = {(x i , y i )} n i=1 is the training data set. Similarly, we define the generalization loss as L D := E (x,y) [ℓ(F (W ⊙ M; x, y))]. The model weights are initialized from a i.i.d. Gaussian N (0, σ 2 0 ). The gradient of the cross-entropy loss is given by ℓ ′ j,i := ℓ ′ j (x i , y i ) = logit j (F, x i ) -I(j = y i ). Since ∇ wj,r L S (W ⊙ M) = ∇ wj,r⊙mj,r L S (W ⊙ M) ⊙ m j,r = ∇ wj,r L S ( W) ⊙ m j,r , we can write the full-batch gradient descent update of the weights as w (t+1) j,r = w (t) j,r -η∇ wj,r L S ( W) ⊙ m j,r = w (t) j,r - η n n i=1 ℓ ′(t) j,i • σ ′ w (t) j,r , ξ i • ξ j,r,i - η n n i=1 ℓ ′(t) j,i σ ′ w (t) j,r , µ yi µ yi ⊙ m j,r , for j ∈ [K] and r ∈ [m] , where ξ j,r,i = ξ i ⊙ m j,r . Condition 2.2. We consider the parameter regime described as follows: Conditions (1) and (2) ensure that there are enough samples in each class with high probability. Condition (3) ensures that our setting is in high-dimensional regime. Condition (4) ensures that the full model can be trained to exhibit good generalization. Condition (5), ( 6) and (7) ensures that the neural network is sufficiently overparameterized and can be optimized efficiently by gradient descent. Condition (7) and (8) further ensures that training time is polynomial in d. We further discuss the practical consideration of η and ϵ to justify their condition in Remark D.9.

3.1. MAIN RESULT

The first main result shows that there exists a threshold on the pruning fraction p such that pruning helps the neural network's generalization. Theorem 3.1 (Main Theorem for Mild Pruning, Informal). Under Condition 2.2, if p ∈ [C 1 log d m , 1] for some constant C 1 , then with probability at least 1 -O(d -1 ) over the randomness in the data, network initialization and pruning, there exists T = O(Kη -1 σ 2-q 0 µ -q + K 2 m 4 µ -2 η -1 ϵ -1 ) such that 1. The training loss is below ϵ: L S ( W (T ) ) ≤ ϵ. 2. The generalization loss can be bounded by L D ( W (T ) ) ≤ O(Kϵ) + exp(-n 2 /p). Theorem 3.1 indicates that there exists a threshold in the order of Θ( log d m ) such that if p is above this threshold (i.e., the fraction of the pruned weights is small), gradient descent is able to drive the training loss towards zero (as item 1 claims) and the overparameterized network achieves good testing performance (as item 2 claims). In the next subsection, we explain why pruning can help generalization via an outline of our proof, and we defer all the detailed proofs in Appendix D.

3.2. PROOF OUTLINE

Our proof contains the establishment of the following two properties: • First we show that after mild pruning the network is still able to learn the signal, and the magnitude of the signal in the feature is preserved. • Then we show that given a new sample, pruning reduces the noise effect in the feature which leads to the improvement of generalization. We first show the above properties for three stages of gradient descent: initialization, feature growing phase, and converging phase, and then establish the generalization property. Initialization. First of all, readers might wonder why pruning can even preserve signal at all. Intuitively, a network will achieve good performance if its weights are highly correlated with the signal (i.e., their inner product is large). Two intuitive but misleading heuristics are given by the following: • Consider a fixed neuron weight. At the random initialization, in expectation, the signal correlation with the weights is given by E w,m [| ⟨w ⊙ m, µ⟩ |] ≤ pσ 0 µ and the noise correlation with the weights is given by E w,m,ξ [| ⟨w ⊙ m, ξ⟩ |] ≤ E w,m,ξ [⟨w ⊙ m, ξ⟩ 2 ] = σ 0 σ n √ pd by Jensen's inequality. Based on this argument, taking a sum over all the neurons, pruning will hurt weight-signal correlation more than weight-noise correlation. • Since we are pruning with Bernoulli(p), a given neuron will not receive signal at all with probability 1 -p. Thus, there is roughly p fraction of the neurons receiving the signal and the rest 1 -p fraction will be purely learning from noise. Even though for every neuron, roughly √ p portion of ℓ 2 mass from the noise is reduced, at the same time, pruning also creates 1 -p fraction of neurons which do not receive signals at all and will purely output noise after training. Summing up the contributions from every neuron, the signal strength is reduced by a factor of p while the noise strength is reduced by a factor of √ p. We again reach the conclusion of pruning under any rate will hurt the signal more than noise. The above analysis shows that under any pruning rate, it seems pruning can only hurt the signal more than noise at the initialization. Such analysis would be indicative if the network training is under the neural tangent kernel regime, where the weight of each neuron does not travel far from its initialization so that the above analysis can still hold approximately after training. However, when the neural network training is in the feature learning regime, this average type analysis becomes misleading. Namely, in such a regime, the weights with large correlation with the signal at the initialization will quickly evolve into singleton neurons and those weights with small correlation will remain small. In our proof, we focus on the featuring learning regime, and analyze how the network weights change and what are the effect of pruning during various stages of gradient descent. We now analyze the effect of pruning on weight-signal correlation and weight-noise correlation at the initialization. Our first lemma leverages the sparsity of our signal and shows that if the pruning is mild, then it will not hurt the maximum weight-signal correlation much at the initialization. On the other hand, the maximum weight-noise correlation is reduced by a factor of √ p. Lemma 3.2 (Initialization). With probability at least 1 -2/d, for all i ∈ [n], σ 0 σ n pd ≤ max r w (0) j,r , ξ i ≤ 2 log(Kmd)σ 0 σ n pd. Further, suppose pm ≥ Ω(log(Kd)), with probability 1 -2/d, for all j ∈ [K], σ 0 ∥µ j ∥ 2 ≤ max r∈S j signal w (0) j,r , µ j ≤ 2 log(8pmKd)σ 0 ∥µ j ∥ 2 . Given this lemma, we now prove that there exists at least one neuron that is heavily aligned with the signal after training. Similarly to previous works (Allen-Zhu & Li, 2020; Zou et al., 2021; Cao et al., 2022) , the analysis is divided into two phases: feature growing phase and converging phase. Feature Growing Phase. In this phase, the gradient of the cross-entropy is large and the weightsignal correlation grows much more quickly than weight-noise correlation thanks to the polynomial ReLU. We show that the signal strength is relatively unaffected by pruning while the noise level is reduced by a factor of √ p. Lemma 3.3 (Feature Growing Phase, Informal). Under Condition 2.2, there exists time T 1 such that 1. The max weight-signal correlation is large: max r w (T1) j,r , µ j ≥ m -1/q for j ∈ [K]. 2. The weight-noise and cross-class weight-signal correlations are small: if j ̸ = y i , then max j,r,i w (T1) j,r , ξ i ≤ O(σ 0 σ n √ pd) and max j,r,k w (T1) j,r , µ k ≤ O(σ 0 µ). Converging Phase. We show that gradient descent can drive the training loss toward zero while the signal in the feature is still large. An important intermediate step in our argument is the development of the following gradient upper bound for multi-class cross-entropy loss which introduces an extra factor of K in the gradient upper bound. Lemma 3.4 (Gradient Upper Bound, Informal). Under Condition 2.2, we have ∇L S ( W (t) ) ⊙ M 2 F ≤ O(Km 2/q µ 2 )L S ( W (t) ). Proof Sketch. To prove this upper bound, note that for a given input (x i , y i ), ℓ ′(t) yi,i ∇F yi (x i ) should make major contribution to ∇ℓ( W; x i , y i ) F . Further note that |ℓ ′(t) yi,i | = 1 -logit yi (F ; x i ) = j̸ =y i e F j (x i ) j e F j (x i ) ≤ j̸ =y i e F j (x i ) e Fy i (x i ) . Now, apply the property that F j (x i ) is small for j ̸ = y i (which we prove in the appendix), the numerator will contribute a factor of K. To bound the rest, we utilize the special property of multi-class cross-entropy loss: |ℓ ′(t) j,i | ≤ |ℓ ′(t) yi,i | ≤ ℓ (t) i . However, a naive application of this inequality will result in a factor of K 3 instead K in our bound. The trick is to further use the fact that j̸ =yi |ℓ ′(t) j,i | = |ℓ ′(t) yi,i |. Using the above gradient upper bound, we can show that the objective can be minimized. Lemma 3.5 (Converging Phase, Informal). Under Condition 2.2, there exists T 2 such that for some time t ∈ [T 1 , T 2 ] we have 1. The results from the feature growing phase (Lemma 3.3) hold up to constant factors.

2.. The training loss is small

L S ( W (t) ) ≤ ϵ. Notice that the weight-noise correlation still remains reduced by a factor of √ p after training. Lemma 3.5 proves the statement of the training loss in Theorem 3.1. Generalization Analysis. Finally, we show that pruning can purify the feature by reducing the variance of the noise by a factor of p when a new sample is given. The lemma below shows that the variance of weight-noise correlation for the trained weights is reduced by a factor of p. Lemma 3.6. The neural network weight W ⋆ after training satisfies that P ξ max j,r w ⋆ j,r , ξ ≥ (2m) -2/q ≤ 2Km exp - (2m) -4/q O(σ 2 0 σ 2 n pd) . Using this lemma, we can show that pruning yields better generalization bound (i.e., the bound on the generalization loss) claimed in Theorem 3.1.

4. OVER PRUNING

Our second result shows that there exists a relatively large pruning fraction (i.e., small p) such that the learned model yields poor generalization, although gradient descent is still able to drive the training error toward zero. The full proof is defered to Appendix E. Theorem 4.1 (Main Theorem for Over Pruning, Informal). Under Condition 2.2 if p = Θ( 1 Km log d ), then with probability at least 1-1/ poly log d over the randomness in the data, network initialization and pruning, there exists T = O(η -1 nσ q-2 0 σ -q n (pd) -q/2 + η -1 ϵ -1 m 4 nσ -2 n (pd) -1 ) such that 1. The training loss is below ϵ: L S ( W (T ) ) ≤ ϵ. 2. The generalization loss is large: L D ( W (T ) ) ≥ Ω(log K). Remark 4.2. The above theorem indicates that in the over-pruning case, the training loss can still go to zero. However, the generalization loss of our neural network behaves no much better than random guessing, because given any sample, random guessing will assign each class with probability 1/K, which yields a generalization loss of log K. The readers might wonder why the condition for this to happen is p = Θ( 1 Km log d ) instead of O( 1 Km log d ). Indeed, the generalization will still be bad if p is too small. However, now the neural network is not only unable to learn the signal but also cannot efficiently memorize the noise via gradient descent. Proof Outline. Now we analyze the over-pruning case. We first show that there is a good chance that the model will not receive any signal after pruning due to the sparse signal assumption and mild overparameterization of the neural network. Then, leveraging such a property, we bound the weight-signal and weight-noise properties for the feature growing and converging phases of gradient descent, as stated in the following two lemmas, respectively. Our result indicates that the training loss can still be driven toward zero by letting the neural network memorize the noise, the proof of which further exploits the fact that high dimensional Gaussian noise are nearly orthogonal. Lemma 4.3 (Feature Growing Phase, Informal). Under Condition 2.2, there exists T 1 such that • Some weights has large correlation with noise: max r w (T1) yi,r , ξ i ≥ m -1/q for all i ∈ [n]. • The cross-class weight-noise and weight-signal correlations are small: if j ̸ = y i , then max j,r,i w (T1) j,r , ξ i = O(σ 0 σ n √ pd) and max j,r,k w (T1) j,r , µ k ≤ O(σ 0 µ). Lemma 4.4 (Converging Phase, Informal). Under Condition 2.2, there exists a time T 2 such that ∃t ∈ [T 1 , T 2 ], the results from phase 1 still holds (up to constant factors) and L S ( W (t) ) ≤ ϵ. Finally, since the above lemmas show that the network is purely memorizing the noise, we further show that such a network yields poor generalization performance as stated in Theorem 4.1.

5.1. SIMULATIONS TO VERIFY OUR RESULTS

In this section, we conduct simulations to verify our results. We conduct our experiment using binary classification task and show that our result holds for ReLU networks. Our experiment settings are the follows: we choose input to be The observations are summarized as follows. In Figure 2a , when the noise level is σ n = 0.5, the pruned network usually can perform at the similar level with the full model when p ≤ 0.5 and noticably better when p = 0.3. When p > 0.5, the test error increases dramatically while the training accuracy still remains perfect. On the other hand, when the noise level becomes large σ n = 1 (Figure 2b ), the full model can no longer achieve good testing performance but mild pruning can improve the model's generalization. Note that the training accuracy in this case is still perfect (omitted in the figure). We observe that in both settings when the model test error is large, the variance is also large. However, in Figure 2b , despite the large variance, the mean curve is already smooth. In particular, Figure 2c plots the testing error over the training iterations under p = 0.5 pruning rate. This suggests that pruning can be beneficial even when the input noise is large. x = [x 1 , x 2 ] = [ye 1 , ξ] ∈ R 800 and x 1 , x 2 ∈ R 400 ,

5.2. ON THE REAL WORLD DATASET

To further demonstrate the mild/over pruning phenomenon, we conduct experiments on MNIST (Deng, 2012) and CIFAR-10 ( Krizhevsky et al., 2009) datasets. We consider neural network ar- chitectures including MLP with 2 hidden layers of width 1024, VGG, ResNets (He et al., 2016) and wide ResNet (Zagoruyko & Komodakis, 2016) . In addition to random pruning, we also add iterative-magnitude-based pruning Frankle & Carbin (2018) into our experiments. Both pruning methods are prune-at-initialization methods. Our implementation is based on Chen et al. (2021c) . Under the real world setting, we do not expect our theorem to hold exactly. Instead, our theorem implies that (1) there exists a threshold such that the testing performance is no much worse than (or sometimes may slightly better than) its dense counter part; and (2) the training error decreases later than the testing error decreases. Our experiments on MLP (Figure 3a ) and VGG-16 (Figure 3b ) show that this is the case: for MLP the test accuracy is steady competitive to its dense counterpart when the sparsity is less than 79% and 36% for VGG-16. We further provide experiments on ResNet in the appendix for validation of our theoretical results.

6. DISCUSSION AND FUTURE DIRECTION

In this work, we provide theory on the generalization performance of pruned neural networks trained by gradient descent under different pruning rates. Our results characterize the effect of pruning under different pruning rates: in the mild pruning case, the signal in the feature is well-preserved and the noise level is reduced which leads to improvement in the trained network's generalization; on the other hand, over pruning significantly destroys signal strength despite of reducing noise variance. One open problem on this topic still appears challenging. In this paper, we characterize two cases of pruning: in mild pruning the signal is preserved and in over pruning the signal is completely destroyed. However, the transition between these two cases is not well-understood. Further, it would be interesting to consider more general data distribution, and understand how pruning affects training multi-layer neural networks. We leave these interesting directions as future works. 

C PRELIMINARY FOR ANALYSIS

In this section, we introduce the following signal-noise decomposition of each neuron weight from Cao et al. (2022) , and some useful properties for the terms in such a decomposition, which are useful in our analysis. Definition C.1 (signal-noise decomposition). For each neuron weight j ∈ [K], r ∈ [m], there exist coefficients γ (t) j,r,k , ζ (t) j,r,i , ω (t) j,r,i such that w (t) j,r = w (0) j,r + K k=1 γ (t) j,r,k • ∥µ k ∥ -2 2 • µ k ⊙ m j,r + n i=1 ζ (t) j,r,i • ξ j,r,i -2 2 • ξ j,r,i + n i=1 ω (t) j,r,i ξ j,r,i -2 2 • ξ j,r,i , where γ (t) j,r,j ≥ 0, γ (t) j,r,k ≤ 0, ζ (t) j,r,i ≥ 0, ω (t) j,r,i ≤ 0. It is straightforward to see the following: γ (0) j,r,k , ζ (0) j,r,i , ω (0) j,r,i = 0, γ (t+1) j,r,j = γ (t) j,r,j -I(r ∈ S j signal ) η n n i=1 ℓ ′(t) j,i • σ ′ w (t) j,r , µ yi ∥µ yi ∥ 2 2 I(y i = j), γ (t+1) j,r,k = γ (t) j,r,k -I((m j,r ) k = 1) η n n i=1 ℓ ′(t) j,i • σ ′ w (t) j,r , µ yi ∥µ yi ∥ 2 2 I(y i = k), ∀j ̸ = k, ζ (t+1) j,r,i = ζ (t) j,r,i - η n • ℓ ′(t) j,i • σ ′ w (t) j,r , ξ i ξ j,r,i 2 2 I(j = y i ), ω (t+1) j,r,i = ω (t) j,r,i - η n • ℓ ′(t) j,i • σ ′ w (t) j,r , ξ i ξ j,r,i 2 2 I(j ̸ = y i ), where {γ (t) j,r,j } T t=1 , {ζ j,r,i } T t=1 are increasing sequences and {γ (t) j,r,k } T t=1 , {ω j,r,i } T t=1 are decreasing sequences, because -ℓ ′(t) j,i ≥ 0 when j = y i , and -ℓ ′(t) j,i ≤ 0 when j ̸ = y i . By Lemma D.4, we have pd > n + K, and hence the set of vectors {µ k } K k=1 { ξ i } n i=1 is linearly independent with probability measure 1 over the Gaussian distribution for each j ∈ [K], r ∈ [m]. Therefore the decomposition is unique.

D PROOF OF THEOREM 3.1

We first formally restate Theorem 3.1. Theorem D.1 (Formal Restatement of Theorem 3.1). Under Condition 2.2, choose initialization variance σ 0 = Θ(m -4 n -1 µ -1 ) and learning rate η ≤ O(1/µ 2 ). For ϵ > 0, if p ≥ C 1 log d m for some sufficiently large constant C 1 , then with probability at least 1 -O(d -1 ) over the randomness in the data, network initialization and pruning, there exists T = O(Kη -1 σ 2-q 0 µ -q + K 2 m 4 µ -2 η -1 ϵ -1 ) such that the following holds: 1. The training loss is below ϵ: L S ( W (T ) ) ≤ ϵ. 2. The weights of the CNN highly correlate with its corresponding class signal: max r γ (T ) j,r,j ≥ Ω(m -1/q ) for all j ∈ [K]. 3. The weights of the CNN doesn't have high correlation with the signal from different classes: max j̸ =k,r∈[m] |γ (T ) j,r,k | ≤ O(σ 0 µ). 4. None of the weights is highly correlated with the noise: max j,r,i ζ (T ) j,r,i = O(σ 0 σ n √ pd), max j,r,i |ω (T ) j,r,i | = O(σ 0 σ n √ pd). Moreover, the testing loss is upper-bounded by L D ( W (T ) ) ≤ O(Kϵ) + exp(-n 2 /p). The proof of Theorem 3.1 consists of the analysis of the pruning on the signal and noise for three stages of gradient descent: initialization, feature growing phase, and converging phase, and the establishment of the generalization property. We present these analysis in detail in the following subsections. A special note is that the constant C showing up in the following proof of each subsequent Lemmas is defined locally instead of globally, which means the constant C within each Lemma is the same but may be different across different Lemma.

D.1 INITIALIZATION

We analyze the effect of pruning on weight-signal correlation and weight-noise correlation at the initialization. We first present a few supporting lemmas, and finally provide our main result of Lemma D.7, which shows that if the pruning is mild, then it will not hurt the max weight-signal correlation much at the initialization. On the other hand, the max weight-noise correlation is reduced by a factor of √ p. Lemma D.2. Assume n = Ω(K 2 log Kd). Then, with probability at least 1 -1/d, |{i ∈ [n] : y i = j}| = Θ(n/K) ∀j ∈ [K]. Proof. By Hoeffding's inequality, with probability at least 1 -δ/2K, for a fixed j ∈ [K], we have 1 n n i=1 I(y i = j) - 1 K ≤ log(4K/δ) 2n . Therefore, as long as n ≥ 2K 2 log(4K/δ), we have 1 n n i=1 I(y i = j) - 1 K ≤ 1 2K . Taking a union bound over j ∈ [K] and making δ = 1/d yield the result. Proof. When pm = Ω(log d), by multiplicative Chernoff's bound, for a given k ∈ [K], we have P m r=1 (m j,r ) k -pm ≥ 0.5pm ≤ 2 exp {-Ω (pm)} . Take a union bound over j ∈ [K], k ∈ [K], we have P m r=1 (m j,r ) k -pm ≥ 0.5pm, ∀j ∈ [K], k ∈ [K] ≤ 2K 2 exp {-Ω (pm)} ≤ 1/d. Lemma D.4. Assume p = 1/ poly log d. Then with probability at least 1 -1/d, for all j ∈ [K], r ∈ [m], d i=1 (m j,r ) i = Θ(pd). Proof. By multiplicative Chernoff's bound, we have for a given j, r P d i=1 (m j,r ) i -pd ≥ 0.5pd ≤ 2 exp{-Ω(pd)}. Take a union bound over j, r, we have P d i=1 (m j,r ) i -pd ≥ 0.5pd, ∀j ∈ [K], r ∈ [m] ≤ 2Km exp{-Ω(pd)} ≤ 1/d, where the last inequality follows from our choices of p, K, m, d. Lemma D.5. Suppose p = Ω(1/ poly log d), and m, n = poly log d. With probability at least 1 -1/d, we have ξ j,r,i 2 2 = Θ(σ 2 n pd), ξ j,r,i , ξ i ′ ≤ O(σ 2 n pd log d), µ k , ξ j,r,i ≤ | ⟨µ, ξ i ⟩ | ≤ O(σ n µ log d), for all j ∈ {-1, 1}, r ∈ [m], i, i ′ ∈ [n] and i ̸ = i ′ . Proof. From Lemma D.4, we have with probability at least 1 -1/d, d k=1 (m j,r ) k = Θ(pd), ∀j ∈ [K], r ∈ [m]. For a set of Gaussian random variable g 1 , . . . , g N ∼ N (0, σ 2 ), by Bernstein's inequality, with probability at least 1 -δ, we have N i=1 g 2 i -σ 2 N ≲ σ 2 N log 1 δ . Thus, by a union bound over j, r, i, with probability at least 1 -1/d, we have ξ j,r,i 2 2 = Θ(σ 2 n pd). For i ̸ = i ′ , again by Bernstein's bound, we have with probability at least 1 -δ, ξ j,r,i , ξ i ′ ≤ O σ 2 n pd log Kmn δ , for all j, r, i. Plugging in δ = 1/d gives the result. The proof for | ⟨µ, ξ i ⟩ | is similar. Lemma D.6. Suppose we have m independent Gaussian random variables g 1 , g 2 , . . . , g m ∼ N (0, σ 2 ). Then with probability 1 -δ, max i g i ≥ σ log m log 1/δ . Proof. By the standard tail bound of Gaussian random variable, we have for every x > 0, σ x - σ 3 x 3 e -x 2 /2σ 2 √ 2π ≤ P [g > x] ≤ σ x e -x 2 /2σ 2 √ 2π . We want to pick a x ⋆ such that P max i g i ≤ x ⋆ = (P [g i ≤ x ⋆ ]) m = (1 -P [g i ≥ x ⋆ ]) m ≤ e -m P[gi≥x ⋆ ] ≤ δ ⇒ P[g i ≥ x ⋆ ] = Θ log(1/δ) m ⇒ x ⋆ = Θ(σ log(m/(log(1/δ) log m))). Lemma D.7 (Formal Restatement of Lemma 3.2). With probability at least 1 -2/d, for all i ∈ [n], σ 0 σ n pd ≤ max r w (0) j,r , ξ i ≤ 2 log(Kmd)σ 0 σ n pd. Further, suppose pm ≥ Ω(log(Kd)). Then with probability 1 -2/d, for all j ∈ [K], σ 0 ∥µ j ∥ 2 ≤ max r∈S j signal w (0) j,r , µ j ≤ 2 log(8pmKd)σ 0 ∥µ j ∥ 2 . Proof. We first give a proof for the second inequality. From Lemma D.3, we know that |S j signal | = Θ(pm). The upper bound can be obtained by taking a union bound over r ∈ S j signal , j ∈ [K]. To prove the lower bound, applying Lemma D.6, with probability at least 1 -δ/K, we have for a given j ∈ [K] max r∈S j signal w (0) j,r , µ j ≥ σ 0 ∥µ j ∥ 2 log pm log K/δ . Now, notice that we can control the constant in pm (by controlling the constant in the lower bound of p) such that pm/ log(Kd) ≥ e. Thus, taking a union bound over j ∈ [K] and setting δ = 1/d yield the result. The proof of the first inequality is similar.

D.2 SUPPORTING PROPERTIES FOR ENTIRE TRAINING PROCESS

This subsection establishes a few properties (summarized in Proposition D.10) that will be used in the analysis of feature growing phase and converging phase of gradient descent presented in the next two subsections. Define T ⋆ = η -1 poly(1/ϵ, µ, d -1 , σ -2 n , σ -1 0 n, m, d). Denote α = Θ(log 1/q (T ⋆ )), β = 2 max i,j,r,k w (0) j,r , µ k , w j,r , ξ i . We need the following bound holds for our subsequent analysis.  4m 1/q max j,r,i w (0) j,r , µ yi , Cnα µ √ log d σ n pd , w = Θ(σ n √ d log d) = Θ(1) . In both mild pruning and over pruning we require p ≥ Ω(1/poly log d). Since α = Θ(log 1/q (T ⋆ )), if we assume T ⋆ ≤ O(poly(d)) for a moment (which we are going to justify in the next paragraph), then α = O(log 1/q (d)). Then if we set d to be large enough, we have 4m 1/q Cnα µ √ log d σnpd ≤ poly log d √ d ≤ 1. Finally for the quantity 4m 1/q max j,r,i {⟨ w (0) j,r , µ yi ⟩, ⟨ w (0) j,r , ξ i ⟩}, by Lemma 3.2, our assumption of K = O(log d) in Condition 2.2 and our choice of σ 0 = Θ(m -4 n -1 µ -1 ) in Theorem 3.1 (or Theorem D.1), we can easily see that this quantity can also be made smaller than 1. Now, to justify that T ⋆ ≤ O(poly(d)), we only need to justify that all the quantities T ⋆ depend on is polynomial in d. First of all, based on Condition 2.2, n, m = poly log(d) and µ = Θ(σ n √ d log d) = Θ(1) further implies σ -2 n = Θ(d log 2 d). Since Theorem 3.1 only requires σ 0 = Θ(m -4 n -1 µ -1 ), this implies σ -1 0 ≤ O(poly log d). Hence σ -1 0 n = O(poly log d). Together with our assumption that ϵ, η ≥ Ω(1/ poly(d)) (which implies 1/ϵ, 1/η ≤ O(poly(d))), we have justified that all terms involved in T ⋆ are at most of order poly(d). Hence T ⋆ = poly(d). Remark D.9. Here we make remark on our assumption on ϵ and η in Condition 2.2. For our assumption on ϵ, since the cross-entropy loss is (1) not strongly-convex and (2) achieves its infimum at infinity. In practice, the cross-entropy loss is minimized to a constant level, say 0.001. We make this assumption to avoid the pathological case where ϵ is exponentially small in d (say ϵ = 2 -d ) which is unrealistic. Thus, for realistic setting, we assume ϵ ≥ Ω(1/ poly(d)) or 1/ϵ ≤ O(poly(d)). To deal with η, the only restriction we have is η = O(1/µ 2 ) in Theorem 3.1 and Theorem 4.1. However, in practice, we don't use a learning rate that is exponentially small, say η = 2 -d . Thus, like dealing with ϵ, we assume η ≥ Ω(1/ poly(d)) or 1/η ≤ O(poly d). We make the above assumption to simplify analysis when analyzing the magnitude of F j (X) for j ̸ = y given sample (X, y). Proposition D.10. Under Condition 2.2, during the training time t < T ⋆ , we have 1. γ (t) j,r,j , ζ (t) j,r,i ≤ α, 2. ω (t) j,r,i ≥ -β -6Cnα log d pd . 3. γ (t) j,r,k ≥ -β -2Cnα µ √ log d σnpd . Notice that the lower bound has absolute value smaller than the upper bound. Proof of Proposition D.10. We use induction to prove Proposition D.10. Induction Hypothesis: Suppose Proposition D.10 holds for all t < T ≤ T ⋆ . We next show that this also holds for t = T via the following a few lemmas. Lemma D.11. Under Condition 2.2, for t < T , there exists a constant C such that w (t) j,r -w (0) j,r , µ k = γ (t) j,r,k ± Cnα µ √ log d σ n pd I((m j,r ) k = 1), w (t) j,r -w (0) j,r , ξ i = ζ (t) j,r,i ± 3Cnα log d pd , w (t) j,r -w (0) j,r , ξ i = ω (t) j,r,i ± 3Cnα log d pd . Proof. From Lemma D.5, there exists a constant C such that with probability at least 1 -1/d, ξ j,r,i , ξ i ′ ξ j,r,i 2 2 ≤ C log d pd , ξ j,r,i , µ k ξ j,r,i 2 2 ≤ C µ √ log d σ n pd , | ⟨µ k , ξ i ⟩ | ∥µ k ∥ 2 2 ≤ C σ n √ log d µ . Using the signal-noise decomposition and assuming (m j,r ) k = 1, we have w (t) j,r -w (0) j,r , µ k -γ (t) j,r,k = n i=1 ζ (t) j,r,i • ξ j,r,i -2 2 • ξ j,r,i , µ k + n i=1 ω (t) j,r,i ξ j,r,i -2 2 • ξ j,r,i , µ k ≤ C µ √ log d σ n pd n i=1 ζ (t) j,r,i + C µ √ log d σ n pd n i=1 ω (t) j,r,i ≤ 2C µ √ log d σ n pd nα. where the second last inequality is by Lemma D.5 and the last inequality is by induction hypothesis. To prove the second equality, for j = y i , w (t) j,r -w (0) j,r , ξ i -ζ (t) j,r,i = K k=1 γ (t) j,r,k • ⟨µ k , ξ i ⟩ ∥µ k ∥ 2 2 + i ′ ̸ =i ζ (t) j,r,i ′ • ξ j,r,i ′ , ξ i ξ j,r,i ′ 2 2 + n i ′ =1 ω (t) j,r,i ′ ξ j,r,i ′ , ξ i ξ j,r,i ′ 2 2 ≤ C σ n √ log d µ K k=1 |γ (t) j,r,k | + C log d pd i ′ ̸ =i |ζ (t) j,r,i ′ | + C log d pd n i ′ =1 |ω (t) j,r,i ′ | = C σ n √ log d µ Kα + 2Cnα log d pd ≤ 3Cnα log d pd . where the last inequality is by n ≫ K and µ = Θ(σ n √ d log d). The proof for the case of j ̸ = y i is similar. Lemma D.12 (Off-diagonal Correlation Upper Bound). Under Condition 2.2, for t < T , j ̸ = y i , we have that w (t) j,r , µ yi ≤ w (0) j,r , µ yi + Cnα µ √ log d σ n pd , w (t) j,r , ξ i ≤ w (0) j,r , ξ i + 3Cnα log d pd , F j ( W (t) j , x i ) ≤ 1. Proof. If j ̸ = y i , then γ (t) j,r,k ≤ 0 and we have that w (t) j,r , µ yi ≤ w (0) j,r , µ yi + γ (t) j,r,yi + Cnα µ √ log d σ n pd I((m j,r ) yi = 1) ≤ w (0) j,r , µ yi + Cnα µ √ log d σ n pd . Further, we can obtain w (t) j,r , ξ i ≤ w (0) j,r , ξ i + ω (t) j,r,i + 3Cnα log d pd ≤ w (0) j,r , ξ i + 3Cnα log d pd . Then, we have the following bound: F j ( W (t) j , x i ) = m r=1 [σ(⟨ w j,r , µ yi ⟩) + σ(⟨ w j,r , ξ i ⟩)] ≤ m2 q+1 max j,r,i w (0) j,r , µ yi , Cnα µ √ log d σ n pd , w (0) j,r , ξ i , 3Cnα log d pd q ≤ 1. where the first inequality is by Equation (1). Lemma D.13 (Diagonal Correlation Upper Bound). Under Condition 2.2, for t < T, j = y i , we have w (t) j,r , µ j ≤ w (0) j,r , µ j + γ (t) j,r,j + Cnα µ √ log d σ n pd , w (t) j,r , ξ i ≤ w (0) j,r , ξ i + ζ (t) j,r,i + 3Cnα log d pd . If max{γ (t) j,r,j , ζ (t) j,r,i } ≤ m -1/q , we further have that F j ( W (t) j , x i ) ≤ O(1). Proof. The two inequalities are immediate consequences of Lemma D.11. If max{γ (t) j,r,j , ζ (t) j,r,i } ≤ m -1/q , we have F j ( W (t) j , x i ) = m r=1 [σ(⟨ w j,r , µ j ⟩) + σ(⟨ w j,r , ξ i ⟩)] ≤ 2 • 3 q m max j,r,i γ (t) j,r , ζ (t) j,r,i , w (0) j,r , µ j , w (0) j,r , ξ i , Cnα µ √ log d σ n pd , 3Cnα log d pd q ≤ O(1). Lemma D.14. Under Condition 2.2, for t ≤ T , we have that 1. ω (t) j,r,i ≥ -β -6Cnα log d pd ; 2. γ (t) j,r,k ≥ -β -2Cnα µ √ log d σnpd . Proof. When j = y i , we have ω (t) j,r,i = 0. We only need to consider the case of j ̸ = y i . When ω (T -1) j,r,i ≤ -0.5β -3Cnα log d pd , by Lemma D.11 we have w (T -1) j,r , ξ i ≤ w (0) j,r , ξ i + ω (T -1) j,r,i + 3Cnα log d pd ≤ 0. Thus, ω (T ) j,r,i = ω (T -1) j,r,i - η n • ℓ ′(T -1) j,i • σ ′ w (T -1) j,r , ξ i ξ j,r,i 2 2 I(j ̸ = y i ) = ω (T -1) j,r,i ≥ -β -6Cnα log d pd . When ω (T -1) j,r,i ≥ -0.5β -3Cnα log d pd , we have ω (T ) j,r,i = ω (T -1) j,r,i - η n • ℓ ′(T -1) j,i • σ ′ w (T -1) j,r , ξ i ξ j,r,i 2 2 I(j ̸ = y i ) ≥ -0.5β -3Cnα log d pd - η n σ ′ 0.5β + 3Cnα log d pd ξ j,r,i 2 2 ≥ -β -6Cnα log d pd , where the last inequality is by setting η ≤ nq -1 0.5β + 3Cnα log d pd 2-q (C 2 σ 2 n d) -1 and C 2 is the constant such that ξ j,r,i 2 2 ≤ C 2 σ 2 n pd for all j, r, i in Lemma D.5. For γ (t) j,r,k , the proof is similar. Consider I((m j,r ) k ) = 1. When γ (t) j,r,k ≤ -0.5β -Cnα µ √ log d σnpd , by Lemma D.11, we have w (t) j,r , µ k ≤ w (0) j,r , µ k + γ (t) j,r,k + Cnα µ √ log d σ n pd ≤ 0. Hence, γ (T ) j,r,k = γ (T -1) j,r,k - η n n i=1 ℓ ′(T -1) j,i σ ′ w (T -1) j,r , µ k µ 2 I(y i = k) = γ (T -1) j,r,k ≥ -β -2Cnα µ √ log d σ n pd . When γ (t) j,r,k ≥ -0.5β -Cnα µ √ log d σnpd , we have γ (T ) j,r,k = γ (T -1) j,r,k - η n n i=1 ℓ ′(T -1) j,i σ ′ w (T -1) j,r , µ k µ 2 I(y i = k) ≥ -0.5β -Cnα µ √ log d σ n pd -C 2 η K σ ′ 0.5β + Cnα µ √ log d σ n pd µ 2 ≥ -β -2Cnα µ √ log d σ n pd , where the first inequality follows from the fact that there are Θ( n K ) samples such that I(y i = k), and the last inequality follows from picking η ≤ K(0.5β + Cnα µ √ log d σnpd ) 2-q µ -2 q -1 C -1 2 . Lemma D.15. Under Condition 2.2, for t ≤ T , we have γ (t) j,r,j , ζ (t) j,r,i ≤ α. Proof. For y i ̸ = j or r / ∈ S j signal , γ (t) j,r,j , ζ j,r,i = 0 ≤ α. If y i = j, then by Lemma D.12 we have ℓ ′(t) j,i = 1 -logit j (F ; X) = i̸ =j e Fi(X) K i=1 e Fi(X) ≤ Ke e Fj (X) . (2) Recall that γ (t+1) j,r,j = γ (t) j,r,j -I(r ∈ S j signal ) η n n i=1 ℓ ′(t) j,i • σ ′ w (t) j,r , µ yi ∥µ yi ∥ 2 2 I(y i = j), ζ (t+1) j,r,i = ζ (t) j,r,i - η n • ℓ ′(t) j,i • σ ′ w (t) j,r , ξ i ξ j,r,i 2 2 I(j = y i ). We first bound ζ (T ) j,r,i . Let T j,r,i be the last time t < T that ζ (t) j,r,i ≤ 0.5α. Then we have ζ (T ) j,r,i = ζ (Tj,r,i) j,r,i - η n ℓ ′(Tj,r,i) i • σ ′ w (Tj,r,i) j,r , ξ i I(y i = j) ξ j,r,i 2 2 I1 - Tj,r,i<t<T η n ℓ ′(Tj,r,i) j,i σ ′ w (t) j,r , ξ i I(y i = j) ξ j,r,i 2 2 I2 . We bound I 1 , I 2 separately. We first bound I 1 as follows. |I 1 | ≤ q η n ζ (Tj,r,i) j,r,i + 0.5β + 3Cnα log d pd q-1 C 2 σ 2 n pd ≤ q2 q n -1 ηα q-1 C 2 σ 2 n pd ≤ 0.25α, where the first inequality follows from Lemma D.13, the second inequality follows because β ≤ 0.1α and 3Cnα log d pd ≤ 0.1α, and the last inequality follows because η ≤ n/(q2 q+2 α q-2 σ 2 n d). For T j,r,i < t < T , by Lemma D.11, we have that w (t) j,r , ξ i ≥ 0.5α -0.5β -3Cnα log d pd ≥ 0.25α and w (t) j,r , ξ i ≤ α + 0.5β + 3Cnα log d pd ≤ 2α. Now we bound I 2 as follows |I 2 | ≤ Tj,r,i<t<T η n Ke exp {-F j (X)} σ ′ w (t) j,r , ξ i I(y i = j) ξ j,r,i 2 2 ≤ Tj,r,i<t<T η n Ke exp -σ w (t) j,r , ξ i σ ′ w (t) j,r , ξ i I(y i = j) ξ j,r,i 2 2 ≤ qKeη2 q-1 T ⋆ n exp(-α q /4 q )α q-1 σ 2 n pd ≤ 0.25T ⋆ exp(-α q /4 q )α q-2 α ≤ 0.25α, where the first inequality follows from Equation equation 2, the second inequality follows because F j (X) ≥ σ w (t) j,r , ξ i , the fourth inequality follows by choosing η ≤ n/(qKe2 q+1 σ 2 n d), and the last inequality follows by choosing α = Θ(log 1/q (T ⋆ )). Plugging the bounds on I 1 , I 2 finishes the proof for ζ (T ) j,r,i . To prove γ (t) j,r,j ≤ α, we pick η ≤ 1/(qe2 q+2 µ 2 ) and the rest of the proof is similar. Lemma D.14 and Lemma D.15 imply Proposition D.10 holds for all t ≤ T .

D.3 FEATURE GROWING PHASE

In this subsection, we first present a supporting lemma, and then provide our main result of Lemma D.17, which shows that the signal strength is relatively unaffected by pruning while the noise level is reduced by a factor of √ p. During the feature growing phase of training, the output of F j (X) = O(1) for all j ∈ [K]. There- fore, logit i (F, X) = O( 1 K ) and 1 -logit i (F, X) = Θ(1) until w (t) j,r , µ j reaches m -1/q . Lemma D.16. Under the same assumption as Theorem D.1, for T = nη -1 C4σ 2-q 0 (σn √ pd) -q C3(2C1) q-1 [log d] (q-1)/2 , the following results hold: • |ζ (t) j,r,i | = O(σ 0 σ n √ pd) for all j ∈ [K], r ∈ [m], i ∈ [n] and t ≤ T . • |ω (t) j,r,i | = O(σ 0 σ n √ pd) for all j ∈ [K], r ∈ [m], i ∈ [n] and t ≤ T . Proof. Define Ψ (t) = max j,r,i {ζ (t) j,r,i , |ω j,r,i |}. Then we have Ψ (t+1) ≤ Ψ (t) + max j,r,i      η n |ℓ ′(t) j,i | • σ ′    w (0) j,r , ξ i + K k=1 γ (t) j,r,k ⟨µ k , ξ i ⟩ ∥µ k ∥ 2 2 + n i ′ =1 Ψ (t) ξ j,r,i ′ , ξ i ξ j,r,i ′ 2 2    ξ j,r,i 2 2      ≤ Ψ (t) + η n q O( log dσ 0 σ n pd) + K log 1/q T ⋆ µσ n √ log d µ 2 + O(σ 2 n pd) + nO(σ 2 n √ pd log d) Θ(σ 2 n pd) Ψ (t) q-1 O(σ 2 n pd) ≤ Ψ (t) + η n O( log dσ 0 σ n pd) + O(Ψ (t) ) q-1 O(σ 2 n pd), where the second inequality follows by |ℓ ′(t) j,i | and applying the bounds from Lemma D.5, and the last inequality follows by choosing 2K log T ⋆ σnd √ p = O(1/ √ d) ≪ σ 0 . Let C 1 , C 2 , C 3 be the con- stants for the upper bound to hold in the big O notation. For any T = nη -1 C4σ 2-q 0 (σn √ pd) -q C3(2C1) q-1 [log d] (q-1)/2 = Θ( nη -1 σ 2-q 0 (σn √ pd) -q [log d] (q-1)/2 ), we use induction to show that Ψ (t) ≤ C 4 σ 0 σ n pd, ∀t ∈ [T ]. Suppose that Equation equation 3 holds for t ∈ [T ′ ] for T ′ ≤ T -1. Then Ψ (T ′ +1) ≤ Ψ (T ′ ) + η n C 1 log dσ 0 σ n pd + C 2 C 4 σ 0 σ n pd q-1 C 3 σ 2 n pd ≤ Ψ (T ′ ) + η n 2C 1 log dσ 0 σ n pd q-1 C 3 σ 2 n pd ≤ (T ′ + 1) η n 2C 1 log dσ 0 σ n pd q-1 C 3 σ 2 n pd ≤ T η n 2C 1 log dσ 0 σ n pd q-1 C 3 σ 2 n pd ≤ C 4 σ 0 σ n pd, where the last inequality follows by picking T = nη -1 C4σ 2-q 0 (σn √ pd) -q C3(2C1) q-1 [log d] (q-1)/2 = Θ( nη -1 σ 2-q 0 (σn √ pd) -q [log d] (q-1)/2 ). Therefore, by induction, we have Ψ (t) ≤ C 4 σ 0 σ n √ pd for all t ∈ [T ]. Lemma D.17 (Formal Restatement of Lemma 3.3). Under the same assumption as Theorem D.1, there exists time T 1 = log 2m -1/q log(1+Θ( η K )µ q σ q-2 0 ) = O(Kη -1 σ 2-q 0 µ -q log 2m -1/q ) such that 1. max r γ (T1) j,r,j ≥ m -1/q for j ∈ [K]. 2. |ζ j,r,i | ≤ O(σ 0 σ n √ pd) for all j ∈ [K], r ∈ [m], i ∈ [n] and t ≤ T 1 . 3. |γ (t) j,r,k | ≤ O(σ 0 µ poly log d) for all j, k ∈ [K], j ̸ = k, r ∈ [m] and t ≤ T 1 . Proof. Consider a fixed class j ∈ [K]. Denote T 1 to be the last time for t ∈ 0, nη -1 C4σ 2-q 0 (σn √ pd) -q C3(2C1) q-1 [log d] (q-1)/2 satisfying max r γ (t) j,r ≤ m -1/q . Then for t ≤ T 1 , max j,r,i ζ (t) j,r,i , |ω (t) j,r,i | ≤ O(σ 0 σ p √ pd) ≤ O(m -1/q ) and max j,r γ (t) j,r,j . Thus, by Lemma D.13, we obtain that F j ( W (t) , x i ) ≤ O(1), ∀y i = j. Thus, ℓ ′(t) j,i = Θ(1). For j ∈ S j signal , we have γ (t+1) j,r,j = γ (t) j,r,j - η n n i=1 ℓ ′(t) j,i • σ ′    w (0) j,r , µ j + γ (t) j,r,j + n i ′ =1 ζ (t) j,r,i ξ j,r,i , µ j ξ j,r,i 2 2 + n i ′ =1 ω (t) j,r,i ξ j,r,i , µ j ξ j,r,i 2 2    ∥µ j ∥ 2 2 I(y i = j) ≥ γ (t) j,r,j - η n n i=1 ℓ ′(t) j,i σ ′ w (0) j,r , µ j + γ (t) j,r,j -O(nσ 0 σ n pd σ n µ √ log d σ 2 n pd ) I(y i = j). Let γ (t) j,r,j = γ (t) j,r,j + w (0) j,r , µ j -O(nσ 0 σ n √ pd σnµ √ log d σ 2 n pd ) and A (t) = max r γ (t) j,r,j . Note that by our choice of µ, we have nµ √ log d σnpd = o(1). Since max r w (0) j,r , µ j ≥ Ω(σ 0 µ) by Lemma D.7, max r w (0) j,r , µ j ≥ Ω(σ 0 µ) -O(nσ 0 σ n pd σnµ √ log d σ 2 n pd ) = Ω(σ 0 µ). Then we have A (t+1) ≥ A (t) - η n n i=1 ℓ ′(t) j,i σ ′ (A (t) )µ 2 I(y i = j) ≥ A (t) + Θ( η K )µ 2 [A (t) ] q-1 ≥ (1 + Θ( η K µ 2 [A (t) ] q-2 ))A (t) ≥ (1 + Θ( η K µ q σ q-2 ))A (t) . Therefore, the sequence A (t) will exponentially grow and will reach 2m -1/q within log 2m -1/q log(1+Θ( η K )µ q σ q-2 0 ) = O(Kη -1 σ 2-q 0 µ -q log 2m -1/q ) ≤ Θ( nη -1 σ 2-q 0 (σn √ pd) -q [log d] (q-1)/2 ). Thus, max r γ (t) j,r ≥ A (t) -max j,r | w (0) j,r , µ j | ≥ 2m -1/q -O(σ 0 µ) ≥ 2 -m -1/q = m -1/q . Now we prove that under the same assumption as Theorem D.1, for T = O(Kη -1 σ 2-q 0 µ -q ), we have |γ (t) j,r,k | ≤ O(σ 0 µ poly log d) for all r ∈ [m], j, k ∈ [K], j ̸ = k and t ≤ T . We show that there exists a time T ′ ≥ T such that for all t ≤ T ′ , max j,r,k |γ (t) j,r,k | ≤ O(σ 0 µ poly log d). Let T ′ = O(K 2 η -1 σ 2-q 0 µ -q log d). Define Φ (t) = max r∈[m], j,k∈[K], j̸ =k {|γ (t) j,r,k |}. Since we assume T ≤ Θ( nη -1 σ 2-q 0 (σn √ pd) -q [log d] (q-1)/2 ), by Lemma D.16, we have ζ (t) j,r,i , |ω (t) j,r,i | ≤ O(σ 0 σ n √ pd). Φ (t+1) ≤ Φ (t) + max j,r,k,i      η n n i=1 I(y i = k)|ℓ ′(t) j,i |σ ′    w (0) j,r , µ k + n i ′ =1 ζ (t) j,r,i ′ ξ j,r,i ′ , µ k ξ j,r,i ′ 2 2 + n i ′ =1 ω (t) j,r,i ′ ξ j,r,i ′ , µ k ξ j,r,i ′ 2 2    µ 2      ≤ Φ (t) + η K 1 K q O(σ 0 µ log d) + nO(σ 0 σ n pd) σ n µ √ log d σ 2 n pd ) q-1 µ 2 ≤ Φ (t) + qη K 2 O(σ 0 µ log d) q-1 µ 2 , where the first inequality follows because γ (t) j,r,k < 0, the second inequality follows because there are Θ(n/K) samples from a given class k and |ℓ ′(t) j,i | = Θ( 1 K ) , and the last inequality follows because µ = σ n √ d log d. Now, let C be the constant such that the above holds with big O. Then, we use induction to show that Φ (t) ≤ C 2 σ 0 µ for all t ≤ T . We proceed as follows. Φ (t+1) ≤ Φ (t) + qη K 2 Cσ 0 µ log d q-1 µ 2 ≤ T qη K 2 Cσ 0 µ log d q-1 µ 2 ≤ C 2 σ 0 µpoly log d, where the last inequality follows by picking T = C2K 2 η -1 σ 2-q 0 µ -q √ log d C q-1 = O(K 2 η -1 σ 2-q 0 µ -q log d).

D.4 CONVERGING PHASE

In this subsection, we show that gradient descent can drive the training loss toward zero while the signal in the feature is still large. An important intermediate step in our argument is the development of the following gradient upper bound for multi-class cross-entropy loss. In this phase, we are going to show that • max r γ (t) j,r,j ≥ m 1/q for all j ∈ [K]. • max j̸ =k,r∈[m] |γ (t) j,r,k | ≤ β 1 where β 1 = O(σ 0 µ). • max j,r,i {ζ (t) j,r,i , |ω (t) j,r,i |} ≤ β 2 where β 2 = O(σ 0 σ n √ pd) Define W ⋆ as follows: w ⋆ j,r = w (0) j,r + Θ(m log(1/ϵ)) µ j µ 2 . Lemma D.18. Based on the result from the feature growing phase, W (T1) -W ⋆ 2 F ≤ O(Km 3 log 2 (1/ϵ)µ -2 ). Proof. We first compute W (T1) -W (0) 2 F = K j=1 m r=1 γ (T1) j,r,j µ j ⊙ m j,r µ 2 + k̸ =j γ (T1) j,r,k µ k ⊙ m j,r µ 2 + i ζ (T1) j,r,i ξ j,r,i ξ j,r,i 2 2 + i ω (T1) j,r,i ξ j,r,i ξ j,r,i 2 2 2 2 ≤ j r   γ (T1) j,r,j 1 µ + k̸ =j γ (T1) j,r,k 1 µ + i ζ (T1) j,r,i 1 ξ j,r,i 2 + i ω (T1) j,r,i 1 ξ j,r,i 2   2 ≤ j r O( 1 µ ) + K O(σ 0 ) + n O(σ 0 ) 2 ≤ j r O( 1 µ 2 ) = O(Km 1 µ 2 ), where the first inequality follows from triangle inequality, the second inequality follows from Lemma D.17, and the last inequality follows from our choice of σ 0 . On the other hand, W (0) -W ⋆ 2 F = j,r m 2 log 2 (1ϵ) 1 µ 2 = O(Km 3 log 2 (1/ϵ) 1 µ 2 ). Thus, we obtain W (T1) -W ⋆ 2 F ≤ 4 W (T1) -W (0) 2 F + 4 W (0) -W ⋆ 2 F ≤ O(Km 3 log 2 (1/ϵ) 1 µ 2 ). Lemma D.19 (Gradient Upper Bound). Under Condition 2.2, for t ≤ T ⋆ , there exists constant C = O(Km 2/q max{µ 2 , σ 2 n pd}) such that ∇L S ( W (t) ) ⊙ M 2 F ≤ CL S ( W (t) ). Proof. We need to prove that |ℓ ′(t) yi,i | ∇F ( W (t) , x i ) ⊙ M 2 F ≤ C. Assume y i ̸ = j. Then we obtain ∇F j ( W j , x i ) ⊙ M F ≤ r σ ′ w (t) j,r , µ yi µ yi + σ ′ w (t) j,r , ξ i ξ i 2 ≤ r σ ′ w (t) j,r , µ yi ∥µ yi ∥ 2 + σ ′ w (t) j,r , ξ i ξ i 2 ≤ m 1/q F j ( W j , x i ) (q-1)/q max{µ, Cσ n pd} ≤ m 1/q max{µ, Cσ n pd}, where the first and second inequality follow from triangle inequality, the third inequality follows from Hölder's inequality, and the last inequality follows from Lemma D.12. Similarly, on the other hand, if y i = j, then ∇F yi ( W) ⊙ M F ≤ m 1/q F yi ( W yi , x i ) (q-1)/q max{µ, Cσ n pd}. Therefore, j̸ =yi |ℓ ′(t) j,i | ∇F j ( W j , x i ) ⊙ M j 2 F ≤ j̸ =yi |ℓ ′(t) j,i |m 2/q O(max{µ 2 , σ 2 n pd}) = |ℓ ′(t) yi,i |m 2/q O(max{µ 2 , σ 2 n pd}) ≤ Ke exp{-F yi (x i )}m 2/q O(max{µ 2 , σ 2 n pd}), and |ℓ ′(t) yi,i | ∇F yi ( W yi , x i ) ⊙ M yi 2 F ≤ Ke exp{-F yi (x i )}m 2/q F yi ( W yi , x i ) 2(q-1)/q O(max{µ 2 , σ 2 n pd}, where the inequality follows from Equation equation 2. Thus, K j=1 |ℓ ′(t) j,i | 2 ∇F j ( W j , x i ) ⊙ M j 2 F ≤ |ℓ ′(t) yi,i | K j=1 |ℓ ′(t) j,i | ∇F j ( W j , x i ) ⊙ M j 2 F ≤ |ℓ ′(t) yi,i |Ke exp{-F yi (x i )}m 2/q O(max{µ 2 , σ 2 n pd} F yi ( W yi , x i ) (q-1)/q + 1 ≤ |ℓ ′(t) yi,i |O(Km 2/q max{µ 2 , σ 2 n pd}), where the first inequality follows because |ℓ ′(t) j,i | ≤ |ℓ ′(t) yi,i |, and the last inequality uses the fact that exp{-x}(1 + x (q-1)/q ) = O(1) for all x ≥ 0. The gradient norm can be bounded by ∇L S ( W (t) ) 2 F ≤ 1 n n i=1 ∇L( W (t) , x i ) F 2 =   1 n n i=1 K j=1 |ℓ ′(t) j,i | 2 ∇F j ( W (t) j , x i ) 2 F   2 ≤ 1 n n i=1 |ℓ ′(t) yi,i | ∇F ( W (t) j , x i ) F 2 ≤ 1 n n i=1 |ℓ ′(t) yi,i |O(Km 2/q max{µ 2 , σ 2 n d}) 2 ≤ O(Km 2/q max{µ 2 , σ 2 n d}) 1 n n i=1 |ℓ ′(t) yi,i | ≤ O(Km 2/q max{µ 2 , σ 2 n d})L S ( W (t) ), where the first inequality uses triangle inequality, the second inequality follows because |ℓ ′(t) j,i | ≤ |ℓ ′(t) yi,i |, the third inequality uses the bound equation 4, the fourth inequality uses Jensen's inequality and the last inequality follows because |ℓ ′(t) yi,i | ≤ ℓ (t) i . Lemma D.20. For T 1 ≤ t ≤ T ⋆ , we have for all j ̸ = y i , ∇F yi ( W (t) yi , x i ), W ⋆ yi -∇F j ( W (t) j , x i ), W ⋆ j ≥ q log 2qK ϵ . Proof of Lemma D.20. The proof of this lemma depends on the next two lemmas. Lemma D.21. For T 1 ≤ t ≤ T ⋆ and j = y i , we have ∇F j ( W (t) j , x i ), W ⋆ j ≥ Θ(m 1/q log(1/ϵ)). Proof. By Lemma D.17, we have max r w (t) j,r , µ j = max r      w (0) j,r , µ j + γ (t) j,r,j + n i=1 ζ (t) j,r,i ξ j,r,i , µ j ξ j,r,i 2 2 + n i=1 ω (t) j,r,i ξ j,r,i , µ j ξ j,r,i 2 2      ≥ m -1/q -O(σ 0 µ log d) -O(nσ 0 σ n √ d µ √ log d σ n pd ) ≥ Θ(m -1/q ), where the last inequality follows by picking σ 0 ≤ O(m -1 n -1 µ -1 (log d) -1/2 ). On the other hand, w (t) j,r , ξ i ≤ w (0) j,r , ξ i + |ω (t) j,r,i | + |ζ (t) j,r,i | + O(n log d pd α) + O(n µ √ log d σ n pd α) ≤ O(1), (5) where the first inequality follows from Lemma D.11 and the second inequality follows from Equation (1) and Proposition D.10. Therefore, ∇F j ( W (t) j , x i ), W ⋆ j = r σ ′ w (t) j,r , µ j µ j , w ⋆ j,r + r σ ′ w (t) j,r , ξ i ξ i , w ⋆ j,r ≥ r σ ′ w (t) j,r , µ j Θ(m log(1/ϵ)) - r O(σ 0 σ n pd log d + σ n √ log d µ m log(1/ϵ)) ≥ Θ(m 1/q log(1/ϵ)) -O(mσ 0 σ n pd log d + σ n √ log d µ m 2 log(1/ϵ)) ≥ Θ(m 1/q log(1/ϵ)), where the last inequality follows because mσ 0 σ n √ pd log d = o(1) and σn √ log d µ m 2 = o(1) by our choices of µ, σ 0 . Lemma D.22. For T 1 ≤ t ≤ T and j ̸ = y i , we have ∇F j ( W (t) j , x i ), W ⋆ j ≤ O(1). Proof. First we have w (t) j,r , µ yi = w (0) j,r , µ yi + γ (t) j,r,yi + n i=1 ζ (t) j,r,i ξ j,r,i , µ j ξ j,r,i 2 2 + n i=1 ω (t) j,r,i ξ j,r,i , µ j ξ j,r,i 2 2 ≤ O(σ 0 µ log d + σ 0 µ poly log d + nσ 0 σ n pd σ n µ √ log d σ 2 n pd ) ≤ O(1), where the first inequality follows from Lemma D.7, Lemma D.5 and Lemma D.17, and the last inequality follows from our choices of σ 0 , µ. Then, we have ∇F j ( W (t) j , x i ), W ⋆ j = r σ ′ w (t) j,r , µ yi µ yi , w ⋆ j,r + r σ ′ w (t) j,r , ξ i ξ i , w ⋆ j,r ≤ mO(σ 0 µ log d) + mO(σ 0 σ n d log d + m log(1/ϵ) σ n √ log d µ ) ≤ O(1), where the second inequality follows from Equation equation 6 and Equation equation 5, and the last inequality follows from our choices of µ, σ 0 . Applying the lower bound and upper bound from Lemma D.21 and Lemma D.22, we have ∇F yi ( W (t) yi , x i ), W ⋆ yi -∇F j ( W (t) j , x i ), W ⋆ j ≥ Θ(m 1/q log(1/ϵ)) -O(1) ≥ q log 2qK ϵ . Lemma D.23. Under the same assumption as Theorem D.1, we have W (t) -W ⋆ 2 F -W (t+1) -W ⋆ 2 F ≥ 5ηL S ( W (t) ) -ηϵ. Proof. To simplify our notation, we define F (t) j (x i ) = ∇F j ( W (t) j , x i ), W ⋆ j . We use the fact that the network is q-homogeneous. W (t) -W ⋆ 2 F -W (t+1) -W ⋆ 2 F = 2η ∇L S ( W (t) ) ⊙ M, W (t) -W ⋆ -η 2 ∇L S ( W (t) ) ⊙ M 2 F = 2η n n i=1 K j=1 ℓ ′(t) j,i qF j ( W (t) j ; x i , y i ) -∇F j ( W (t) j , x i ), W ⋆ j -η 2 ∇L S ( W (t) ) ⊙ M 2 F ≥ 2qη n n i=1   log(1 + K j=1 e Fj -Fy i ) -log(1 + K j=1 e ( Fj -Fy i )/q )   -η 2 ∇L S ( W (t) ) ⊙ M 2 F ≥ 2qη n n i=1 ℓ( W (t) ; x i , y i ) -log(1 + Ke -log(2qK/ϵ) ) -η 2 ∇L S ( W (t) ) ⊙ M 2 F ≥ 2qη n n i=1 ℓ( W (t) ; x i , y i ) - ϵ 2q -η 2 ∇L S ( W (t) ) ⊙ M 2 F ≥ CηL S ( W (t) ) -ηϵ, where the first inequality follows from the convexity of the cross-entropy loss with softmax, the second inequality follows from Lemma D.20, the third inequality follows because log(1 + x) ≤ x, and the last inequality follows from Lemma D.19 for some constant C. Lemma D.24 (Formal Restatement of Lemma 3.5). Under the same assumption as Theorem D.1, choose T 2 = T 1 + ∥ W (T 1 ) -W ⋆ ∥ 2 F 2ηϵ = T 1 + O(Km 3 log 2 (1/ϵ)µ -2 ). Then for any time t during this stage, we have max r γ (t) j,r,j ≥ m 1/q for all j ∈ [K], max j,r,i {|ζ (t) j,r,i |, |ω (t) j,r,i |} ≤ 2β 1 , max j̸ =k,r∈[m] {|γ (t) j,r,k |} ≤ 2β 2 , and 1 t -T 1 t s=T1 L S ( W (s) ) ≤ W (T1) -W ⋆ 2 F Cη(t -T 1 ) + ϵ C . Proof. From Lemma D.17, we have max r γ (T1) j,r,j ≥ m 1/q and since γ (t) is an increasing sequence over t, we have max r γ (t) j,r,j ≥ m 1/q for all t ∈ [T 1 , T 2 ]. We have W (s) -W ⋆ 2 F -W (s+1) -W ⋆ 2 F ≥ CηL S ( W (s) ) -ηϵ. Taking a telescopic sum from T 1 to t yields t s=T1 L S ( W (s) ) ≤ W (T1) -W ⋆ 2 F + ηϵ(t -T 1 ) Cη . Combining Lemma D.18, we have t s=T1 L S ( W (s) ) ≤ O(η -1 W (T1) -W ⋆ 2 F ) = O(η -1 Km 3 log 2 (1/ϵ)µ -2 ). Define Ψ (t) = max j,r,i {ζ (t) j,r,i , |ω (t) j,r,i |} and Φ (t) = max j̸ =k,r∈[m] |γ (t) j,r,k | and β 2 = O(σ 0 µ ). Now we use induction to prove Ψ (t) ≤ 2β 1 and Φ (t) ≤ 2β 2 . Suppose the result holds for time t ≤ t ′ . Then Ψ (t+1) ≤ Ψ (t) + max j,r,i      η n |ℓ ′(t) j,i | • σ ′    w (0) j,r , ξ i + K k=1 γ (t) j,r,k ⟨µ k , ξ i ⟩ ∥µ k ∥ 2 2 + n i ′ =1 Ψ (t) ξ j,r,i ′ , ξ i ξ j,r,i ′ 2 2    ξ j,r,i 2 2      ≤ Ψ (t) + η n q max i |ℓ ′(t) yi,i | O( log dσ 0 σ n pd) + K log 1/q T ⋆ µσ n √ log d µ 2 + O(σ 2 n pd) + nO(σ 2 n √ pd log d) Θ(σ 2 n pd) Ψ (t) q-1 O(σ 2 n pd) ≤ Ψ (t) + η n n i=1 |ℓ ′(t) yi,i | O( log dσ 0 σ n pd) + O(Ψ (t) ) q-1 O(σ 2 n pd), where the second inequality follows by |ℓ ′(t) j,i | ≤ |ℓ ′(t) yi,i | and applying the bounds from Lemma D.5, and the last inequality follows by choosing K log 1/q T ⋆ √ d = O( 1 √ d ) ≪ σ 0 σ n √ pd. Unrolling the recursion by taking a sum from T 1 to t ′ we have Ψ (t ′ +1) (i) ≤ Ψ (T1) + η n t ′ s=T1 n i=1 |ℓ ′(s) yi,i |O(σ 2 n pd poly log d)β q-1 1 (ii) ≤ Ψ (T1) + η n O(σ 2 n pd poly log d)β q-1 1 t ′ s=T1 n i=1 ℓ (s) i = Ψ (T1) + η n O(σ 2 n pd poly log d)β q-1 1 t ′ s=T1 L S ( W (s) ) (iii) ≤ Ψ (T1) + 1 n O(Km 3 µ -2 σ 2 n pd poly log d)β q-1 1 (iv) ≤ β 1 + O(Km 3 )β q-1 1 (v) ≤ 2β 1 , where (i) follows from induction hypothesis Ψ (t) ≤ 2β 1 , (ii) follows from the property of crossentropy loss with softmax |ℓ ′ j,i | ≤ |ℓ ′ yi,i | ≤ ℓ i , (iii) follows from Equation equation 7, (iv) follows from our choice of µ, n, K, and (v) follows because O(Km 3 )β 1 q-2 ≤ O(Km 3 σ 0 σ n √ pd) ≤ 1. Therefore, by induction Ψ (t) ≤ 2β 1 holds for time t ≤ t ′ + 1. On the other hand, Φ (t ′ +1) (i) ≤ Φ (t) + j,r,k,i      η n n i=1 I(y i = k)|ℓ ′(t) j,i |σ ′    w (0) j,r , µ k + n i ′ =1 ζ (t) j,r,i ′ ξ j,r,i ′ , µ k ξ j,r,i ′ 2 2 + n i ′ =1 ω (t) j,r,i ′ ξ j,r,i ′ , µ k ξ j,r,i ′ 2 2    µ 2      (ii) ≤ Φ (t) + Θ( η K ) max j,i |ℓ ′(t) j,i | O(σ 0 µ log d) + nO(σ 0 σ n pd) σ n µ √ log d σ 2 n pd ) q-1 µ 2 (iii) ≤ Φ (T1) + Θ( η K )µ 2 t s=T1 n i=1 ℓ (s) i O(σ 0 µ log d) q-1 (iv) ≤ β 2 + O(m 3 )β q-1 2 (v) ≤ 2β 2 , where (i) follows because γ (t) j,r,k ≤ 0, (ii) follows from Lemma D.7 and Lemma D.5, (iii) follows because max j,i |ℓ ′(t) j,i | ≤ max i |ℓ ′(t) yi,i | ≤ max i ℓ (t) i ≤ i ℓ (t) i , (iv) follows from Equation equation 7, and (v) follows because O(m 3 )β q-2 2 ≤ O(m 3 σ 0 µ) ≤ 1.

D.5 GENERALIZATION ANALYSIS

In this subsection, we show that pruning can purify the feature by reducing the variance of the noise by a factor of p when a new sample is given. Now the network has parameter w ⋆ j,r = w (0) j,r + K k=1 γ ⋆ j,r,k µ k ⊙ m j,r µ 2 + n i=1 ζ ⋆ j,r,i ξ j,r,i ξ j,r,i 2 2 + n i=1 ω ⋆ j,r,i ξ j,r,i ξ j,r,i . We have w ⋆ j,r 2 = O(σ 0 √ pd + µ -1 log 1/q (T ⋆ ) + Kσ 0 poly log d + nσ 0 σ n √ pd 1 σn √ pd ) = O(σ 0 √ pd). Lemma D.25 (Formal Restatement of Lemma 3.6). With probability at least 1 - 2Km exp -(2m) -4/q O(σ 2 0 σ 2 n pd) , max j,r w ⋆ j,r , ξ ≤ (2m) -2/q . Proof. Since w ⋆ j,r , ξ follows a Gaussian distribution with variance O(σ 2 0 σ 2 n pd), we have P w ⋆ j,r , ξ ≥ (2m) -2/q ≤ 2 exp - (2m) -4/q O(σ 2 0 σ 2 n pd) . Applying a union bound over j ∈ [K], r ∈ [m] gives the result.  (Kη -1 σ 2-q 0 µ -q + K 2 m 4 µ -2 η -1 ϵ -1 ) iterations, we can find W ⋆ such that • L S ( W ⋆ ) ≤ ϵ. • L D ≤ O(Kϵ) + exp(-n 2 /p). Proof. Let E be the event that Lemma D.25 holds. Then, we can divide L D ( W ⋆ ) into two parts: E[ℓ(F ( W ⋆ , x))] = E[I(E)ℓ(F ( W ⋆ , x))] I1 + E[I(E c )ℓ(F ( W ⋆ , x))] I2 . Since L S ( W ⋆ ) ≤ ϵ, for each class j ∈ [K] there must exist one training sample (x i , i ) ∈ S with y i = j such that ℓ(F ( W ⋆ , x i )) ≤ Kϵ ≤ 1 by pigeonhole principle. This implies that j ′ ̸ =j exp(F j ′ (x i ) -F j (x i )) ≤ 2Kϵ. Conditioning on the event E, by Lemma D.25, we have  |F j ( W ⋆ , x) -F j ( W ⋆ , x i )| ≤ ≤ K + j ′ ̸ =y F j ′ (x) = K + j ′ ̸ =y σ( w ⋆ j ′ ,r , µ y ) + σ( w ⋆ j ′ ,r , ξ ) ≤ K + Km(O(σ 0 µ log d)) q + O(m(σ 0 σ n √ d) q ) ∥ξ/σ n ∥ q 2 ≤ 2K + ∥ξ/σ n ∥ q 2 , where the first inequality follows because F y (x) ≥ 0, the second and third inequalities follow from the property of log function, and the last inequality follows from our choice of σ 0 ≤ O(m -4 n -1 σ -1 n d -1/2 ). We further have I 2 ≤ E[I(E)] E[ℓ(F ( W ⋆ , x)) 2 ] ≤ P(E c ) 4K 2 + E ∥ξ/σ n ∥ 2q 2 ≤ exp(-Cm -2/q σ -2 0 σ -2 n p -1 d -1 + log(d)) ≤ exp(-n 2 /p), where the first inequality follows from Cauchy-Schwarz inequality, the second inequality follows from Equation equation 8, the third inequality follows from Lemma D.25, and the last inequality follows because σ 0 ≤ O(m -4 n -1 σ -1 n d -1/2 ). E PROOF OF THEOREM 4.1 In this section, we show that there exists a relatively large pruning fraction (i.e., small p) such that while gradient descent is still able to drive the training error toward zero, the learned model yields poor generalization. We first provide a formal restatement of Theorem 4.1. Theorem E.1 (Formal Restatement of Theorem 4.1). Under Condition 2.2, choose initialization variance σ 0 = Θ(m -4 n -1 µ -1 ) and learning rate η ≤ O(1/µ 2 ). For ϵ > 0, if p = Θ( 1 Km log d ), then with probability at least 1 -1/ log(d), there exists T = O(η -1 nσ q-2 0 σ -q n (pd) -q/2 + η -1 ϵ -1 m 4 nσ -2 n (pd) -1 ) such that the following holds: 1. The training loss is below ϵ: L S ( W (T ) ) ≤ ϵ. 2. The model weight doesn't learn any of its corresponding signal at all: γ (t) j,r,j = 0 for all j ∈ [K], r ∈ [m]. 3. The model weights is highly correlated with the noise: max r∈[m] ζ (T ) j,r,i ≥ Ω(m -1/q ) if y i = j. Moreover, the testing loss is large: L D ( W (T ) ) ≥ Ω(log K). The proof of Theorem 4.1 consists of the analysis of the over-pruning for three stages of gradient descent: initialization, feature growing phase, and converging phase, and the establishment of the generalization property. We present these analysis in detail in the following subsections. Proof. First, the probability that a given class j receives no signal is (1 -p) m . We use the inequality that 1 + t ≥ exp {O(t)} ∀t ∈ (-1/4, 1/4). Then the probability that |S j signal | = 0, ∀j ∈ [K] is given by (1 -p) Km ≥ exp {-O (pKm)} ≥ 1 -O 1 log d .

E.2 FEATURE GROWING PHASE

Lemma E.3 (Formal Restatement of Lemma 4.3). Under the same assumption as Theorem E.1, there exists T 1 < T ⋆ such that T 1 = O(η -1 nσ q-2 0 σ -q n (pd) -q/2 ) and we have • max r ζ yi,r,i ≥ m -1/q for all i ∈ [n]. • max j,r,i |ω Let T i to be the last time that ζ (t) j,r,i ≤ m -1/q . We can compute the growth of B (t) i as B (t+1) i ≥ B (t) i + Θ( ησ 2 n pd n )[B (t) i ] q-1 ≥ B (t) i + Θ( ησ 2 n pd n )[B i ] q-2 B (t) i ≥ 1 + Θ ησ q-2 0 σ q n p q/2 d q/2 n B (t) i . Therefore, B i will reach 2m -1/q within O(η -1 nσ q-2 0 σ -q n (pd) -q/2 ) iterations. On the other hand, by Proposition D.10, we have |ω Proof. We derive the following bound: ≤ Km(O( √ Kσ 0 ) + O(n 1/2 σ -1 n (pd) -1/2 log 1/q T ⋆ )) + O(m 2 n 1/2 log(1/ϵ)σ -1 n (pd) -1/2 ) ≤ O(m 2 n 1/2 log(1/ϵ)σ -1 n (pd) -1/2 ), where the first inequality follows from triangle inequality, the second inequality follows from the expression of W (T1) , W ⋆ , and the third inequality follows from Lemma D.5 and the fact that ζ (t) j,r,i > 0 if and only if j = y i . W (T1) -W ⋆ F ≤ W (T1) -W (0) F + W (0) -W ⋆ F ≤ j, Lemma E.5. For T 1 ≤ t ≤ T ⋆ , we have ∇F yi ( W yi , x i ), W ⋆ yi -∇F j ( W j , x i ), W ⋆ j ≥ q log 2qK ϵ . Lemma E.6. For T 1 ≤ t ≤ T ⋆ and j = y i , we have ∇F j ( W (t) j , x i ), W ⋆ j ≥ Θ(m 1/q log(1/ϵ)). Proof. First of all, from Lemma E.9 we know there exists t ∈ [T 1 , T 2 ] such that L S ( W (T ) ) ≤ ϵ. ≤ m O(σ q 0 σ q n d q/2 + n q (pd) -q/2 ) ≤ 1, where the last inequality follows because σ 0 ≤ O(m -1/q µ -1 ) and d ≥ Ω(m 2/q n 2 ). Thus, with probability at least 1 -1/d, ℓ(F ( W (t) ; x)) ≥ log(1 + (K -1)e -1 ).



We point out that as many previous works(Allen-Zhu & Li, 2020; Zou et al., 2021;Cao et al., 2022), polynomial ReLU activation can help us simplify the analysis of gradient descent, because polynomial ReLU activation can give a much larger separation of signal and noise (thus, cleaner analysis) than ReLU. Our analysis can be generalized to ReLU activation by using the arguments in(Allen-Zhu & Li, 2022).



(1) Number of classes K = O(log d). (2) Total number of training samples n = poly log d. (3) Dimension d ≥ C d for some sufficiently large constant C d . (4) Relationship between signal strength and noise strength: µ = Θ(σ n √ d log d) = Θ(1). (5) The number of neurons in the network m = Ω(poly log d). (6) Initialization variance: σ 0 = Θ(m -4 n -1 µ -1 ). (7) Learning rate: Ω(1/ poly(d)) ≤ η ≤ O(1/µ 2 ). (8) Target training loss: ϵ = Θ(1/ poly(d)).

where ξ i is sampled from a Gaussian distribution. The class labels y are {±1}. We use 100 training examples and 100 testing examples. The network has width 150 and is initialized with random Gaussian distribution with variance 0.01. Then, p fraction of the weights are randomly pruned. We use the learning rate of 0.001 and train the network over 1000 iterations by gradient descent.

Figure 2: Figure (a) shows the relationship between pruning rates p and training/testing error under noise variance σ n = 0.5. Figure (b) shows the relationship between pruning rates p and testing error under noise variance σ n = 1. The training error is omitted since it stays effectively at zero across all pruning rates. Figure (c) shows a particular training curve under pruning rate p = 50% and noise variance σ n = 1. Each data point is created by taking an average over 10 independent runs.

Figure 4: The figure shows the experiment results of ResNet-20-128 under various sparsity by random pruning and IMP. Each data point is averaged over 2 runs.

Assume pm = Ω(log d) and m = poly log d. Then, with probability 1 -1/d, for all j ∈ [K], k ∈ [K], we have m r=1 (m j,r ) k = Θ(pm), which implies that |S j signal | = Θ(pm) for all j ∈ [K].

Theorem D.26 (Formal Restatement of Generalization Part of Theorem 3.1). Under the same assumptions as Theorem D.1, within O

have exp(F j ′ (x) -F j (x)) ≤ 2Kϵe 2 = O(Kϵ). Next we bound the term I 2 . ℓ(F ( W ⋆ , x)) = log   1 + j ′ ̸ =y exp(F j ′ (x) -F y (x)) =y log(1 + exp(F j ′ (x)))

INITIALIZATION Lemma E.2. When m = poly log d and p = Θ( 1 Km log d ), with probability 1 -O(1/ log d), for all class j ∈ [K] we have |S j signal | = 0.

,i | = O(σ 0 σ n √ pd). • max j,r,k |γ (t) j,r,k | ≤ O(σ 0 µ).Proof. First of all, recall that from Definition C.1 we have for j= y i w , ξ i -O(n log 1/q T ⋆ log d pd ) -O(nσ 0 σ n pd log d pd ) .Since max j=yi,r w(0) j,r , ξ i ≥ Ω(σ 0 σ n √ pd), we have B (0) i ≥ Ω(σ 0 σ n pd) -O(n log 1/q T ⋆ log d pd ) -O(nσ 0 σ n pd log d pd )≥ Ω(σ 0 σ n pd).

Based on the result from feature growing phase,W (T1) -W ⋆ F ≤ O(m 2 n 1/2 log(1/ϵ)σ -1 n (pd) -1/2 ).

(m 2 n 1/2 log(1/ϵ)σ -1 n (pd) -1/2 )

example (x, y). Taking a union bound over r, with probability at least 1 -d -1 , we havew (t) y,r , ξ = O(σ 0 σ n √ d + n(pd) -1/2 ),for all r ∈ [m]. Then,

Shiwei Liu, Decebal Constantin Mocanu, Amarsagar Reddy Ramapuram Matavalam, Yulong Pei, and Mykola Pechenizkiy. Sparse evolutionary deep learning with over one million artificial neurons on commodity hardware. Neural Computing and Applications, 33(7):2589-2604, 2021c. Shiwei Liu, Lu Yin, Decebal Constantin Mocanu, and Mykola Pechenizkiy. Do we actually need dense over-parameterization? in-time over-parameterization in sparse training. In International Conference on Machine Learning, pp. 6989-7000. PMLR, 2021d. Tianlin Liu and Friedemann Zenke. Finding trainable sparse networks through neural tangent transfer. In International Conference on Machine Learning, pp. 6336-6347. PMLR, 2020. Jian-Hao Luo and Jianxin Wu. An entropy-based pruning method for cnn compression. arXiv preprint arXiv:1706.05791, 2017. Eran Malach, Gilad Yehudai, Shai Shalev-Schwartz, and Ohad Shamir. Proving the lottery ticket hypothesis: Pruning is all you need. In International Conference on Machine Learning, pp. 6682-6691. PMLR, 2020. Advances in Neural Information Processing Systems, 33:2599-2610, 2020. Alexandra Peste, Eugenia Iofinova, Adrian Vladu, and Dan Alistarh. Ac/dc: Alternating compressed/decompressed training of deep neural networks. Advances in Neural Information Processing Systems, 34, 2021. Ameya Prabhu, Girish Varma, and Anoop Namboodiri. Deep expander networks: Efficient deep networks from graph theory. In Proceedings of the European Conference on Computer Vision (ECCV), pp. 20-35, 2018. Vivek Ramanujan, Mitchell Wortsman, Aniruddha Kembhavi, Ali Farhadi, and Mohammad Rastegari. What's hidden in a randomly weighted neural network? In Proceedings of the IEEE CVF Conference on Computer Vision and Pattern Recognition, pp. 11893-11902, 2020. Difan Zou and Quanquan Gu. An improved analysis of training over-parameterized deep neural networks. Advances in neural information processing systems, 32, 2019. Difan Zou, Yuan Cao, Dongruo Zhou, and Quanquan Gu. Gradient descent optimizes overparameterized deep relu networks. Machine Learning, 109(3):467-492, 2020. Difan Zou, Yuan Cao, Yuanzhi Li, and Quanquan Gu. Understanding the generalization of adam in learning neural networks with proper regularization. arXiv preprint arXiv:2108.11371, 2021. A EXPERIMENT DETAILS The experiments of MLP, VGG and ResNet-32 are run on NVIDIA A5000 and ResNet-50 and ResNet-20-128 is run on 4 NIVIDIA V100s. We list the hyperparameters we used in training. All of our models are trained with SGD and the detailed settings are summarized below. Summary of architectures, dataset and training hyperparametersWe plot the experiment result of ResNet-20-128 in Figure4. This figure further verifies our results that there exists pruning rate threshold such that the testing performance of the pruned network is on par with the testing performance of the dense model while the training accuracy remains perfect.

First recall from Condition 2.2 that m, n = poly(log d) and µ

annex

Proof. By Lemma D.5, we have ξ j,r,i , w ⋆ j,r = Θ(m log(1/ϵ)) and by Lemma E.3 for j = y i , max r w (t) j,r , ξ i ≥ max r ζ j,r,i -max r w (0) j,r , ξ i -O(n log d d α) ≥ Θ(m -1/q ). Then we have≥ Θ(m 1/q log(1/ϵ)).Lemma E.7. For T 1 ≤ t ≤ T ⋆ and j ̸ = y i , we haveProof. We first compute w ⋆ j,r ,where the inequality follows from Lemma D.5 and Lemma D.15. Thus, we havewhere the last inequality follows from our choice of σ 0 ≤ O(m -1/q µ -1 ).Lemma E.8. Under the same assumption as Theorem E.1, we haveProof. To simplify our notation, we defineThe proof is exactly the same as the proof of Lemma D.23.where the first inequality follows from the convexity of the cross-entropy loss with softmax, the second inequality follows from Lemma D.20, the third inequality follows because log(1 + x) ≤ x, and the last inequality follows from Lemma D.19 for some constant C > 0.Lemma E.9 (Formal Restatement of Lemma 4.4). Under the same assumption as Theorem E.1,). Then for any time t during this stage we have max j,r |ωProof. We haveTaking a telescopic sum from T 1 to t yieldsCη .Combining Lemma E.4, we have).

E.4 GENERALIZATION ANALYSIS

Theorem E.10 (Formal Restatement of the Generalization Part of Theorem 4.1). Under the same assumption as Theorem E.1, within O(η -1 nσ q-2 0 σ -q n (pd) -q/2 + η -1 ϵ -1 m 4 nσ -2 n (pd) -1 ) iterations, we can find W (T ) such that L S ( W (T ) ) ≤ ϵ, and L D ( W (t) ) ≥ Ω(log K).

