OVER-PARAMETERIZED MODEL OPTIMIZATION WITH POLYAK-ŁOJASIEWICZ CONDITION

Abstract

This work pursues the optimization of over-parameterized deep models for superior training efficiency and test performance. We first theoretically emphasize the importance of two properties of over-parameterized models, i.e., the convergence gap and the generalization gap. Subsequent analyses unveil that these two gaps can be upper-bounded by the ratio of the Lipschitz constant and the Polyak-Łojasiewicz (PL) constant, a crucial term abbreviated as the condition number. Such discoveries have led to a structured pruning method with a novel pruning criterion. That is, we devise a gating network that dynamically detects and masks out those poorly-behaved nodes of a deep model during the training session. To this end, this gating network is learned via minimizing the condition number of the target model, and this process can be implemented as an extra regularization loss term. Experimental studies demonstrate that the proposed method outperforms the baselines in terms of both training efficiency and test performance, exhibiting the potential of generalizing to a variety of deep network architectures and tasks.

1. INTRODUCTION

Most practical deep models are over-parameterized with the model size exceeding the training sample size and can perfectly fit all training points (Du et al., 2018; Vaswani et al., 2019) . Recent empirical and theoretical studies demonstrate that over-parameterization plays an essential role in model optimization and generalization (Liu et al., 2021b; Allen-Zhu et al., 2019) . Indeed, a plethora of state-of-the-art models that are prevalent in the community are over-parameterized, such as Transformer-based models for natural language modeling tasks (Brown et al., 2020; Devlin et al., 2018; Liu et al., 2019) and wide residual networks for computer vision tasks (Zagoruyko & Komodakis, 2016) . However, training over-parameterized models is usually time-consuming and can take anywhere from hours to weeks to complete. Notwithstanding some prior works (Liu et al., 2022; Belkin, 2021) on theoretical analyses of the over-parameterized models, those findings remain siloed from the common practices of training those networks. The work seeks to optimize over-parameterized models, in pursuit of superior training efficiency and generalization capability. We first analyze two key theoretical properties of over-parameterized models, namely the convergence gap and the generalization gap, which can be quantified by the convergence rate and the sample complexity, respectively. Theoretical analysis of over-parameterized models is intrinsically challenging as the over-parameterized optimization landscape is often nonconvex, limiting convexity-based analysis. Inspired by recent research on the convergence analysis of neural networks and other non-linear systems (Bassily et al., 2018; Gupta et al., 2021; Oymak & Soltanolkotabi, 2019) , we propose to use the Polyak-Łojasiewicz (PL) condition (Polyak, 1963; Karimi et al., 2016; Liu et al., 2022) as the primary mathematical tool to analyze convergence rate and sample complexity for over-parameterized models, along with the widely used Lipschitz con-Figure 1 : Benefits of PL regularization for BERT optimization. Figure (a-c ) shows that models with a smaller condition number (e.g., T 2 < T 1 in general) achieve faster training convergence and better test performance. In addition, pruning heads with a large condition number, i.e., Masked T , reduces the condition number, leading to more rapid and accurate training. Figure (d) shows T 2 heads have different condition numbers. The largest ones are pruned to produce Masked T in Figure (a-c ). dition (Allen-Zhu et al., 2019) . Our theoretical analysis shows that the aforementioned properties can be controlled by the ratio of the Lipschitz constant to the PL constant, which is referred to as the condition number (Gupta et al., 2021) . A small condition number indicates a large decrease in training loss after parameter updates and high algorithmic stability relative to data perturbation, i.e., fast convergence and good generalization ability. More promisingly, such pattern can be observed in empirical studies. As shown in Figure 1(a-c ), where BERT models were applied to WikiText-2 for language modeling, the training loss of the model with a small condition number (T 2 ) decreases much faster than the model with a large condition number (T 1 ), especially when the differences in condition numbers are pronounced (between 40 and 80 epochs); its test performance also improves much faster and is ultimately better. Such theoretical and empirical findings motivate us to formulate a novel regularized optimization problem which adds the minimization of condition number to the objective function; we call this new additional term PL regularization. In this way, we can directly regularize the condition number while training over-parameterized models, thereby improving their convergence speed and generalization performance. Our empirical analysis further reveals that, given an over-parameterized model, different model components exhibit distinct contributions to model optimization. Figure 1 (d) plots the heatmap of the condition number for all model heads at epoch-10 and it shows that the condition number varies considerably between model heads. Given the fact that over-parameterized models contain a large number of redundant parameters, we argue that it is possible to reduce the condition number of an over-parameterized network during training by identifying and masking out poorly-behaved sub-networks with large condition numbers. Figure 1(a-c ) illustrates the potential efficacy of this approach. After disabling 75% heads of the BERT model T 2 according to the condition number ranking at epoch-10, the masked BERT (Masked T) possesses a smaller condition number and achieves faster convergence and better test performance. This phenomenon motivates us to impose PL regularization, and hence improve model optimization, by adopting a pruning approach. More specifically, we introduce a binary mask for periodically sparsifying parameters, and the mask is learned via a gating network whose input summarizes the optimization dynamics of sub-networks in terms of PL regularization. An overview of the proposed method is provided in Appendix E.1. The proposed pruning approach to enforcing PL regularization is related to the structured pruning works, which focuses on compressing model size while maintaining model accuracy. The significant difference lies in that we utilize the condition number to identify the important components, which thus can simultaneously guarantee the convergence and generalization of model training. More importantly, as a consequence of this difference, compared with most pruning works which obtain a sparse model at a slight cost of degraded accuracy, our method is found to achieve even better test performance than the dense model when no more than 75% of the parameters were pruned. Experimental results demonstrate that our method outperforms state-of-the-art pruning methods in terms of training efficiency and test performance. The contributions of this work are threefold: • We are the first to propose using PL regularization in the training objective function of over-parameterized models. Such proposal is founded on the theoretical analysis of optimization and generalization properties of over-parameterized models, which shows that a small condition number implies fast convergence and good generalization. • We describe a PL regularization-driven structured pruning method for over-parameterized model optimization. Specifically, we introduce a learnable mask, guided by the PL-based condition number, to dynamically sparsify poorly-behaved sub-networks during model training to optimize training efficiency and test performance. • The proposed analysis framework and optimization method can be applied to a wide range of deep models and tasks. Experimental studies using three widely used deep models, namely BERT , Switch-Transformer , and VGG-16 , demonstrate that the proposed method outperforms the baselines on both training efficiency and test performance.

2. ANALYSIS OF OVER-PARAMETERIZED MODELS WITH PL CONDITION

In this section, we first show that, in the context of over-parameterized models, the optimal test error can be decomposed into convergence gap and generalization gap, which can be quantified via convergence rate and sample complexity, respectively. Next, using the tool of PL * condition, we show that both convergence rate and sample complexity are upper-bounded by the condition number and thus can be improved by using a small condition number.

2.1. OPTIMIZATION PROPERTIES OF OVER-PARAMETERIZED MODELS

