BALANCING TRAINING TIME VS. PERFORMANCE WITH BAYESIAN EARLY PRUNING

Abstract

Pruning is an approach to alleviate overparameterization of deep neural network (DNN) by zeroing out or pruning DNN elements with little to no efficacy at a given task. In contrast to related works that do pruning before or after training, this paper presents a novel method to perform early pruning of DNN elements (e.g., neurons or convolutional filters) during the training process while preserving performance upon convergence. To achieve this, we model the future efficacy of DNN elements in a Bayesian manner conditioned upon efficacy data collected during the training and prune DNN elements which are predicted to have low efficacy after training completion. Empirical evaluations show that the proposed Bayesian early pruning improves the computational efficiency of DNN training. Using our approach we are able to achieve a 48.6% faster training time for ResNet-50 on ImageNet to achieve a validation accuracy of 72.5%.

1. INTRODUCTION

Deep neural networks (DNNs) are known to be overparameterized (Allen-Zhu et al., 2019) as they usually have more learnable parameters than needed for a given learning task. So, a trained DNN contains many ineffectual parameters that can be safely pruned or zeroed out with little/no effect on its predictive accuracy. Pruning (LeCun et al., 1989 ) is an approach to alleviating overparameterization of a DNN by identifying and removing its ineffectual parameters while preserving its predictive accuracy on the validation/test dataset. Pruning is typically applied to the DNN after training to speed up testing-time evaluation. For standard image classification tasks with MNIST, CIFAR-10, and ImageNet datasets, it can reduce the number of learnable parameters by up to 50% or more while maintaining test accuracy (Han et al., 2015; Li et al., 2017; Molchanov et al., 2017) . In particular, the overparameterization of a DNN also leads to considerable training time being wasted on those DNN elements (e.g., connection weights, neurons, or convolutional filters) which are eventually ineffectual after training and can thus be safely pruned. Our work in this paper considers early pruning of such DNN elements by identifying and removing them throughout the training process instead of after training.foot_0 As a result, this can significantly reduce the time incurred by the training process without compromising the final test accuracy (upon convergence) much. Recent work (Section 5) in foresight pruning (Lee et al., 2019; Wang et al., 2020) show that pruning heuristics applied at initialization work well to prune connection weights without significantly degrading performance. In contrast to these work, we prune throughout the training procedure, which improves performance after convergence of DNNs, albeit with somewhat longer training times. In this work, we pose early pruning as a constrained optimization problem (Section 3.1). A key challenge in the optimization is accurately modeling the future efficacy of DNN elements. We achieve this through the use of multi-output Gaussian process which models the belief of future efficacy conditioned upon efficacy measurements collected during training (Section 3.2). Although the posed optimization problem is NP-hard, we derive an efficient Bayesian early pruning (BEP) approximation algorithm, which appropriately balances the inherent training time vs. performance tradeoff in pruning prior to convergence (Section 3.3). Our algorithm relies on a measure of network element efficacy, termed saliency (LeCun et al., 1989) . The development of saliency functions is an active area of research with no clear optimal choice. To accomodate this, our algorithm is agnostic, and therefore flexible, to changes in saliency function. We use BEP to prune neurons and convolutional filters to achieve practical speedup during training (Section 4).foot_2 Our approach also compares favorably to state-of-the-art works such as SNIP (Lee et al., 2019 ), GraSP (Wang et al., 2020) , and momentum based dynamic sparse reparameterization (Dettmers & Zettlemoyer, 2019).

2. PRUNING

Consider a dataset of D training examples X = {x 1 , . . . , x D }, Y = {y 1 , . . . , y D } and a neural network N vt parameterized by a vector of M pruneable network elements (e.g. weight parameters, neurons, or convolutional filters) v t [v a t ] a=1,...,M , where v t represent the network elements after t iterations of stochastic gradient descent (SGD) for t 1, . . . , T . Let L(X , Y; N vt ) be the loss function for the neural network N vt . Pruning aims at refining the network elements v t given some sparsity budget B and preserving the accuracy of the neural network after convergence (i.e., N v T ), which can be stated as a constrained optimization problem (Molchanov et al., 2017) : min m∈{0,1} M |L(X , Y; N m v T ) -L(X , Y; N v T )| s.t. ||m|| 0 ≤ B (1) where is the Hadamard product and m is a pruning mask. Note that we abuse the Hadamard product for notation simplicity: for a = 1, .., M , m a × v a T corresponds to pruning v a T if m a = 0, and keeping v a T otherwise. Pruning a network element refers to zeroing the network element or the weight parameters which compute the network element. Any weight parameters which reference the output of the pruned network element are also zeroed since the element outputs a constant 0. The above optimization problem is difficult due to the NP-hardness of combinatorial optimization. This leads to the approach of using saliency function s which measures efficacy of network elements at minimizing the loss function. A network element with small saliency can be pruned since it's not salient in minimizing the loss function. Consequently, pruning can be done by maximizing the saliency of the network elements given the sparsity budget B: max m∈{0,1} M M a=1 m a s(a; X , Y, N v T , L) s.t. ||m|| 0 ≤ B (2) where s(a; X , Y, N v T , L) measures the saliency of v a T at minimizing L after convergence through T iterations of SGD. The above optimization problem can be efficienctly solved by selecting the B most salient network elements in v T . The construction of the saliency function has been discussed in many existing works: Some approaches derived the saliency function from first-order (LeCun et al., 1989; Molchanov et al., 2017) and second-order (Hassibi & Stork, 1992; Wang et al., 2020) Taylor series approximations of L. Other common saliency functions include L 1 (Li et al., 2017) or L 2 (Wen et al., 2016) norm of the network element weights, as well as mean activation (Polyak & Wolf, 2015) . In this work, we use a first-order Taylor series approximation saliency function defined for neurons and convolutional filters 3 (Molchanov et al., 2017), however our approach remains flexible to arbitrary choice of saliency function on a plug-n-play basis.

3.1. PROBLEM STATEMENT

As has been mentioned before, existing pruning works based on the saliency function are typically done after the training convergence (i.e., (2)) to speed up the testing-time evaluation, which waste considerable time on training these network elements which will eventually be pruned. To resolve this issue, We extend the pruning problem definition (2) along the temporal dimension, allowing network elements to be pruned during the training process consisting of T iterations of SGD.



In contrast, foresight pruning (Wang et al., 2020) removes DNN elements prior to the training process. Popular deep learning libaries do not accelerate sparse matrix operations over dense matrix operations. Thus, pruning network connections cannot be easily capitalized upon with performance improvements. It is also unclear whether moderately sparse matrix operations (i.e., operations on matrices generated by connection pruning) can be significantly accelerated on massively parallel architectures such as GPUs (see Yang et al. (2018) Fig.7). See Section 5 in Buluc ¸&Gilbert (2008) for challenges in parallel sparse matrix multiplication. Implementation details of this saliency function can be found in Appendix A.1.

