TRAINABILITY PRESERVING NEURAL PRUNING

Abstract

Many recent works have shown trainability plays a central role in neural network pruning -unattended broken trainability can lead to severe under-performance and unintentionally amplify the effect of retraining learning rate, resulting in biased (or even misinterpreted) benchmark results. This paper introduces trainability preserving pruning (TPP), a scalable method to preserve network trainability against pruning, aiming for improved pruning performance and being more robust to retraining hyper-parameters (e.g., learning rate). Specifically, we propose to penalize the gram matrix of convolutional filters to decorrelate the pruned filters from the retained filters. In addition to the convolutional layers, per the spirit of preserving the trainability of the whole network, we also propose to regularize the batch normalization parameters (scale and bias). Empirical studies on linear MLP networks show that TPP can perform on par with the oracle trainability recovery scheme. On nonlinear ConvNets (ResNet56/VGG19) on CIFAR10/100, TPP outperforms the other counterpart approaches by an obvious margin. Moreover, results on ImageNet-1K with ResNets suggest that TPP consistently performs more favorably against other top-performing structured pruning approaches.

1. INTRODUCTION

Neural pruning aims to remove redundant parameters without seriously compromising the performance. It normally consists of three steps (Reed, 1993; Han et al., 2015; 2016b; Li et al., 2017; Liu et al., 2019b; Wang et al., 2021b; Gale et al., 2019; Hoefler et al., 2021; Wang et al., 2023) : pretrain a dense model; prune the unnecessary connections to obtain a sparse model; retrain the sparse model to regain performance. Pruning is usually categorized into two classes, unstructured pruning (a.k.a. element-wise pruning or fine-grained pruning) and structured pruning (a.k.a. filter pruning or coarse-grained pruning). Unstructured pruning chooses a single weight as the basic pruning element; while structured pruning chooses a group of weights (e.g., 3d filter or a 2d channel) as the basic pruning element. Structured pruning fits more for acceleration because of the regular sparsity. Unstructured pruning, in contrast, results in irregular sparsity, hard to exploit for acceleration unless customized hardware and libraries are available (Han et al., 2016a; 2017; Wen et al., 2016) . Recent papers (Renda et al., 2020; Le & Hua, 2021) report an interesting phenomenon: During retraining, a larger learning rate (LR) helps achieve a significantly better final performance, empowering the two baseline methods, random pruning and magnitude pruning, to match or beat many more complex pruning algorithms. The reason behind is argued (Wang et al., 2021a; 2023) to be related to the trainability of neural networks (Saxe et al., 2014; Lee et al., 2020; Lubana & Dick, 2021) . They make two major observations to explain the LR effect mystery (Wang et al., 2023) . (1) The weight removal operation immediately breaks the network trainability or dynamical isometry (Saxe et al., 2014) (the ideal case of trainability) of the trained network. (2) The broken trainability slows down the optimization in retraining, where a greater LR aids the model converge faster, thus a better performance is observed earlier -using a smaller LR can actually do as well, but needs more epochs. Although these works (Lee et al., 2020; Lubana & Dick, 2021; Wang et al., 2021a; 2023) provide a plausibly sound explanation, a more practical issue is how to recover the broken trainability or maintain it during pruning. In this regard, Wang et al. (2021a) proposes to apply weight orthogonalization based on QR decomposition (Trefethen & Bau III, 1997; Mezzadri, 2006) to the pruned 3) and ( 5), which we will show perform more favorably than other alternatives (see Tabs. 1 and 10). Residual Block Skip Connection F (i-1) F (i) F (i+1) F (i+2) W (i) W (i+1) model. However, their method is shown to only work for linear MLP networks. On modern deep convolutional neural networks (CNNs), how to maintain trainability during pruning is still elusive. We introduce trainability preserving pruning (TPP), a new and novel filter pruning algorithm (see Fig. 1 ) that maintains trainability via a regularized training process. By our observation, the primary cause that pruning breaks trainability lies in the dependency among parameters. The primary idea of our approach is thus to decorrelate the pruned weights from the kept weights so as to "cut off" the dependency, so that the subsequent sparsifying operation barely hurts the network trainability. Specifically, we propose to regularize the gram matrix of weights: All the entries representing the correlation between the pruned filters (i.e., unimportant filters) and the kept filters (i.e., important filters) are encouraged to diminish to zero. This is the first technical contribution of our method. The second one lies in how to treat the other entries. Conventional dynamical isometry wisdom suggests orthogonality, namely, 1 self-correlation and 0 cross-correlation, even among the kept filters, while we find directly translating the orthogonality idea here is unnecessary or even harmful because the too strong penalty will constrain the optimization, leading to deteriorated local minimum. Rather, we propose not to impose any regularization on the correlation entries of kept filters. Finally, modern deep models are typically equipped with batch normalization (BN) (Ioffe & Szegedy, 2015) . However, previous filter pruning papers rarely explicitly take BN into account (except two (Liu et al., 2017; Ye et al., 2018) ; the differences of our work from theirs will be discussed in Sec. 3.2) to mitigate the side effect when it is removed because its associated filter is removed. Since they are also a part of the whole trainable parameters in the network, unattended removal of them will also lead to severely crippled trainability (especially at large sparsity). Therefore, BN parameters (both the scale and bias included) ought to be explicitly taken into account too, when we develop the pruning algorithm. Based on this idea, we propose to regularize the two learnable parameters of BN to minimize the influence of its absence later. Practically, our TPP is easy to implement and robust to hyper-parameter variations. On ResNet50 ImageNet, TPP delivers encouraging results compared to many recent SOTA filter pruning methods.

Contributions. (1)

We present the first filter pruning method (trainability preserving pruning) that effectively maintains trainability during pruning for modern deep networks, via a customized weight gram matrix as regularization target. (2) Apart from weight regularization, a BN regularizer is introduced to allow for their subsequent absence in pruning -this issue has been overlooked by most previous pruning papers, although it is shown to be pretty important to preserve trainability, especially in the large sparsity regime. (3) Practically, the proposed method can easily scale to



Figure 1: Illustration of the proposed TPP algorithm on a typical residual block. Weight parameters are classified into two groups as a typical pruning algorithm does: important (white color) and unimportant (orange or blue color), right from the beginning (before any training starts) based on the filter L 1 -norms. Then only the unimportant parameters are enforced with the proposed TPP regularization terms, which is the key to maintain trainability when the unimportant weights are eventually eliminated from the network. Notably, the critical part of a regularization-based pruning algorithm lies in its specific regularization term, i.e., Eqs. (3) and (5), which we will show perform more favorably than other alternatives (see Tabs. 1 and 10).

availability

https://github.com/MingSun-Tse/TPP.

