NTK-SAP: IMPROVING NEURAL NETWORK PRUNING BY ALIGNING TRAINING DYNAMICS

Abstract

Pruning neural networks before training has received increasing interest due to its potential to reduce training time and memory. One popular method is to prune the connections based on a certain metric, but it is not entirely clear what metric is the best choice. Recent advances in neural tangent kernel (NTK) theory suggest that the training dynamics of large enough neural networks is closely related to the spectrum of the NTK. Motivated by this finding, we propose to prune the connections that have the least influence on the spectrum of the NTK. This method can help maintain the NTK spectrum, which may help align the training dynamics to that of its dense counterpart. However, one possible issue is that the fixedweight-NTK corresponding to a given initial point can be very different from the NTK corresponding to later iterates during the training phase. We further propose to sample multiple realizations of random weights to estimate the NTK spectrum. Note that our approach is weight-agnostic, which is different from most existing methods that are weight-dependent. In addition, we use random inputs to compute the fixed-weight-NTK, making our method data-agnostic as well. We name our foresight pruning algorithm Neural Tangent Kernel Spectrum-Aware Pruning (NTK-SAP). Empirically, our method achieves better performance than all baselines on multiple datasets. Our code is available at https://github. com/YiteWang/NTK-SAP.

1. INTRODUCTION

The past decade has witnessed the success of deep neural networks (DNNs) in various applications. Modern DNNs are usually highly over-parameterized, making training and deployment computationally expensive. Network pruning has emerged as a powerful tool for reducing time and memory costs. There are mainly two types of pruning methods: post-hoc pruning (Han et al., 2015; Renda et al., 2020; Molchanov et al., 2019; LeCun et al., 1989; Hassibi & Stork, 1992) and foresight pruning (Lee et al., 2018; Wang et al., 2020; Alizadeh et al., 2022; Tanaka et al., 2020; de Jorge et al., 2020b; Liu & Zenke, 2020) . The former methods prune the network after training the large network, while the latter methods prune the network before training. In this paper, we focus on foresight pruning. SNIP (Lee et al., 2018) is probably the first foresight pruning method for modern neural networks. It prunes those initial weights that have the least impact on the initial loss function. SNIP can be viewed as a special case of saliency score based pruning: prune the connections that have the least "saliency scores", where the saliency score is a certain metric that measures the importance of the connections. For SNIP, the saliency score of a connection is the difference of loss function before and after pruning this connection. GraSP (Wang et al., 2020) and Synflow (Tanaka et al., 2020) used two different saliency scores. One possible issue of these saliency scores is that they are related to the initial few steps of training, and thus may not be good choices for later stages of training. Is there a saliency score that is more directly related to the whole training dynamics? Recently, there are many works on the global optimization of neural networks; see, e.g., Liang et al. (2018b; a; 2019; 2021) 2020). Among them, one line of research uses neural tangent kernel (NTK) (Jacot et al., 2018) to describe the gradient descent dynamics of DNNs when the network size is large enough. More specifically, for large enough DNNs, the NTK is asymptotically constant during training, and the convergence behavior can be characterized by the spectrum of the NTK. This theory indicates that the spectrum of the NTK might be a reasonable metric for the whole training dynamics instead of just a few initial iterations. It is then natural to consider the following conceptual pruning method: prune the connections that have the least impact on the NTK spectrum. There are a few questions on implementing this conceptual pruning method. First, what metric to compute? Computing the whole eigenspectrum of the NTK is too timeconsuming. Following the practice in numerical linear algebra and deep learning (Lee et al., 2019a; Xiao et al., 2020) , we use the nuclear norm (sum of eigenvalues) as a scalar indicator of the spectrum. Second, what "NTK" matrix to pick? We call the NTK matrix defined for the given architecture with a random initialization as a "fixed-weight-NTK", and use "analytic NTK" (Jacot et al., 2018) to refer to the asymptotic limit of the fixed-weight-NTK as the network width goes to infinity. The analytic NTK is the one studied in NTK theory (Jacot et al., 2018) , and we think its spectrum may serve as a performance indicator of a certain architecture throughout the whole training process.foot_0 However, computing its nuclear norm is still too time-consuming (either using the analytic form given in Jacot et al. (2018) or handling an ultra-wide network). The nuclear norm of a fixed-weight-NTK is easy to compute, but the fixed-weight-NTK may be quite different from the analytic NTK. To resolve this issue, we notice a less-mentioned fact: the analytic NTK is also the limit of the expectation (over random weights) of fixed-weight-NTKfoot_1 , and thus it can be approximated by the expectation of fixedweight-NTK for a given width. The expectation of fixed-weight-NTK shall be a better approximation of analytic NTK than a single fixed-weight-NTK. Of course, to estimate the expectation, we can use a few "samples" of weight configurations and compute the average of a few fixed-weight-NTKs. One more possible issue arises: would computing, say, 100 fixed-weight-NTKs take 100 times more computation cost? We use one more computation trick to keep the computation cost low: for each mini-batch of input, we use a fresh sample of weight configuration to compute one fixed-weight-NTK (or, more precisely, its nuclear norm). This will not increase the computation cost compared to computing the fixed-weight-NTK for one weight configuration with 100 mini-batches. We call this "new-input-new-weight" (NINW) trick. We name the proposed foresight pruning algorithm Neural Tangent Kernel Spectrum-Aware Pruning (NTK-SAP). We show that NTK-SAP is competitive on multiple datasets, including CIFAR-10, CIFAR-100, Tiny-ImageNet, and ImageNet. In summary, our contributions are: • We propose a theory-motivated foresight pruning method named NTK-SAP, which prunes networks based on the spectrum of NTK. • We introduce a multi-sampling formulation which uses different weight configurations to better capture the expected behavior of pruned neural networks. A "new-input-new-weight" (NINW) trick is leveraged to reduce the computational cost, and may be of independent interest. • Empirically, we show that NTK-SAP, as a data-agnostic foresight pruning method, achieves state-of-the-art performance in multiple settings.

2. RELATED WORK AND BACKGROUND

Pruning after training (Post-hoc pruning). Post-hoc pruning can be dated back to the 1980s (Janowsky, 1989; Mozer & Smolensky, 1989) and they usually require multiple rounds of trainprune-retrain procedure (Han et al., 2015; LeCun et al., 1989) . Most of these pruning methods use



Please see Appendix B.2 for more discussions from an empirical perspective. This is a subtle point, and we refer the readers to Appendix B.1 for more discussions.



; Venturi et al. (2018); Safran & Shamir (2017); Li et al. (2022); Ding et al. (2022); Soltanolkotabi et al. (2019); Sun et al. (2020a); Lin et al. (2021a;b); Zhang et al. (2021) and the surveys Sun et al. (2020b); Sun (

