ROBUSTNESS TO PRUNING PREDICTS GENERALIZATION IN DEEP NEURAL NETWORKS

Abstract

Why over-parameterized neural networks generalize as well as they do is a central concern of theoretical analysis in machine learning today. Following Occam's razor, it has long been suggested that simpler networks generalize better than more complex ones. Successfully quantifying this principle has proved difficult given that many measures of simplicity, such as parameter norms, grow with the size of the network and thus fail to capture the observation that larger networks tend to generalize better in practice. In this paper, we introduce a new, theoretically motivated measure of a network's simplicity: the smallest fraction of the network's parameters that can be kept while pruning without adversely affecting its training loss. We show that this measure is highly predictive of a model's generalization performance across a large set of convolutional networks trained on CIFAR-10. Lastly, we study the mutual information between the predictions of our new measure and strong existing measures based on models' margin, flatness of minima and optimization speed. We show that our new measure is similar to -but more predictive than -existing flatness-based measures.

1. INTRODUCTION

The gap between learning-theoretic generalization bounds for highly overparameterized neural networks and their empirical generalization performance remains a fundamental mystery to the field (Zhang et al., 2016; Jiang et al., 2020; Allen-Zhu et al., 2019) . While these models are already being successfully used in many applications, improving our understanding of how neural networks perform on unseen data is crucial for safety-critical use cases. By understanding which factors drive generalization in neural networks we may further be able to develop more efficient and performant network architectures and training methods. Numerous theoretically and empirically motivated attempts have been made to identify generalization measures, that is, properties of the trained model, training procedure and training data that distinguish models that generalize well from those that do not (Jiang et al., 2020) . A number of generalization measures have attempted to quantify Occam's razor, i.e. the principle that simpler models generalize better than complex ones (Neyshabur et al., 2015; Bartlett et al., 2017) . This has proven to be non-trivial, as many measures, particularly norm-based measures, grow with the size of the model and thus incorrectly predict that larger networks generalize worse than smaller networks. Other approaches have tried to establish a connection between model compression and generalization (Arora et al., 2018; Zhou et al., 2019) . While both of these approaches are theoretically elegant and yield tighter bounds than bounds that are based on the size of uncompressed networks, they nonetheless grow with the size of the original network. Recent empirical studies (Jiang et al., 2020; 2019) , on the other hand, identify three classes of generalization measures that do seem predictive of generalization: measures that estimate the flatness of local minima, the speed of the optimization, and the margin of training samples to decision boundaries. While these measures are correlated with generalization, their failure to fully explain the test performance of the model demonstrate a need for other notions of model simplicity to explain generalization in neural networks. In this paper, we leverage the empirical observation that large fractions of trained neural networks' parameters can be pruned -that is, set to 0 -without hurting the models' performance (Gale et al., 2019; Zhu & Gupta, 2018; Han et al., 2015) . Based on this insight, we introduce a new measure of a model's simplicity which we call prunability: the smallest fraction of weights we can keep while pruning a network without hurting its training loss. In a range of empirical studies, we demonstrate that a model's prunability is indeed highly informative of its generalization. In particular, we find that the larger the fraction of parameters that can be pruned without hurting a model's training loss, the better the model will generalize. Overall, we show that the smaller the fraction of parameters a model actually "uses" -the simpler it is -the better the network's generalization performance. In summary, our main contributions are thus the following: 1. We introduce a new generalization measure called prunability that captures a model's simplicity (Sec. 4) and show that across a large set of models this measure is highly informative of a network's generalization performance (Sec. 5.4.1). 2. We show that even in a particularly challenging setting in which we observe a test loss double descent (Nakkiran et al., 2020; He et al., 2016) , prunability is informative of models' test performance and is competitive with an existing strong generalization measures (Sec. 5.4.2). 3. Lastly, we investigate whether the success of prunability can be explained by its relationship to flat local minima. We find that while prunability makes similar predictions to some existing measures that estimate the flatness of minima, it differs from them in important ways, in particular exhibiting a stronger correlation with generalization performance than these flatness-based measures (Sec. 5.4.1 and 5.4.3).

2. RELATED WORK

Zhang et al. ( 2016) demonstrate that neural networks can perfectly memorize randomly labeled data while still attaining good test set performance on correctly labeled data, a phenomenon that complexity measures such as the VC-Dimension and Rademacher complexity fail to explain. More recently, increasing the parameter count of models has been shown to improve their generalization performance, yielding 'double descent' curves in which the test loss first decreases, then increases and then decreases again as the parameter count is increased (He et al., 2016; Nakkiran et al., 2020; Belkin et al., 2019) . In contrast, many existing generalization measures grow monotonically with the model's size and thus fail to capture double descent, making it a particularly interesting setting to study new generalization measures (Maddox et al., 2020) . While generalization measures have been studied for a long time, Jiang et al. ( 2020) recently brought new empirical rigor to the field. They perform a large-scale empirical study by generating a large set of trained neural networks with a wide range of generalization gaps. Proposing a range of new evaluation criteria, they test how informative of generalization previously proposed generalization measures actually are. In this paper, we evaluate our new generalization measure in the same framework and use the strongest measures as baselines for comparison against our proposed measure, the neural network's prunability. A common theme across many generalization measures is that they try to formalize Occam's Razor, the idea that simpler models generalize better than more complex models. Rissanen (1986) formalizes this principle as a model's Minimum Description length which was later applied to small neural networks by Hinton & Van Camp (1993) 2019), on the other hand, formulate a generalization bound that is non-vacuous even for large network architectures based on the description length of a compressed network which in turn also grows with the size of the original network. While these approaches are theoretically elegant, all such measures grow with the number of parameters in a given network and thus do not capture the phenomenon that more highly overparameterized networks tend to generalize better. Three classes of generalization measures seem to be particularly predictive of generalization. First, measures that estimate the flatness of local minima, either based on random perturbations of the



.Similarly, Neyshabur et al. (2015; 2017)  and numerous other approaches suggest that networks with smaller parameter norms generalize better. We are aware of two existing generalization measures that are based on compression or pruning. Arora et al. (2018) derive a PAC-Bayesian generalization bound based on the size of a network following a particular compression method based on noise stability.Zhou et al. (

