CATASTROPHIC FISHER EXPLOSION: EARLY PHASE FISHER MATRIX IMPACTS GENERALIZATION

Abstract

The early phase of training has been shown to be important in two ways for deep neural networks. First, the degree of regularization in this phase significantly impacts the final generalization. Second, it is accompanied by a rapid change in the local loss curvature influenced by regularization choices. Connecting these two findings, we show that stochastic gradient descent (SGD) implicitly penalizes the trace of the Fisher Information Matrix (FIM) from the beginning of training. We argue it is an implicit regularizer in SGD by showing that explicitly penalizing the trace of the FIM can significantly improve generalization. We further show that the early value of the trace of the FIM correlates strongly with the final generalization. We highlight that in the absence of implicit or explicit regularization, the trace of the FIM can increase to a large value early in training, to which we refer as catastrophic Fisher explosion. Finally, to gain insight into the regularization effect of penalizing the trace of the FIM, we show that it limits memorization by reducing the learning speed of examples with noisy labels more than that of the clean examples, and 2) trajectories with a low initial trace of the FIM end in flat minima, which are commonly associated with good generalization.

1. INTRODUCTION

Implicit regularization in gradient-based training of deep neural networks (DNNs) remains relatively poorly understood, despite being considered a critical component in their empirical success (Neyshabur, 2017; Zhang et al., 2016; Jiang et al., 2020b) . Recent work suggests that the early phase of training of DNNs might hold the key to understanding these implicit regularization effects. Golatkar et al. (2019) ; Keskar et al. (2017) ; Sagun et al. (2018) ; Achille et al. (2019) show that by introducing regularization later, a drop in performance due to lack of regularization in this phase is hard to recover from, while on the other hand, removing regularization after the early phase has a relatively small effect on the final performance. Other works show that the early phase of training also has a dramatic effect on the trajectory in terms of properties such as the local curvature of the loss surface or the gradient norm (Jastrzebski et al., 2020; Frankle et al., 2020) . In particular, Achille et al. ( 2019 2019) show that using a large learning rate has a dramatic effect on the early optimization trajectory in terms of the loss curvature. These observations lead to a question: what is the mechanism by which regularization in the early phase impacts the optimization trajectory and generalization? We investigate this question mainly through the lens of the Fisher Information Matrix (FIM), a matrix that can be seen as approximating the local curvature of the loss surface in DNNs (Martens, 2020; Thomas et al., 2020) . Our main contribution is to show that the implicit regularization effect of using a large learning rate or a small batch size can be modeled as an implicit penalization of the trace of the FIM (Tr(F)) from the very beginning of training. We demonstrate on image classification tasks that the value of Tr(F) early in training correlates with the final generalization performance across settings with different learning rates or batch sizes. We then show evidence that explicitly regularizing Tr(F) (which we call Fisher penalty) significantly improves generalization in training with a sub-optimal learning rate. On the other hand, growth of Tr(F) early in training, which may occur in practice when using a relatively small learning rate, coincides with poor generalization. We call this phenomenon the catastrophic Fisher explosion. Figure 1 illustrates this effect on the TinyImageNet dataset (Le & Yang, 2015) . Our second contribution is an analysis of why implicitly or explicitly regularizing Tr(F) impacts generalization. We reveal two effects of implicit or explicit regularization of Tr(F): (1) penalizing Tr(F) discourages memorizing noisy labels, (2) small Tr(F) in the early phase of training biases optimization towards a flat minimum, as characterized by the trace of the Hessian.

2. IMPLICIT AND EXPLICIT REGULARIZATION OF THE FIM

Fisher Information Matrix Consider a probabilistic classification model p θ (y|x), where θ denotes its parameters. Let (x, y; θ) be the cross-entropy loss function calculated for input x and label y. Let g(x, y; θ) = ∂ ∂θ (x, y; θ) denote the gradient computed for an example (x, y). The central object that we study is the Fisher Information Matrix F defined as F(θ) = E x∼X ,ŷ∼p θ (y|x) [g(x, ŷ)g(x, ŷ) T ], (1) where the expectation is often approximated using the empirical distribution X induced by the training set. We denote its trace by Tr(F). Later, we also look into the Hessian H(θ) = ∂ 2 ∂θ 2 (x, y; θ). We denote its trace by Tr(H). The FIM can be seen as an approximation to the Hessian (Martens, 2020). In particular, as p(y|x; θ) → p(y|x), where p(y|x) is the empirical label distribution, the FIM converges to the



); Jastrzębski et al. (2019); Golatkar et al. (2019); Lewkowycz et al. (2020); Leclerc & Madry (2020) independently suggest that rapid changes in the local curvature of the loss surface in the early phase critically affects the final generalization. Closely related to our work, Lewkowycz et al. (2020); Jastrzębski et al. (

Figure1: The catastrophic Fisher explosion phenomenon demonstrated for Wide ResNet trained using stochastic gradient descent on the TinyImageNet dataset. Training is done with either a learning rate optimized using grid search (η 1 = 0.0316, red), or a small learning rate (η 2 = 0.001, blue). Training with η 2 leads to large overfitting (left) and a sharp increase in the trace of the Fisher Information Matrix (FIM, middle). The trace of the FIM is closely related to the gradient norm (right).

Figure 2: Association between the value of Tr(F) in the initial phase of training (Tr(F i )) and test accuracy on ImageNet, CIFAR-10 and CIFAR-100 datasets. Each point corresponds to multiple seeds and a specific value of learning rate. Tr(F i ) is recorded during the early phase of training (2-7 epochs, see the main text for details). The plots show that early Tr(F) is predictive of final generalization. Analogous results illustrating the influence of batch size are shown in Appendix A.1 .

