GENERALIZATION BOUNDS VIA DISTILLATION

Abstract

This paper theoretically investigates the following empirical phenomenon: given a high-complexity network with poor generalization bounds, one can distill it into a network with nearly identical predictions but low complexity and vastly smaller generalization bounds. The main contribution is an analysis showing that the original network inherits this good generalization bound from its distillation, assuming the use of well-behaved data augmentation. This bound is presented both in an abstract and in a concrete form, the latter complemented by a reduction technique to handle modern computation graphs featuring convolutional layers, fullyconnected layers, and skip connections, to name a few. To round out the story, a (looser) classical uniform convergence analysis of compression is also presented, as well as a variety of experiments on cifar10 and mnist demonstrating similar generalization performance between the original network and its distillation.

1. OVERVIEW AND MAIN RESULTS

Generalization bounds are statistical tools which take as input various measurements of a predictor on training data, and output a performance estimate for unseen data -that is, they estimate how well the predictor generalizes to unseen data. Despite extensive development spanning many decades (Anthony & Bartlett, 1999) , there is growing concern that these bounds are not only disastrously loose (Dziugaite & Roy, 2017) , but worse that they do not correlate with the underlying phenomena (Jiang et al., 2019b) , and even that the basic method of proof is doomed (Zhang et al., 2016; Nagarajan & Kolter, 2019) . As an explicit demonstration of the looseness of these bounds, Figure 1 calculates bounds for a standard ResNet architecture achieving test errors of respectively 0.008 and 0.067 on mnist and cifar10; the observed generalization gap is 10 -1 , while standard generalization techniques upper bound it with 10 15 . Contrary to this dilemma, there is evidence that these networks can often be compressed or distilled into simpler networks, while still preserving their output values and low test error. Meanwhile, these simpler networks exhibit vastly better generalization bounds: again referring to Figure 1 , those same networks from before can be distilled with hardly any change to their outputs, while their bounds reduce by a factor of roughly 10 10 . Distillation is widely studied (Buciluȗ et al., 2006; Hinton et al., 2015) , but usually the original network is discarded and only the final distilled network is preserved. The purpose of this work is to carry the good generalization bounds of the distilled network back to the original network; in a sense, the explicit simplicity of the distilled network is used as a witness to implicit simplicity of the original network. The main contributions are as follows. • The main theoretical contribution is a generalization bound for the original, undistilled network which scales primarily with the generalization properties of its distillation, assuming that wellbehaved data augmentation is used to measure the distillation distance. An abstract version of this bound is stated in Lemma 1.1, along with a sufficient data augmentation technique in Lemma 1.2. A concrete version of the bound, suitable to handle the ResNet architecture in Figure 1 , is described in Theorem 1.3. Handling sophisticated architectures with only minor proof alterations is another contribution of this work, and is described alongside Theorem 1.3. This abstract and concrete analysis is sketched in Section 3, with full proofs deferred to appendices. • Rather than using an assumption on the distillation process (e.g., the aforementioned "wellbehaved data augmentation"), this work also gives a direct uniform convergence analysis, culminating in Theorem 1.4. This is presented partially as an open problem or cautionary tale, as Figure 1 : Generalization bounds throughout distillation. These two subfigures track a sequence of increasingly distilled/compressed ResNet8 networks along their horizontal axes, respectively for cifar10 and mnist data. This horizontal axis measures distillation distance Φ γ,m , as defined below in eq. (1.1). The bottom curves measure various training and testing errors, whereas the top two curves measure respectively a generalization bound presented here (cf. Theorem 1.3 and Lemma 3.1), and a generalization measure. Notably, the top two curves drop throughout a long interval during which test error remains small. For further experimental details, see Section 2. its proof is vastly more sophisticated than that of Theorem 1.3, but ultimately results in a much looser analysis. This analysis is sketched in Section 3, with full proofs deferred to appendices. • While this work is primarily theoretical, it is motivated by Figure 1 and related experiments: Figures 2 to 4 demonstrate that not only does distillation improve generalization upper bounds, but moreover it makes them sufficiently tight to capture intrinsic properties of the predictors, for example removing the usual bad dependence on width in these bounds (cf. Figure 3 ). These experiments are detailed in Section 2.

1.1. AN ABSTRACT BOUND VIA DATA AUGMENTATION

This subsection describes the basic distillation setup and the core abstract bound based on data augmentation, culminating in Lemmas 1.1 and 1.2; a concrete bound follows in Section 1.2. Given a multi-class predictor f : R d → R k , distillation finds another predictor g : R d → R k which is simpler, but close in distillation distance Φ γ,m , meaning the softmax outputs φ γ are close on average over a set of points (z i ) m i=1 : Φ γ,m (f, g) := 1 m m i=1 φ γ (f (z i )) -φ γ (g(z i )) 1 , where φ γ (f (z)) ∝ exp f (z)/γ . (1.1) The quantity γ > 0 is sometimes called a temperature (Hinton et al., 2015) . Decreasing γ increases sensitivity near the decision boundary; in this way, it is naturally related to the concept of margins in generalization theory, as detailed in Appendix B. due to these connections, the use of softmax is beneficial in this work, though not completely standard in the literature (Buciluȗ et al., 2006) . We can now outline Figure 1 and the associated empirical phenomenon which motivates this work. (Please see Section 2 for further details on these experiments.) Consider a predictor f which has good test error but bad generalization bounds; by treating the distillation distance Φ γ,m (f, g) as an objective function and increasingly regularizing g, we obtain a sequence of predictors (g 0 , . . . , g t ), where g 0 = f , which trade off between distillation distance and predictor complexity. The curves in Figure 1 are produced in exactly this way, and demonstrate that there are predictors nearly identical to the original f which have vastly smaller generalization bounds. Our goal here is to show that this is enough to imply that f in turn must also have good generalization bounds, despite its apparent complexity. To sketch the idea, by a bit of algebra (cf. Lemma A.2), we can upper bound error probabilities with expected distillation distances and errors: Pr x,y [arg max y f (x) y = y] ≤ 2E x φ γ (f (x)) -φ γ (g(x)) 1 + 2E x,y 1 -φ γ (g(x)) y .