Let X denote the input space and Y the target space. The standard optimization problem can be formulated as finding a minimizer of the expected risk: f * = argmin f :f ∈Y X R(f ), where the expected risk is defined as R(f ) := E (x,y)∼D [L(f (x), y)], D denotes an unknown data distribution over X × Y, and L denotes a non-negative loss function. To restrict the search space, f is typically chosen from a hypothesis class H ⊂ Y X ; the minimizer in this case is given by f * H = argmin f :f ∈H R(f ). In practice, the data distribution is unknown and thus the minimizer is found by minimizing the empirical risk: f * H,S = argmin f :f ∈H R S (f ), where the empirical risk is defined as R S (f ) := 1 n n i=1 L(f (x i ), y i ), and S denotes the training set S = {(x i , y i )} n i=1 drawn i.i.d. from D. The empirical risk minimization problem is often solved by using algorithms such as gradient descent; the optimal solution found by these algorithms is denoted by fH,S . Finally, when f is parameterized by w ∈ R m , we denote the loss function by L(x, y; w), or simply L(w), and also use L S (w) to denote the empirical risk. Definition 1 (Over-parameterized model Liu et al. (2020a) ). A model is an over-parameterized model if the number of parameters m is larger than the number of data. Motivated by the previous work (Gühring et al., 2020) , the Bayes error, i.e., the optimal test error, can be decomposed into five components: empirical error, convergence gap, generalization gap, estimation error and approximation error. In the over-parameterized setting, following previous works (Brutzkus et al., 2017; Belkin, 2021; Liu et al., 2022) , we assume that the optimal function f * H,S fits or interpolates the training data exactly. Therefore, for over-parameterized models, the Bayes error decomposition can be simplified as R(f * ) ≤ 2|R S ( fH,S ) -R S (f * H,S )| + 3 sup f ∈H |R(f ) -R S (f )|. The first and second terms refer to the convergence gap and the generalization gap, respectively. (The detailed decomposition can be found in Appendix A.) Thus, the Bayes error can be better approximated if we restrict the convergence gap and generalization gap. Furthermore, the convergence gap can be quantified in proportion to the reciprocal of convergence rate, and the number of instances required by specific generalization gap can be quantified in proportion to sample complexity. Definition 2 (Convergence Rate (Alzubi et al., 2018) ). If the model sequences f (w 0 ), . . . , f (w t ) during optimization converge to the optimal solution f (w * ), convergence rate can be defined as: lim t→∞ f (w t+1 ) -f (w * ) f (w t ) -f (w * ) = γ, ( ) where γ is a real number in range [0, 1]. The convergence rate is measured by the limit of the ratio of successive differences. The sequence converges sublinearly, linearly, or superlinearly depending on whether γ = 1, γ ∈ (0, 1), or γ = 0, respectively. For an algorithm with a linear rate of convergence, R S (f (w t )) ≤ exp(-tγ) (R S (f (w 0 ))). Definition 3 (Sample Complexity (Alzubi et al., 2018) ). For any ε, δ ∈ (0, 1), the sample complexity n(ε, δ) is defined as the smallest n < ∞ for which there exists a learning algorithm f such that, for any distribution D on training dataset S with the data sample x 1 , . . . , x n of size n, Pr (|R(f S ) -R S (f S )| ≤ ϵ) ≥ 1 -δ. The sample complexity denotes the number of required examples to guarantee a generalization gap smaller than ϵ with a probability 1 -δ.

2.2. PL * CONDITION AND THEORETICAL PROPERTIES

The previous subsection explains that faster convergence and better generalization are crucial to the test performance of over-parameterized models and these measures can be quantified by the convergence rate and sample complexity, respectively. In this subsection, we show that both quantities can be analyzed through the PL * condition and Lipschitz continuity, whose results suggest that encouraging an over-parameterized model to have a smaller condition number will generally result in faster convergence and better generalization. Definition 4 (Lipschitz Continuity (Nesterov, 2003) ). The function f is L f -Lipschitz with respect to parameters w if there exists a constant L f , such that:  |f (w 2 ) -f (w 1 )| ≤ L f ∥w 1 - w 2 ∥, > 0 if ∀w ∈ W, ∥∇f (w)∥ 2 ≥ µ • f (w). The inequality in the PL * condition implies that every stationary point is a global minimizer, which is also a property of strong convexity. However, unlike strong convexity, the PL * condition does not assume the uniqueness of the minimizer (Bassily et al., 2018) . Such difference makes PL * a more appropriate mathematical framework to analyze the theoretical properties of over-parameterization as the optimization of over-parameterized models results in manifolds of global minima. Following the original argument from Liu et al. (2022) ; Charles & Papailiopoulos (2018) , we use the PL * condition and Lipschitz continuity to analyze the convergence rate and sample complexity of over-parameterized models trained using gradient descent (GD). Theorem 6 Suppose the loss function L(w) and its first derivative are both L f -Lipschitz continuous and the loss function satisfies the µ-PL * condition in the ball B(w 0 , R) := {w ∈ R m : ∥w -w 0 ∥ ≤ R} with R > 0, then we have the following: (a) Convergence Rate of GD (Karimi et al., 2016) : Gradient descent with a step size η = 1 L f converges to a global solution with an exponential convergence rate: L S (w t ) ≤ (1 -µ L f ) t L S (w 0 ). (b) Sample Complexity: Suppose an algorithm A has pointwise hypothesis stability γ with respect to a bounded loss function such that 0 ≤ L(w) ≤ M . The model f parameterized by w is the output of algorithm A trained on dataset S. Then, for any ϵ, δ, we have Pr (|R(f S ) -R S (f S )| ≤ ϵ) ≥ 1 -δ, with the sample complexity n(ϵ, δ) ≤ 6L 2 f M µϵ 2 δ + M 2 2ϵ 2 δ . The proof of the theorem is deferred to Appendix B. The theorem indicates that 1) a smaller L f /µ would result in a faster convergence rate and decrease the convergence gap; and 2) a smaller L f /µ would result in a smaller sample complexity, i.e., fewer samples required to obtain a generalization gap smaller than ϵ with a probability higher than 1 -δ. In other words, the theorem suggests that the convergence rate and generalization ability of an over-parameterized model can be improved by encouraging the model to have a smaller L f /µ, which is referred to as the condition number in (Nesterov, 2003; Gupta et al., 2021) . Following the previous work, we call the ratio of the Lipschitz constant L f to the PL constant µ as the condition number. Definition 7 (Condition Number (Gupta et al., 2021) ). Suppose the loss function L satisfies the PL * condition and has L f -Lipschitz continuous gradients, the condition number is defined as L f µ .

3. MODEL OPTIMIZATION WITH PL CONDITION

Based on the analysis in the previous section, we advocate optimizing the condition number L f µ of over-parameterized models by adding it as a regularization term to the optimization problem; this technique is termed PL regularization. In this section, we first introduce the analytic form of PL regularization. Then, we propose a PL regularization-driven structured pruning method.

3.1. FORMULATION OF PL REGULARIZED OPTIMIZATION

To minimize the condition number, we propose the following regularized risk minimization problem, which adds the condition number to the training error: min w∈W L S (w) + α L f µ , where µ indicates the PL constant of the neural network with respect to parameters w; L f is the Lipschitz constant with respect to parameters w; L S (w) indicates the training error of the neural network with parameters w; α is a trade-off parameter. While it is desirable to train a network with a small condition number term, directly optimizing µ is often difficult. Following the uniform conditioning developed in previous works (Liu et al., 2022; Scaman et al., 2022) for an over-parameterized model f with a Lipschitz continuous loss function, e.g., the mean square error loss, the analysis of µ could reduce to calculating and optimizing the minimal eigenvalue of the neural tangent kernel (NTK) associated to f . Denote the NTK matrix of the function f as K(w) = DF(w)DF T (w) ∈ R n×n , where DF represents the differentiable map of f at w, and the smallest eigenvalue of K as λ min (K(w)); and denote µ * = inf w∈W λ min (K(w)) based on the uniform conditioning (Liu et al., 2022; 2020a; Scaman et al., 2022) . Consequently, the objective function ( 4) is represented as: min w∈W L S (w) + α L f µ * . ( ) Following the previous works (Liu et al., 2022; 2020b) , the following proposition shows that the condition number is upper bounded by Hessian norm ∥H f ∥ 2 , which is defined as ∥H f ∥ 2 = max i ∥H fi ∥ 2 , where H fi = ∂ 2 fi ∂w 2 . ∥H f ∥ 2 is termed the Hessian norm of f hereinafter. Proposition 8 Given a point w 0 , B(w 0 , R) denotes a ball centered at w 0 with radius R > 0. Suppose the loss function satisfies the µ * -PL * condition and has L f -Lipschitz gradients, for all w ∈ B(w 0 , R). Then the condition number is upper bounded by L f µ * ≤ ∥DF (w0)∥2+sup w∈B √ nR•∥H f (w)∥2 inf w∈B λmin(K(w)) . The proof can be found in Appendix C. Proposition 8 motivates us to optimize the upper bound of condition number for PL regularization optimization.

3.2. IMPLEMENTATION OF PL-REGULARIZED OPTIMIZATION VIA PRUNING

In this subsection, we adopt a pruning approach controlled by Hessian norm to solve the new regularized risk minimization problem.

3.2.1. LEARNING OBJECTIVE OF PL REGULARIZATION-DRIVEN PRUNING ALGORITHM

First, we propose the following theorem to show that pruning parameters governed by minimizing the Hessian norm ∥H f ∥ 2 is able to decrease the upper bound of the condition number. Theorem 9 Suppose the neural network f parameterized by w (w ∈ B (w 0 , R)) and its loss function L have the same assumptions as in Proposition 8. Pruning parameters from the parameter set governed by minimizing the Hessian norm ∥H f ∥ 2 can decrease the upper bound of the condition number. Proof sketch. Based on the Taylor expansion around initialization w 0 , we show that the pruning parameters controlled by minimizing the Hessian norm ∥H f ∥ 2 is able to minimize the upper bound of Lipschitz (according to Proposition 8); and based on the matrix theory, pruning parameter is able to maximize the minimal eigenvalue of NTK (Proposition 14); details can be found in Appendix D. To disable unnecessary or harmful parameters, we introduce a binary pruning mask to implement structured pruning, i.e., each element in the mask controls a subset of parameters. This design is particularly suitable for parameters with inherent structures, e.g., some parameters are from the same convolutional filter in a CNN or head in a Transformer. Suppose the parameters w consists of p subsets of parameters and let w [i] denote the i th subset of the parameters. Then the mask m, a Boolean vector of length p, sparsifies the parameters as follows: w [i] is disabled, i.e., all elements of w [i] equal to 0, if m i = 0; otherwise, w [i] remains the same. With the introduction of the pruning mask, Eq. 5 can be transformed to the following optimization problem, which additionally learns the binary mask to prune the network: min w,m L prune = L S (m * w) + α∥H f (m * w)∥ 2 , where * denotes the Khatri-Rao product, i.e., m * w = m 1 w [1] • • • m p w [p] ⊤ ; and ∥H f (m * w)∥ 2 denotes the Hessian norm w.r.t. the pruned model. In practice, we optimize the Hessian norm by minimizing the trace of the Hessian which can be efficiently approximated; details can be found in Appendix E.2. We further conduct an empirical analysis to show that pruning parameters controlled by penalizing the trace of Hessian leads to a decrease in the condition number; the result is included in Figure 5 of Appendix D. Directly solving Eq. 6 to obtain m is difficult, and therefore we adopt a gating network to learn a binary mask and perform structured pruning. The mask is defined as a differentiable gating function: m = g(d w ; v), where d w and v are the input and parameter of the function g; d w , dubbed as the parameter features, encodes the optimization dynamics characterized by the PL * condition. In the following two subsections, we explain the details of the parameter features d w and the gating function g. Algorithm 2 summarizes the proposed pruning algorithm.

3.2.2. PARAMETER FEATURES

d w [i] (d w [i] ∈ R 2 ) is designed to encode two pieces of information about the optimization dynamics for each w [i] . According to the discussion in Section 3.1, the PL * condition at w can be analyzed through the minimum eigenvalue of the corresponding NTK. Therefore, the first parameter feature about w [i] is chosen as λ min (K(w [i] )). The second parameter feature is empirically set as the entropy of the eigenvalue distribution of the corresponding tangent kernel, which provides a measure to summarize the distribution envelope. Given the NTK matrix K(w [i] ), the entropy, denoted by ρ, is defined as: ρ K w [i] = - n i=1 λi K w [i] log λi K w [i] , where λi = λ i / j λ j is the normalized eigenvalue. ρ is maximal when the eigenvalues are all equal and small when a single eigenvalue is much larger than all others.

3.2.3. GATING NETWORK

The gating network is implemented by a two-layer feedforward network (FN) with parameters v, which outputs a scalar between 0 and 1 to indicate the importance of w [i] : g(d w ; v) = softmax FN v (d w [1] ); • • • ; FN v (d w [p] ) ∈ [0, 1] p , The parameter v is updated according to the pre-defined pruning times and pruning ratios {t j , r j } q j=1 . For example, under the one-shot pruning schedule, i.e., {t, r}, we update v once at epoch t to prune r% of parameters; under the linear pruning schedule, i.e., {t j , r f /t max } tmax j=1 where t max denotes the maximum number of training iterations and r f denotes the target pruning ratio, we update v at every epoch to prune r f /t max % of parameters (Hoefler et al., 2021) . Once the update of v is complete, we apply the TopK function to generate the binary mask m: m = TopK (g(d w ; v), k), where m i = 1, if g(d w ; v) i is in the top k elements of g(d w ; v), 0, otherwise. At pruning time t j , the hyperparameter k is set according to the pruning ratio: k = (1 -r j )p. By utilizing such a learnable gating function, the mask is capable of adaptively inducing sparsity on the network, till the pruned network produces good accuracy and fulfills the target pruning ratio. The final output m tq is used to produce the sparse over-parameterized model f (m tq * w).

4. RELATED WORK

Over-parameterized Model. Deep models used in practice are often heavily overparameterized (Zhang et al., 2021; Nakkiran et al., 2021) . Recent progress has been made in convergence analysis of over-parameterized models optimized by (stochastic) gradient descent algorithms (Du et al., 2018; Liu et al., 2020b; Oymak & Soltanolkotabi, 2019) . In particular, the PL condition (Polyak, 1963) and its slight variant -the PL * condition (Liu et al., 2022) -have attracted interest in connection with the convergence guarantees. Some works prove that for overparameterized models with the proper random initialization, gradient descent-based methods provably find the global minimum with a global polynomial time convergence (Du et al., 2018; Belkin, 2021; Karimi et al., 2016; Liu et al., 2022) . Furthermore, the work (Charles & Papailiopoulos, 2018) shows the generalization bound of models satisfying the PL condition via stability hypothesis (Bousquet & Elisseeff, 2002) . Inspired by the success of recent theoretical analysis, this work uses PL * condition to analyze the convergence and generalization ability, which is then used to guide the training of practical over-parameterized models. Network Pruning. Network pruning aims to minimize network parameters while maintaining the model performance. Most existing structured pruning methods rely on approximations of loss differences resulting from parameter perturbation. There are three classes of pruning techniques. The first uses zeroth-order information such as the magnitude of weight and activation values with respect to the running model, e.g., adding L1 or L2 penalty norm to the objective function (Han et al., 2015b; a; He et al., 2018) or using the Lottery Ticket Hypothesis and its extensions (Frankle & Carbin, 2018; Frankle et al., 2019; Morcos et al., 2019; You et al., 2020; Yu et al., 2020; Liu et al., 2021b) . The second uses gradient-based first-order methods to approximate loss change via the first-order Taylor expansion of the training loss (Liu & Wu, 2019; You et al., 2019; Mozer & Smolensky, 1988; Molchanov et al., 2016; Sanh et al., 2020) . The third uses the Hessian or Fisher information of the loss to select weights for deletion (Hassibi et al., 1993; Hassibi, 1992; Cun et al., 1990; Molchanov et al., 2019; Liu et al., 2021a) . Different from structured pruning works often accompanied with negative impacts on model generalization, the proposed pruning method based on PL regularization simultaneously improves convergence and generalization ability.

5. EXPERIMENTS

This section presents experiments on three over-parameterized models, namely BERT (Devlin et al., 2018) , Switch-Transformer (Fedus et al., 2021) , and VGG-16 (Simonyan & Zisserman, 2014) . Detailed ablation experiments can be found in Appendix F. Evaluation Measure. We evaluate the effectiveness of our algorithm for model optimization in terms of training efficiency and generalization ability. Training efficiency is measured by the training error of the model trained with fixed training iterations (or epochs) (Shazeer et al., 2017; Fedus et al., 2021) . Generalization ability is measured by the test performance (Liu et al., 2021b) . For fair comparisons, all models are evaluated with the same pruning ratio, which is computed as the number of disabled parameters divided by the parameter counts of the corresponding structure.

5.1. RESULTS ON BERT

This experiment focuses on the Multi-Head Attention (MHA) component of BERT-Base (Devlin et al., 2018) , as MHA plays a critical role in the success of BERT. Following the basic setup of (Devlin et al., 2018) , we train BERT with self-supervised masked language modeling (MLM) task (Devlin et al., 2018) on WikiText-2 (Merity et al., 2016) . The baseline methods contain vanilla BERT (Wang et al., 2019) , two Transformer-based pruning methods, namely BERT-LTH (Chen et al., 2020) and Att-Score (Michel et al., 2019) , and two pruning approaches, namely SNIP (Lee et al., 2018) and GraSP (Wang et al., 2020) . More experimental details can be found in Appendix E.3. (Shazeer et al., 2017) , which has shown excellent performance in NLP tasks recently. For comparison, the baseline methods include vanilla Switch-Transformer, Switch-LTH (Chen et al., 2020) , Exp-Score (Michel et al., 2019) , SNIP (Lee et al., 2018) , and GraSP (Wang et al., 2020) . All experiments are conducted on WikiText-103 (Merity et al., 2016) with the masked language modeling task. More experimental details can be found in Appendix E.3. . This phenomenon indicates that while the sparsely-activated MoE architecture selects different experts for each data point, there also exist certain poorly-behaved experts. By modeling the optimization dynamics which guarantee the convergence and generalization ability, the proposed PL regularization-driven pruning method can disable under-or over-specialized experts and keep well-behaved experts. We also investigate the benefits of implementing PL regularization on both heads of MHA and experts of MoE, as shown in the last line in Table 2 . Compared with the vanilla Switch-Transformer, implementing PL regularization on both heads and experts presents a 28% improvement in training perplexity after 70 k iterations and a 13% improvement in test performance. This result further demonstrates the proposed PL regularization is flexible and applicable to different architectures.

5.3. RESULTS ON VGG-16

This experiment focuses on VGG-16 (Simonyan & Zisserman, 2014) trained on CIFAR-10 and CIFAR-100 datasets. The baselines include vanilla VGG-16 (Li et al., 2016) three filter pruning methods, namely FPGM (He et al., 2019) , CHIP (Sui et al., 2021) , and DPFPS (Ruan et al., 2021) , and three general pruning approaches, namely LTH (Frankle & Carbin, 2018) , SNIP (Lee et al., 2018) , and GraSP (Wang et al., 2020) . All models are trained from scratch and are pruned with a linear schedule where starting from epoch-15, we prune the same number of filters at each epoch until the target sparsity is reached. More details of baselines and experimental settings can be found in Appendix E.4. Additional experimental results of ResNet-56 can be found in Appendix F.1. Table 3 shows the performance of all methods trained on CIFAR-10 and CIFAR-100 with a 50% pruning ratio. As we can see, our method outperforms all pruning baselines in terms of training efficiency and test accuracy on both datasets. Especially, compared with the vanilla VGG-16, our method achieves a higher accuracy with fewer parameters, demonstrating the effectiveness of PL regularization implemented as structured pruning for VGG-16 optimization. Method & Dataset CIFAR-10 CIFAR-100 Train Loss (@ epochs) Test Accuracy (@ epochs) Train Loss (@ epochs) Test Accuracy (@ epochs) @100 @150 @100 @150 Final @100 @150 @100 @ 

5.4. QUANTITATIVE ANALYSIS

Figure 2 : Heatmaps of condition number of the remained experts determined by LTH, Exp-Score, SNIP, GraSP, and our method. Baselines fail to identify the well-behaved experts (The experts marked in red are the ones that are wellbehaved but missed by baselines). This experiment aims to understand why the baselines perform differently in terms of training efficiency and test performance. Theorem 6 states that the condition number controls the convergence and generalization ability of a model. Thus, we take a trained Switch-Transformer as an example and compare the condition number of important experts identified by different pruning criteria, including the weight magnitude of LTH, the gradient-based score of Exp-Score, the gradient value of SNIP, and the Hessian value of GraSP. Figure 2 presents heatmaps of the condition number and marks important experts by squares. It shows that our method has the ability to retain well behaved experts with small condition number while the baselines fail to retain these experts. The reason is that these baselines bias towards training loss, adversely affecting generalization. This observation confirms the importance of simultaneously considering both the convergence gap and the generalization gap for over-parameterized model optimization. Furthermore, the criteria of recent pruning works can be regarded as the parameter features and fed into the proposed gating network, which may help optimize PL regularization by providing rich information of parameters. We leave it for future works.

6. CONCLUSION

This paper establishes a principled connection between model optimization and the PL condition and advocates utilizing this theoretical finding when optimizing over-parameterized models. A new optimization problem is therefore formulated with the proposed PL regularization, and we illustrate that it can be solved via a structured pruning method. We postulate that the effect of such regularization may be achieved in other ways, e.g., by using more advanced optimization algorithms which induce implicit regularization. Moreover, experiments demonstrate the superiority of the proposed PL regularization-driven pruning algorithm, unveiling the potential of adopting theoretically founded measures as the pruning criteria.

A BAYES ERROR DECOMPOSITION

Motivated by the previous work (Gühring et al., 2020) , the Bayes error, i.e., the optimal test error, can be decomposed as: R(f * ) ≤ R S ( fH,S ) empirical error + |R S (f * H,S ) -R S ( fH,S )| convergence gap + |R(f * H,S ) -R S (f * H,S )| generalization gap + |R(f * H ) -R(f * H,S )| estimation error + |R(f * H ) -R(f * )| approximation error . ( ) In the over-parameterized setting, we assume that the optimal function f * H,S fits or interpolates the training data exactly, i.e., R S (f * H,S ) = 0. Thus the empirical error equals the convergence gap. Moreover, the estimation error can be associated with the generalization gap: Bousquet et al., 2003) . Furthermore, as over-parameterization refers to manifolds of potential interpolating predictors, H for the overparameterized model can be omitted, i.e., R(f * H ) = R(f * ). Therefore, for over-parameterized models, Eq. 10 can be simplified as |R(f * H,S ) -R S (f * H,S )| ≤ 2 sup f ∈H |R(f ) -R S (f )| ( R(f * ) ≤ 2 |R S ( fH,S ) -R S (f * H,S ) 0 | convergence gap +3 sup f ∈H |R(f ) -R S (f )| generalization gap . B PROOF OF THEOREM 6

B.1 CONVERGENCE RATE

The convergence rate of the over-parameterized model has been studied by many works (Liu et al., 2022; Karimi et al., 2016; Allen-Zhu et al., 2019) . We assume that the loss function has the L f -Lipschitz continuous gradient, i.e., for all parameters w and v, such that: L(w) ≤ L(v) + ⟨∇L(v), w -v⟩ + L f 2 ∥w -v∥ 2 , ( ) Using the assumption, we have that: L S (w t+1 ) -L S (w t ) ≤ ⟨∇L S (w t ), w t+1 -w t ⟩ + L f 2 ∥w t+1 -w t ∥ 2 Given the gradient method with a learning rate η = 1 L f , such that w t+1 = w t -η∇L S (w t ) = w t -1 L f ∇L S (w t ), we have L S (w t+1 ) -L S (w t ) ≤ - 1 2L f ∥∇L S (w t )∥ 2 (14) Under the PL * condition at point w t , we have L S (w t+1 ) -L S (w t ) ≤ - µ L f (L S (w t )) Therefore, we have L S (w t ) ≤ (1 - µ L f )(L S (w t-1 )) ≤ (1 - µ L f ) t (L S (w 0 ))

B.2 SAMPLE COMPLEXITY

Following previous work (Charles & Papailiopoulos, 2018) , we first calculate the generalization bound ϵ for the function f with the pointwise hypothesis stability γ. Then, we analyze the pointwise hypothesis stability γ for model trained on the loss function satisfying the PL * condition with parameter µ. Finally, given the pointwise hypothesis stability γ, and the generalization bound ϵ, we bound the sample complexity of f .

B.2.1 GENERALIZATION ANALYSIS VIA POINTWISE HYPOTHESIS STABILITY

A useful approach to analyzing the generalization performance of learning algorithms is algorithmic stability (Bousquet & Elisseeff, 2002) . A learning algorithm A is stable if small changes in the training set result in small differences in the output predictions of the trained model. Following the foundational works, we analyze the generalization bound from the perspective of pointwise hypothesis stability. Given a data set S = {x 1 , . . . , x n } where x i ∼ D, we define dataset S i as S\x i , i.e., S i = {x 1 , . . . , x i-1 , x i+1 , . . . , x n }. For our purposes, f is the output of learning algorithm A, parameterized by w. For the model f parameterized by w, we denote the model trained on the dataset S and S i as w S and w S i , respectively. Accordingly, the loss functions are represented as L(w S ; x) and L(w S i ; x), respectively. Especially, for a given dataset S, the empirical loss on a dataset S is denoted as L S (w S ) and given by L S (w S ) = 1 n n i=1 L(f (x i ; w S ), y i )). Definition 10 (Pointwise Hypothesis Stability (Bousquet & Elisseeff, 2002) ). An algorithm A has pointwise hypothesis stability γ with respect to a loss function L if ∀i ∈ {1, . . . , n}, E S [|L (w S ; x i ) -L (w S i ; x i )|] ≤ γ. ( ) Then we can use the pointwise hypothesis stability to establish the generalization bounds as Theorem 11 Bousquet & Elisseeff (2002) . Suppose we have a learning algorithm A with pointwise hypothesis stability γ with respect to a bounded loss function L such that 0 ≤ L(w; x) ≤ M . For any δ, we have with probability at least 1 -δ, R(w S ) ≤ R S (w S ) + M 2 + 12M nγ 2nδ . In the following, we derive stability bounds for models trained on risk functions satisfying the PL * condition.

B.2.2 POINTWISE HYPOTHESIS STABILITY WITH PL * CONDITION

Theorem 12 Assume that for all training sets S and models f parameterized by w, the loss function L is PL * with parameter µ. In addition, assume that the over-parameterized model f with parameter w S trained on S is capable of converging to some global minimizer w * S . Then for all S, if |L S (w S ) -L S (w * S )| ≤ ϵ f , then f has pointwise hypothesis stability with parameter γ as: γ ≤ L 2 f µn . ( ) Proof. For a training set S, let w S denote the parameters of f on S, and let w S i denote the parameters of f on S i , where S i denotes S\x i . Let w * S and w * S i denote the corresponding optimal solutions, respectively. We then have, |L (w S ; x i ) -L (w S i ; x i )| ≤ |L (w S ; x i ) -L (w * S ; x i )| + |L (w * S ; x i ) -L (w * S i ; x i )| + |L (w * S i ; x i ) -L (w S i ; x i )| . ( ) The three terms can be separately bounded. The first and third term. As discussed in previous work (Karimi et al., 2016) , the PL * condition implies that the quadratic growth (QG) condition µ 2 ∥w S -w * S ∥ 2 ≤ |L S (w S ) -L S (w * S )| . ( ) Using the fact that |L S (w S ) -L (w * S )| ≤ ϵ f by the assumption, it implies ∥w S -w * S ∥ ≤ √ 2 √ µ |L S (w S ) -L S (w * S )| = 2ϵ f µ . ( ) Since that loss function L is L f -Lipschitz, we can bound the first and third term as: |L (w S ; x i ) -L (w * S ; x i )| ≤ L f ∥w S -w * S ∥ ≤ L f ϵ f µ ; |L (w S i ; x i ) -L (w * S i ; x i )| ≤ L f ∥w S i -w * S i ∥ ≤ L f ϵ f µ ; In the over-parameterized settings, the model f optimized by the loss function enables a global optimizer w * and has a small ϵ f constant, thus the first term and third term equal to 0. The second term. PL * condition implies that L (w * S ; x i ) = 0, and L S i w * S i = 0 thus the second term can be manipulated as: |L (w * S ; x i ) -L (w * S i ; x i )| = |L (w * S i ; x i ) | = |nL S (w * S i ) -(n -1)L S i (w * S i ) | = n|L S (w * S i ) |. ( ) Note that since ∇L S i w * S i = 0, we get: ∥∇L S (w * S i )∥ 2 = 1 n 2 ∥∇L (w * S i ; x i )∥ 2 ≤ L 2 f n 2 . ( ) Furthermore, PL * condition implies µL S w * S i ≤ ∥∇L S w * S i ∥ 2 , we can obtain n |L S (w * S i )| ≤ n µ ∥∇L S (w * S i )∥ 2 ≤ n µ L 2 f n 2 ≤ L 2 f µn . The overall bound. Plugging Eqs. 23 and 26 into Eq. 20, we can obtain the desired result: Assume that for all S and w, the loss function L is PL * with parameter µ, algorithm A has pointwise hypothesis stability with parameter γ as: . Therefore, we can compute the sample complexity as: γ ≤ L 2 f µn . ( n(ϵ, δ) ≤ 6L 2 f M µϵ 2 δ + M 2 2ϵ 2 δ (28) C PROOF OF PROPOSITION 8 Proof. Consider an arbitrary point w ∈ B(w 0 , R). For all sample i in dataset S, we have DF i (w) = DF i (w 0 ) + 1 0 H fi (w 0 + τ (w -w 0 )) (w -w 0 ) dτ. ( ) Since τ ∈ [0, 1], the point w 0 + τ (w -w 0 ) is inside of the ball, thus ∥H fi (w 0 + τ (w -w 0 ))∥ 2 ≤ ∥H f (w)∥ 2 . Therefore, Eq. 29 can be bounded as: ∥DF i (w) -DF i (w 0 )∥ 2 ≤ sup τ ∈[0,1] ∥H fi (w 0 + τ (w -w 0 ))∥ 2 • ∥w -w 0 ∥ 2 ≤ R • ∥H f (w)∥ 2 . ( ) By triangle inequality, we have, ∥DF(w)∥ 2 -∥DF(w 0 )∥ 2 ≤ ∥DF(w) -DF(w 0 )∥ 2 ≤ ∥DF(w) -DF(w 0 )∥ F = i ∥DF i (w) -DF i (w 0 )∥ 2 ≤ √ nR • ∥H f (w)∥ 2 . ( ) By re-arranging Eq. 31, we have ∥DF(w)∥ 2 ≤ ∥DF(w 0 )∥ 2 + √ nR • ∥H f (w)∥ 2 . ( ) Using the definition of L f , we have L f = sup w∈B ∥DF(w)∥ 2 ≤ ∥DF(w 0 )∥ 2 + √ nR • sup w∈B ∥H f (w)∥ 2 . ( ) Combining Eq. 33 and definition of µ * , we have L f µ * ≤ ∥DF(w 0 )∥ + sup w∈B ∥H f (w)∥ • R inf w∈B λ min (K(w)) .

D PROOF OF THEOREM 9

To prove that minimizing the upper bound of condition number can be achieved by pruning parameters controlled by Hessian norm ∥H f ∥ 2 , we separately analyze the impact of pruning parameters controlled by Hessian norm ∥H f ∥ 2 on the Lipschitz constant; and the impact of pruning parameters on the minimal eigenvalue of NTK (Proposition 14). Proof. Using the same argument in Proposition 8, as shown in Eq. 33 where L f = sup w∈B ∥DF(w)∥ 2 ≤ ∥DF(w 0 )∥ 2 + √ nR • sup w∈B ∥H f (w)∥ 2 , we can conclude that pruning parameters reduces the parameter size, leading to the decrease of the ∥w -w 0 ∥ (i.e., the radius of the ball, R); and pruning parameters governed by minimizing the Hessian norm ∥H f ∥ 2 can decrease the upper bound of the Lipschitz constant. In summary, the pruning operation with controlling of Hessian norm is able to minimize the upper bound of L f . ■ Then we use matrix analysis to show that maximizing the minimal eigenvalue of NTK can be achieved via pruning. Let us denote the matrix G = DFDF ⊤ . The matrix G and NTK matrix (K = DF ⊤ DF) shares the same non-zero eigenvalues, thus we analyze the effect of pruning on the eigenvalues of NTK via the matrix G ∈ R m×m where m denotes the parameter size. Pruning any parameter w p from the weight set w corresponds to deleting the corresponding p-th row and the corresponding p-th column from G, i.e, Gp . According to the Cauchy Interlacing Theorem, increasing the minimal eigenvalue of G can be achieved by pruning the parameter w m . We first summarize the Cauchy Interlacing Theorem as follows: Theorem 13 (Cauchy Interlacing Theorem Horn & Johnson (2012)). Let B ∈ M m be Hermitian, let y ∈ C m and a ∈ R be given, and let A = B y y * a ∈ M m+1 . Then λ 1 (A) ≤ λ 1 (B) ≤ λ 2 (A) ≤ • • • ≤ λ m (A) ≤ λ m (B) ≤ λ m+1 (A). Theorem 13 allows us to conclude that if we remove the parameter w m , the corresponding Gm ⊆ G, then λ min ( Gm ) ≥ λ min (G). Proposition 14 For an over-parameterized model f parameterized by w ∈ W, pruning operation on the parameter set w increases the minimal eigenvalue of NTK matrix associated to f . Proof. For parameter w p and matrix G, let G p→m denote the permuted matrix obtained by permuting p-th row and p-th column of matrix G to m-th row and m-th column. Then pruning parameter w p is equivalent to pruning w m of G p→m matrix. Following Theorem 13, pruning parameter w m of G p→m will increase its minimal eigenvalue. Moreover, using the fact that G p→m = PGP ⊤ = PGP -foot_7 where P denotes the permutation matrix, we have G p→m and G are similar matrix with same eigenvalues. Thus, the minimal eigenvalue of G is increased. In summary, pruning any parameter w p will increase the minimal eigenvalue of G. ■ Proof of Theorem 9. Proposition 8 shows that pruning operation controlled by minimizing Hessian norm ∥H f ∥ 2 is able to minimize the upper bound of Lipschitz constant L f ; and Proposition 14 shows that pruning parameter is able to maximize the minimal eigenvalue of NTK. Combining these two propositions, we can conclude that pruning parameters controlled by minimizing the Hessian norm ∥H f ∥ 2 is able to minimize the condition number of the model. ■ To further verify the claim of theories, we empirically investigate the evolution of condition number of the pruned model and vanilla model. We apply a small BERT with 4 Transformer layers to WikiText-2 for the language modeling task. We prune the heads of vanilla BERT by using a strategy of pruning 5% heads every 10 epochs, controlled by minimizing the Hessian trace (details can be found in Appendix E.2); the pruned network is denoted as Masked-BERT. As shown in Figure 3 , compared with vanilla BERT, Masked-BERT presents a smaller value of condition number 1 . Especially, we can observe a sharp decrease in the condition number of Masked-BERT after each pruning operation; to highlight this, red lines are included in the figure which indicate the condition number before and after pruning. 

E EXPERIMENTAL DETAILS E.1 ALGORITHM OVERVIEW

Figure 4 shows the overview of the proposed method. We impose PL regularization for model optimization by adopting a pruning approach. More specifically, we introduce a binary mask for periodically sparsifying parameters, and the mask is learned via a gating network whose input summarizes the optimization dynamics of sub-networks in terms of the PL condition. Algorithm 2 summarizes the proposed pruning algorithm. To further verify pruning parameters controlled by minimizing Hessian trace is capable to decrease the condition number of the model, we empirically investigate the evolution of the trace of Hessian and the condition number. Here a small BERT with 4 Transformer layers is applied to WikiText-2 for language modeling. As shown in Figure 5 , Masked-BERT where we prune 50% heads at epoch-10 controlled by minimizing the trace of Hessian presents a small value of Hessian trace, accompanied with a smaller value of condition number. The phenomenon demonstrates that minimizing the Hessian trace is an effective way to control the condition number of the model. Input: Training set S, pruning time and ratio {t j , r j } q j=1 , maximum iterations t max 1: Initialize: m = 1 2: while not converged do 3: for t = t 0 , . . . , t max do 4: if t = t j then 5: Calculate d w t j = [λ min (K(w tj [1] )), ρ(K(w tj [1] )), . . . , λ min (K(w tj [p] )), ρ(K(w tj [p] ))] 6: for each batch in S do 7: v = v -η∇ v L total with α = 1 ▷ Update v according to Eq. 36 for each batch in S do 13: w t+1 = w t -η∇ w L total with α = 0 ▷ Update w according to Eq. 36 We follow the original basic settings when training baseline models. For BERT-LTH and Switch-LTH, we use iterative magnitude pruning (IMP) (Chen et al., 2020) to target sparsity with rewinding step of 5% maximum training epochs. For Att-Score and Exp-Score, all the attention heads and experts across the model are sorted by gradient-based proxy importance score (Michel et al., 2019) . Then they are pruned by the same iterative strategy as IMP. For SNIP and GraSP, we implement them in a structured way, i.e., we sum up the relative indicators in those components, and the attention heads and experts are pruned at the initialization stage of the training process (Lee et al., 2018; Wang et al., 2020) . A protecting strategy is used for the pruning operation of all baselines and our model. We keep at least one head or expert in each of the layers to keep the activation across the model.

E.4 EXPERIMENTAL DETAILS OF VGG-16

First, we summarize the architecture-specific pruning baselines, including FPGM, CHIP, and DPFPS. FPGM (He et al., 2019) compresses CNN models by pruning filters with norm-and distance-based criteria, which prunes all the weighted layers with the same pruning rate at the same [0, 0, 0] [1e-6, 0.9, 0.999] Number of experts 4 Capacity factor 1.5 time. CHIP (Sui et al., 2021) first extracts feature maps from a pre-trained model and calculates channel independence, which is used to sort and prune the filters, and then it fine-tunes the sparse pre-trained model. DPFPS (Ruan et al., 2021) prunes structured parameters in a dynamic sparsity manner, where the sparsity allocation ratios are distributed differently over layers in the training process. Then, we report the detailed experimental settings. For training VGG-16 on CIFAR-10 and CIFAR-100, we use similar configurations as (Wang et al., 2020) does. For all baselines, the network is trained with Kaiming initialization (He et al., 2015) using SGD for 200 epochs. The learning rate is decayed from the initial value 0.1 by a factor of 0.1 at 1/2 and 3/4 of the total number of epochs. To obtain more stable results, we conduct each experiment in 3 trials. In this paper, we evaluate the performance of all models on VGG-16 with 50% and 90% pruning ratios, respectively. The pruning policy of our method is the linear pruning schedule where starting from the epoch-15, we prune the same number of filters at each epoch until the target pruning ratio is reached. The pruning policy of baselines follows their origin settings given a target pruning ratio. In detail, for LTH, SNIP and GraSP, we conduct unstructured pruning with the target pruning ratio at the initialization stage. For FPGM, we utilize the norm-based criterion when given the 50% target pruning ratio. When given the 90% target pruning ratio, we first prune 50% filters by the norm-based criteria and then prune 40% filters by the distance-based criteria. For CHIP, given the 50% pruning ratio, following the original work, we assign the pruning ratio of the previous 7 layers as 20% and the ratio of remaining layers as 60%; while given the 90% pruning ratio, the ratio of the previous 7 layers is assigned as 75% and remaining layer as 95%.

F ADDITIONAL EXPERIMENTAL RESULTS

In this section, we report additional experimental results.

F.1 RESULTS ON RESNET-56

This experiment focuses on ResNet-56 (Simonyan & Zisserman, 2014) trained on the CIFAR-10 dataset. The baselines include five filter pruning methods, namely FPGM (He et al., 2019) , DPFPS (Ruan et al., 2021) , LTH (Frankle & Carbin, 2018) , SNIP (Lee et al., 2018) , and GraSP (Wang et al., 2020) . All models are trained from scratch and are pruned with a linear schedule, which starts from epoch-5 and prune the same number of filters per epoch until the target sparsity is reached. Table 5 shows the performance of all methods trained on CIFAR-10 with a 25% pruning ratio. As we can see, our method outperforms all pruning baselines in terms of training efficiency and test accuracy. As shown in Figure 6 , compared to vanilla BERT trained on WikiText-2, our model achieves the same test perplexity as the vanilla BERT at iterations 10k at iterations 8k, which is a 1.3× speedup in terms of step time, i.e., our method only requires 0.7× iterations to achieve the same test perplexity. As for the saved wall-clock time, as shown in Table 6 , where the first row shows the wall-clock time in seconds and the second row shows the time saved compared with the vanilla BERT. We can observe that our method saves more wall-clock times compared with other baselines. In particular, it saves 29% wall-clock time of vanilla BERT. Furthermore, the training of the gating network accounts for only 2-3% of the total training time of BERT, demonstrating the efficiency of our method with the loop for pruning. We find similar results in Switch-Transformer trained on WikiText-103. As shown in Figure 7 and Table 7 , compared to vanilla Switch-Transformer, our method yields a 2x speedup, requiring 44% less clock time. We also evaluate the impact of one-shot pruning and linear pruning schedule for BERT. Specifically, we compare the performance of one-shot pruning and iterative pruning strategy on the heads of BERT with the same 75% pruning ratio. As shown in Table 9 , the one-shot pruning strategy achieves competitive performance compared to the linear pruning schedule while saving more computational costs.

F.4 THE EFFECT OF PRUNING RATIO

This subsection evaluates the performance of the proposed PL regularization-driven structured pruning method with varying pruning ratios. Table 10 and Table 11 show the performance of BERT with the 50% and 90% pruning ratio, respectively. 13 show the performance of VGG-16 with the 75% and 90% pruning ratio, respectively. We can observe that with a high pruning ratio, our method outperforms pruning baselines in terms of training efficiency and test performance. In practice, in order to find the optimal model, researchers usually train multiple models with varying parameter initializations. In over-parameterized settings, such a process requires too many computational resources and time efforts. Our method can not only reduce the computational cost for training one model, but also reduce the sensitivity to initialization and thus less models need to be trained. To verify the latter benefit, we conduct experiments on two BERT models with different parameter initializations: BERT-A and BERT-B denote the BERT model with good initialization and bad initialization, respectively. Table 14 and Figure 8 show that imposing our method on the BERT-B can significantly improve its original performance; it outperforms the vanilla BERT-A and slightly underperforms the masked BERT-A model implemented with PL regularization. This phenomenon suggests the potential of our method in avoiding multiple initializations and meanwhile saving computation and time efforts. 



China and Shanghai Key Laboratory of Data Science, School of Computer Science, Fudan University, Shanghai, China. School of Mathematics Statistics, The University of Glasgow, Glasgow, UK. MicrosoftResearch Asia, Shanghai, China. Department of Engineering Science, University of Oxford, Oxford, England. Department of Electrical Engineering and Computer Science, University of Michigan, Michigan, United States. Department of Computer Science, University of Colorado Boulder, Boulder, Colorado, United States. School of Microelectronics, Fudan University, Shanghai, China. * The corresponding author. Condition number is computed as λmax(K) λ min(K) , where λmax and λmin denote the largest and smallest eigenvalue, respectively.



3 BOUND OF SAMPLE COMPLEXITY Based on the above subsection, if the loss function L of model f is PL * condition with parameter µ, and the algorithm A has the pointwise hypothesis stability with parameter γ ≤ L 2 f µn , the generalization gap ϵ of the obtained model f can be bound as: ϵ = M 2 +12M nγ 2nδ

Figure 3: The evolution of condition number of the vanilla BERT and the pruned BERT (denoted as Masked-BERT).

Figure 4: Algorithm overview. The gating network generates binary mask m based on the parameter features d w . PL Regularization helps pruned model obtain a smaller condition number L f µ .

Figure 5: Connection between the trace of Hessian T race(H f (m * w)) and the condition number L f /µ.

tj = T opK(g(d w ; v), (1 -r j )p) ▷ Generate binary mask m 10: w tj = m tj * w tj ▷ Reparameterize w

DETAILS OF BERT AND SWITCH-TRANSFORMER BERT and Switch Transformer are trained from scratch by self-supervised Masked Language Model (MLM) task. The MLM objective is a cross-entropy loss on predicting the masked tokens. The models uniformly select 15% of the input tokens for possible replacement. Considering the model capacity and computational resources, we train BERT on WikiText-2 and Switch Transformer on WikiText-103 datasets. The training hyperparameters are presented in Table 4. All baseline models, including our method, are trained by the same training setup. These experiments are conducted on 8 × GPUs of NVIDIA GeForce RTX 3090.

Figure 6: Training efficiency comparison of vanilla BERT and our model.

Figure 7: Training efficiency comparison of vanilla Switch-Transformer and our model.

Figure 8: Training error and testing performance of BERT models with different parameter initialization.

∀w 1 , w 2 ∈ W, where W indicates the hypothesis set.

Results of BERT on WikiText-2. We denote perplexity as PPL, and ∆PPL represents the perplexity difference between training and test. Boldface indicates the best result.Table1shows the performance of all models on WikiText-2 with the {5, 75%} pruning policy, i.e., we prune 75% heads at epoch-5 during training. Our method shows better training efficiency and generalization ability than all baselines. In particular, our method improves test perplexity over vanilla BERT, and improves training perplexity by 60%, 29%, and 33% after 8 k, 10 k, and 15 k iterations, demonstrating the effectiveness of PL regularization for BERT model optimization. We also evaluate the training efficiency by measuring the number of iterations required for converging to a fixed training error status. Compared to vanilla BERT, our method requires only 0.7× iterations to achieve satisfactory test perplexity (see Appendix F.2). Other pruning methods, on the other hand, can improve training efficiency but at a cost of sacrificing test performance, as their pruning criteria focus on training loss, ignoring the measurement of generalization ability. Furthermore, we analyze the effect of pruning times of mask for BERT optimization (shown in Appendix F.3). It is interesting to find that compared with the model applying PL regularization iteratively throughout training, the model that applies PL regularization during the early stages of training has better training efficiency and test performance. This phenomenon indicates that important heads for BERT optimization can potentially be discovered in the early stages of training.

Results of Switch-Transformer on WikiText-103 (Best results in Boldface).

Table2lists the performance of Switch-Transformer on WikiText-103 with the {5, 50%} pruning policy. Our method (Ours (experts only)) significantly outperforms all the baselines in both training efficiency and generalization ability, demonstrating our method can be applicable to MoE architecture. In particular, compared with the vanilla Switch-Transformer, when used on the MoE architecture, our method improves training perplexity by 62%, 41%, 24% in 50 k, 60 k, and 70 k iterations, while improving test performance by 13%

Results of VGG-16 on CIFAR-10 and CIFAR-100 (Best results in Boldface).

Hyperparameters for training BERT and Switch Transformer

Results of ResNet-56 on CIFAR-10 (Best results in Boldface).This subsection shows the comparison of training efficiency by measuring the number of training iterations required for converging to a given test perplexity. We also evaluate the wall-clock time saved when the model converges to a given test perplexity. The calculation of parameter features and the training of the gating network are included in the statistic.

Comparison of saved clock time on BERT.

Comparison of clock time saved when model achieves a given test perplexity. The symbol * denotes expert pruning, and * * denotes expert and head pruning.

Performance comparison on WikiText-2. For simplicity, we denote perplexity as PPL, and ∆PPL represents the perplexity difference between training and test. Boldface indicates the best result.

Comparsion of one-shot pruning and linear pruning schedule for BERT.

Performance Comparison of BERT on Wikitext-2. All models are pruned by 50% sparsity. Boldface indicates the best result among pruned models.

Performance Comparison of BERT on Wikitext-2. All models are pruned by 90% sparsity. Boldface indicates the best result among pruned models.

Performance Comparison of VGG-16 on CIFAR-10 and CIFAR-100. All models are pruned by 75% sparsity. Boldface indicates the best result among pruned models. Models without results mean that these models cannot converge under the current setting.

Performance Comparison of VGG-16 on CIFAR-10 and CIFAR-100. All models are pruned by 90% sparsity. Boldface indicates the best result among pruned models. Models without results mean that these models cannot converge under the current setting.Artificial deep learning models are typically over-parameterized, coming at the heavy computation effort during model training and inference. This work, drawing on theoretical insights, finds an efficient optimization strategy for the over-parameterized model to achieve a faster convergence time and a better test performance.

Results of BERT with different parameter initialization.

