ON FLAT MINIMA, LARGE MARGINS AND GENERAL-IZABILITY

Abstract

The intuitive connection to robustness and convincing empirical evidence have made the flatness of the loss surface an attractive measure of generalizability for neural networks. Yet it suffers from various problems such as computational difficulties, reparametrization issues, and a growing concern that it may only be an epiphenomenon of optimization methods. We provide empirical evidence that under the cross-entropy loss once a neural network reaches a non-trivial training error, the flatness correlates (via Pearson Correlation Coefficient) well to the classification margins, which allows us to better reason about the concerns surrounding flatness. Our results lead to the practical recommendation that when assessing generalizability one should consider a margin-based measure instead, as it is computationally more efficient, provides further insight, and is highly correlated to flatness. We also use our insight to replace the misleading folklore that smallbatch methods generalize better because they are able to escape sharp minima. Instead we argue that large-batch methods did not have enough time to maximize margins and hence generalize worse.

1. INTRODUCTION

Understanding under which conditions a neural network will generalize from seen to unseen data is crucial, as it motivates design choices and principles which can greatly improve performance. Complexity or generalization measures are used to quantify the properties of a neural network which lead to good generalization. Currently however, established complexity measures such as VC-Dimension (Vapnik, 1998) or Rademacher Complexity (Bartlett & Mendelson, 2002) do not correlate with the generalizability of neural networks (e.g. see Zhang et al. (2016) ). Hence many recommendations, such as reducing model complexity, early stopping, or adding explicit regularization are also not applicable or necessary anymore. Therefore, there is an ongoing effort to devise new complexity measures that may guide recommendations on how to obtain models that generalize well. A popular approach is to consider the flatness of the loss surface around a neural network. Hochreiter & Schmidhuber (1997) used the minimum description length (MDL) argument of Hinton & Van Camp (1993) to claim that the flatness of a minimum can also be used as a generalization measure. Motivated by this new measure Hochreiter & Schmidhuber (1997) , and more recently Chaudhari et al. (2019) , developed algorithms with explicit regularization intended to converge to flat solutions. Keskar et al. (2016) then presented empirical evidence that flatness relates to improved generalizability and used it to explain the behavior of stochastic gradient descent (SGD) with large and small-batch sizes. Other works since have empirically corroborated that flatter minima generalize better (e.g. Jiang et al. (2019) ; Li et al. (2018) ; Bosman et al. (2020) ). There are however various issues that are still unresolved, which makes using flatness for constructing practical deep learning recommendations difficult. For one, flatness is computationally expensive to compute. The most common way to compute the flatness is via the Hessian, which grows quadratically in the number of parameters; this becomes too large when used with modern networks containing millions of parameters. It is also not clear to what extent flatness is a true measure of generalizability, capable of discerning which neural network will or will not generalize. Dinh et al. (2017) showed that reparametrizations affect flatness and a flat model can be made arbitrarily sharp without changing any of its generalization properties. In addition Probably Approximately Correct (PAC-Bayes) bounds that bound the generalizability in terms of the flatness are also either affected by rescaling, impossible to evaluate or loose (Neyshabur et al., 2017; Arora et al., 2018; Petzka et al., 2020) . While there have been solutions attempting to prevent issues around reparametrization (Liang et al., 2019; Tsuzuku et al., 2019) , it remains to establish whether flatness is an epiphenomenon of stochastic gradient descent or other complexity measures as Achille et al. (2018) and Jastrzebski et al. (2018) are suggesting. This motivates investigating possible correlations to more well-understood measures of generalization that may help alleviate issues surrounding flat minima, while allowing flat minima to be used when appropriate. In this paper we will demonstrate a correlation to classification margins, which are a well-understood generalization measure. Margins represent the linearized distance to the decision boundaries of the classification region (Elsayed et al., 2018) . An immediate consequence of such a relationship is that to assess generalizability, we could now simply use a computationally cheap and more robust margin based complexity measure. Our contributions will demonstrate further practical implications of the relationship between margins and flatness which open doors to valuable future work such as a better understanding of why and when a model generalizes and more principled algorithm design. • We prove that under certain conditions flatness and margins are strongly correlated. We do so by deriving the Hessian trace for the affine classifier. Based on its form, we derive an expression in terms of classification margins which we show correlates well with the Hessian trace, with increasing training accuracy for various neural network architectures. By being able relate the two complexity measures, we are now able to provide various practical recommendations, and offer different perspectives on phenomena that may not be explainable without such a view. These are shown in the following contributions. • We use our insight to replace the misleading folklore that, unlike large-batch methods, small-batch methods are able to escape sharp minima (Keskar et al., 2016) . We instead employ a margin perspective and use our empirical results along with recent results by Banburski et al. (2019) and Hoffer et al. (2017) to argue that a large batch method was unable to train long enough to maximize the margins. With our explanation, we help reframe the small and large-batch discussion and build further intuition. • We show that once a neural network is able to correctly predict the label of every element in the training set it can be made arbitrarily flat by scaling the last layer. We are motivated by the relationship to margins which suffer from the same issue. We highlight this scaling issue because, in some instances, it may still be beneficial for algorithm design to be guided by convergence to flat regions. Hence, we need to account for scaling issues which make it difficult to use flatness to assess whether a network generalizes better than another. Other works have made connections between flatness and well-behaved classification margins via visualizations (see Huang et al. (2019) ; Wang et al. (2018) ), but they have not demonstrated a quantifiable relationship. Further work has used both the classification margins and flatness to construct PAC-Bayes bounds (Neyshabur et al., 2017; Arora et al., 2018) , and have related flatness to increased robustness (Petzka et al., 2020; Borovykh et al., 2019) however they did not show when and to what extent these quantities are related. We structure the paper as follows. In Section 2, we discuss both our notation and our motivation choosing the cross-entropy loss and the Hessian trace as the flatness measure and provide further background on the classification margins. In Section 3, we present our contribution showing a strong correlation between the margins and flatness by deriving. In Section 4, we combine recent results based on classification margins to offer a different perspective on the misleading folklore on why larger-batch methods generalize worse. In Section 5, we highlight that networks can be made arbitrarily flat. Lastly, we offer our thoughts and future work in the Section 6.

2. PROBLEM SETTING

We first define the basic notation that we use for a classification task. We let X represent the input space and Y = {1, ..., C} the output space where C are the number of possible classes. The network architecture is given by φ : Θ × X → R |Y| where Θ is the corresponding parameter space. We measure the performance of a parameter vector by defining some loss function : R C × Y → R. If we have have a joint probability distribution D relating input and output space then we would like to minimize the expected loss L D (θ) = E (x,y)∼D [ (φ(θ, x), y) ]. Since we usually only have access to some finite dataset D, we denote the empirical loss by LD (θ) = 1 |D| |D| i=1 (φ(θ, x i ), y i ). If L D and LD are close, then we would say a model generalizes well, as we were able to train on a finite dataset and extrapolate to the true distribution. We will use the cross-entropy loss which is given by (φ(θ, x), y) = -log(S y (φ(θ, x))) where the softmax function S : R C → R C is given by S(a) i = e a i C j=1 e a j (see Goodfellow et al. (2016) ). The choice of the cross-entropy function as the loss function has a significant impact on how the flatness measure behaves. Unlike the multiclass mean squared error (MMSE), exponential type losses such as the cross-entropy loss on neural networks have been shown to include implicit regularization which leads to margin maximizing solutions for neural networks (Banburski et al., 2019) . Also, various properties for flat minima which have been proven for the MMSE loss by Mulayoff & Michaeli are not applicable to the cross-entropy loss, further highlighting the fundamental differences between the loss functions. While the MMSE loss has shown some promise for many classification tasks (Hui & Belkin, 2020) the cross-entropy loss is still the loss which is most used and was primarily used for the empirical evidence around flat minima (Keskar et al., 2016; Chaudhari et al., 2019) , which motivates our choice. The qualitative description of a flat region was given by Hochreiter & Schmidhuber (1997) as "a large connected region in parameter space where the error remains approximately constant". We measure the flatness by the trace of the Hessian of the loss with respect to the parameters (in short the Hessian trace) denoted by T r(H θ ( LD (θ)) (Dinh et al., 2017) . Since the Hessian is symmetric, the Hessian trace is equivalent to the sum of its eigenvalues which for a fixed parameter space is proportional to the expected increase of the second order approximation of the loss around a fixed minimum θ in a random direction θ with θ ∼ N (θ, I). Since we apply flatness arguments only close to minima, we assume that all eigenvalues are positive and that the Hessian trace is a good measure of flatness Sagun et al. (2017) . Even though the Hessian is only an approximation of flatness, the Hessian is often preferred as it allows us to reason about various directions in parameter space via its eigenvectors and eigenvalues (see Sagun et al. (2017) ; Chaudhari et al. (2019) ) and alleviates the issue of infinitely long but sharp ridges making a minimum infinitely flat (Dinh et al., 2017; Freeman & Bruna, 2016) . The Hessian has also been linked to feature robustness via its use in the second order approximation of the loss (e.g. Petzka et al. (2020) ; Borovykh et al. (2019) ) and is a promising quantity to relate to the margins. As we are working with non-linear functions it is intractable to compute exact distances to the decision boundary, therefore we use a measure which is related to the linearized distance as described in Elsayed et al. (2018) . Under this view, larger margins are better because the data is further from the decision boundary. Specifically, we define the margins as in Neyshabur et al. (2017) : for some vector v ∈ R C and label y we let the margin of v be γ(v, y) = |v y -max j =y v j |. Since we use the margin in different contexts we define the output margins γ(φ(θ, x), y) and the margins of the model output after the softmax layer γ(S(φ(θ, x)), y). Due to the intuition of margins relating to the regularity of the classification regions, they have been proven and shown to be a good generalization measure for linear networks (Langford & Shawe-Taylor, 2003) and later for neural networks (see Bartlett et al. (2017) ; Jiang et al. (2018; 2019) ) when correctly adjusted. Due to results by Banburski et al. (2019) and Soudry et al. ( 2018), Poggio et al. (2019) claimed that a large part of the mystery around generalizability has been solved, since standard optimization methods are maximizing the margin instead of memorizing data.

3.1. THE AFFINE CROSS-ENTROPY HESSIAN TRACE

Generally, it is difficult to derive a closed form solution of the Hessian trace due to the non-linear nature of neural networks. To gain insight into what may determine the flatness or sharpness of a solution we consider an affine prediction function for which we derive the following simple and insightful expression for the Hessian trace: Proposition 3.1 (Affine Cross-Entropy Hessian Trace (ACEHT)). Assume an affine predictor given by φ ((θ, b)  , x) = θx + b where (θ, b) ∈ R C×d × R C = Θ. Then the trace of the Hessian under the cross-entropy loss assuming our predictor function is: T r(H( (φ((θ, b), x), y))) = (|x| 2 + 1)(1 - C j=1 S 2 j (φ(Θ, x)) = (|x| 2 + 1)(1 -|S(φ(Θ, x))| 2 ). The derivation is in Appendix C. We immediately observe that the trace of the Hessian is a product of both the size of the input and 1 -η(S(φ(θ, x))) where η(S(φ(θ, x))) = C j=1 S 2 j (φ(Θ, x)), where we can view 1 -η(S(φ(θ, x))) as a confidence measure. In the visualization provided in Figure 1 we clearly see that 1 -η(S(φ(θ, x))) is only zero when the predictor predicts one class with probability 1, regardless of whether it is the correct class or not. When the model is least confident, namely when every entry is predicted with probability 1/C, then 1 -η(S(φ(θ, x))) is also highest. Hence, in the affine case with a cross-entropy loss the Hessian trace can be seen as an indication of the model confidence in its prediction. This confidence interpretation is also connected to classification margins by observing that S y ≥ γ(S(φ(θ, x)), y) and hence (1 - C j=1 S 2 j (φ(Θ, x)) ≤ 1 -S 2 y ((φ(Θ, x))) ≤ 1 -γ 2 (S(φ(θ, x)), y). Therefore, if the margins are large then the region will also be flat. The intuition for this is that the error in the upper bound becomes smaller as S y becomes larger, i.e. when the model predicts correctly and confidently. We will also provide evidence for a converse, i.e. a flat minimum has large margins, in the following experimental sections. Finally, we note that without the expression in Proposition 3.1 we would not have been able to derive the upper bound 1 -γ 2 (S(φ(θ, x)), y) without guesswork. )) represents the confidence of a model's prediction we plot 1 -η(a) for all a ∈ R 3 such that a is a valid probability distribution over three classes (i.e. for all the elements of the standard 2-simplex). Since there are only two free variables, x and y in the plot represent a by a = [x, y, 1 -x -y]. We see that 1 -η(a) is only zero when a = e i for some i, namely when a model would be most confident. We also note that 1 -η(a) is largest when a model would be least confident in its prediction-i.e. when a = [1/3, 1/3, 1/3].

3.2. EXTENSION TO THE NON-LINEAR CASE

Now we will attempt to extend the derivation of the previous section to the non-linear case. This is a challenging undertaking so we will resort to numerical evidence. To extend the results from the affine case we will consider both the ACEHT and the upper bound ACEHT (S(φ(θ, x)))) ≤ |x|(1 -γ 2 (S(φ(θ, x)), y)) to which we refer as the "margin bound". We will compare both quantities to the empirically derived Hessian trace. To compute the empirical Hessian trace we use the PyHessian package (Yao et al., 2019) which implements Hutchinson's method (Bai et al., 1996; Avron & Toledo, 2011) . To compare the quantities we will compare them in terms of their distributions over the data. Specifically, let (X, Y ) ∼ D and fix θ then we compute the Pearson Correlation Coefficient (rvalue) (Lee Rodgers & Nicewander, 1988) between the random variables T r(H( (φ(θ, X)), Y )) and ACEHT (S(φ(θ, X))) and similarly for the margin bound. The choice of the r-value is natural because in the affine case the ACEHT and the Hessian trace are equivalent, therefore a linear relationship should be expected. Our method is also more general than just comparing some statistic, such as the average (which is generally used for flatness measures), of the above random variables. For example, while the smallest margin over the dataset is commonly used a generalization measure (Bartlett et al., 2017; Jiang et al., 2019; Neyshabur et al., 2017) , Jiang et al. (2018) showed that higher moments of the distribution are a much better predictor for generalizability as we will also see in Section 4. Figure 2 is an examples of such a fit for an affine predictor. While the high r-value of 0.97 confirms our analytic results, we also observe that the fit is not perfect, as would be expected due to the exact relationship. The inaccuracies are due to the numerical methods used and become more pronounced the higher the Hessian trace is. To avoid outliers heavily impacting the linear regression model in the non-linear case, we will use the SciPy function LocalOutlierFactor (Breunig et al., 2000) to remove outliers before fitting the line. With this we prevent hand picking points to skew the results and will also stabilize our results. 

3.2.1. EMPIRICAL EVIDENCE

We present our results using the convolutional neural network LeNet on the MNIST dataset as they are representative of what we have observed on other architectures, hyperparameters, and datsets (see Appendix B). Our results use stochastic gradient descent with a fixed learning rate and batch size to achieve an appropriate performance on the classification task. Because of the computational difficulty of computing the empirical Hessian trace for every single element in the input data, we consider 1,000 randomly selected datapoints from the training-set. To highlight the computational difficulty of using even very optimized numerical tools, such as PyHessian, we note that it takes us roughly 1,5 hours to compute the Hessian trace for the whole MNIST dataset while it only takes 5 seconds for the margins. In Figure 3 we present the plots related to the correlation of the empirical Hessian trace to the ACEHT and margin bound over the randomly sampled datapoints. Figures 3a and 3b show that for most of training, the correlation is between 0.8 and 1. Combining Figures 3a and 3c it can be seen that the r-value increases with the model training accuracy. Furthermore, the datapoint which are incorrectly predicted do not show a correlation. With that we confirm the intuition that indeed, flatter solution are more robust and have larger margins. While we have found flatness and margins to be highly correlated in scenarios in which others have identified flatness to be a good generalization measure (Jiang et al., 2019; Keskar et al., 2016; Chaudhari et al., 2019) , it may just be that this is also an epiphenomenon of stochastic gradient descent or some other process and there may be situations in which the relationship does not hold. However, our general advice to consider margins more is not impacted by this. In the scenario where generalizability and flatness have been linked, we have also shown that margins and flatness are correlated, hence it is advantageous to use margins instead due to computational reasons or for more complete intuition. The only situation in which it is more likely that margins and flatness are not correlated is when flatness has not yet been linked to generalizability. In such a situation it may also be better to use the better understood margin measure instead of using a flatness measure to assess generalizability. In the next section we will consider the first case, where we examine a general scenario in which flatness has been used to reason about generalizability and offer a more insightful margin perspective. In Figure 3c we observe that the increase in correlation occurs with an increase in training accuracy. To demonstrate the evolution of the distributions throughout training we plot the ACEHT and empirical HT distribution against each other in Figures 3d 3e 3f . We observe that while the most apparent outliers were removed, some still skew the linear regression.

4. PERSPECTIVE ON LARGE AND SMALL-BATCH METHODS

We now show how our results lead to a better understanding of phenomena which have been misleadingly attributed to flat minima. To do so, we consider the experiments which rekindled the debate around flat minima by Keskar et al. (2016) , where flatness was used to explain why smallbatch methods tend to generalize better than large-batch methods. The idea was that small-batch methods converge to flatter minima due to them being able to "escape" sharp minima more easily. However, it has been shown that the minima of both methods appear to be in the same attractive basin (Sagun et al., 2017; Freeman & Bruna, 2016; Draxler et al., 2018) , meaning that small-batch methods do not seem to escape any attractive basin but are merely in a different area of the same attractive basin. While the results gave credence to flatter minima generalizing better, flatter minima do not seem to provide the full picture for why large-batch methods tend to do worse and we believe that an explanation in terms of the margins is more illuminating.

4.1. EXPERIMENT SETUP

We will replicate the experiment by Keskar et al. ( 2016) for a fully connected network with batchnormalized layers on the MNIST dataset as described in Appendix A. We chose the large-batch size to be 4096 and the small-batch size to be 256. To have a fair comparison, we use the same seed and take 10,000 gradient steps for both methods, instead of basing the stopping time on epochs. We also used stochastic gradient descent without Momentum. With our setup we observe a similar phenomenon as Keskar et al. ( 2016) in Table 1 . The small and large-batch method both attain the same training accuracy and comparable training loss. However, the small-batch method is at a considerably flatter minimum and generalizes better than the large-batch method. We will now show that instead of considering the flatness, it would be more insightful to consider margins to explain the difference in generalizability. While the upper bound of ACEHT is in terms of the softmax margins, we consider the output margins in this section. The reason is that most margin based generalization measures use the output margins. Another more practical reason is that towards the end of training, the softmax margins are all very close to 1 making it difficult to visualize and observe the distribution. We also do not use a normalized version of the margins (such as Bartlett et al. (2017) ; Jiang et al. (2018) ). Our reasoning is that because we use the same architecture, the same dataset, and train in a similar manner the margin distributions will be comparable. 

4.2. A MARGIN PERSPECTIVE ON LARGE AND SMALL-BATCH SIZES

In Figure 4 we see that the output margins and the Hessian trace are correlate as expected from Section 3. We can also roughly see that the small-batch method has fewer low margins than the largebatch methods. To emphasize this difference we consider Figure 4c where we plot the histogram and box-plot of the output margin distribution for both the large-batch and small-batch method. We also display the skewness of each, which is the third moment centered around the mean. The box-plots and the skewness confirm that the small-batch method is dominated by large margins indicating better generalizability (as discussed in Bartlett et al. (2017) ; Jiang et al. (2018) ). The idea with a left-skewed margin distribution is that the tail with low margin datapoints is mostly compromised of outliers and will not massively affect the robustness to input perturbations. This soft-margin SVM perspective is in contrast to hard-margin SVMs where the margin is defined to be the minimum of all the distances to the decision boundary (Shalev-Shwartz & Ben-David, 2014) . If a hard-margin view was adopted, then the small-batch method would be predicted to generalize worse, because it has the smallest margin as we see in Figure 4c . However, the distribution of the small-batch method is also more left skewed, which would point to this minimum being an outlier rather than being indicative of generalizability. We now want to explain why the small-batch method generalizes well. As observed in Jastrzębski et al. ( 2017) a smaller batch-size is similar to a larger learning-rate, hence at every step the process will advance further than a large batch-method would. It has already been noted by Hoffer et al. (2017) that training longer leads the large-batch method to generalize just as well as the small-batch method because it had time to "catch up", even though the decrease in training loss may be barely noticeable. We have also seen that SGD converges to margin maximizing solutions by Banburski et al. (2019) . Therefore, a method that is able to train or advance further, will also be closer to a margin maximizing solution. We therefore expect that large-batch methods not having had enough time to maximize margins is the driving force behind the large vs small-batch phenomenon. Figure 4 : We plot both small-batch method (orange) and large-batch method (blue). In Figures 4a and 4b we plot the output margins against the Hessian trace for each datapoint. We observe a strong relationship between the Hessian trace and the output margins. In Figure 4c we plot both the histogram and box-plot and display the skewness (the third standardized moment) for both the large and small-batch method's margin distributions. We observe that the distribution of the small-batch method is more left skewed which would indicate better generalizability independent of the flatness.

5. BECOMING FLATTER WITH INCREASING MARGINS

Reparametrization problems such as shown by Dinh et al. (2017) are neither a new phenomenon nor should they necessarily discourage the design of algorithms which attempt to find flat minima. Rather they inform on what aspects of a generalization measure need to be adjusted to allow them to be used in a practical setting. For SVMs, the problem of scaling the hyperplane normal to increase margins of correctly classified points is solved by scaling the normal to make it a unit vector, transforming the functional margin into the geometric margin (Shalev-Shwartz & Ben-David, 2014). In the case of neural networks, it is also known that scaling the last layer leads to an increase in the margins for data which has been correctly predicted (Neyshabur et al., 2017) . This scaling issues has been successfully addressed (see Bartlett et al. (2017) ; Elsayed et al. (2018) ; Jiang et al. (2018) ). Due to the relationship to the classification margins it is natural to ask if flatness suffers from a similar problem. We confirm this with the following Proposition: Proposition 5.1. For a given neural network φ let T α : Θ → Θ be such that for all x ∈ X and θ ∈ Θ we have φ(T α (θ), x) = αφ(θ, x). Now assume that θ ∈ Θ and a datapoint (x , y ) for which argmax k∈{1,...,C} (φ(θ, x )) k = y then ∀s, t ∈ {1, ..., dim(Θ)} lim α→∞ ∂ θs ∂ θt (φ(T α (θ ), x ), y ) = 0. ( ) The proof is in the Appendix D. From the Proposition we immediately derive the following Corollary: Corollary 5.2. Assume that φ and θ predict every datapoint in a set D correctly then ∀s, t ∈ {1, ..., dim(Θ)} lim α→∞ ∂ θs ∂ θt L D (T α (θ)) = 0. Due to the Corollary, if a network has achieved full training accuracy, then the network is equivalent under the T α transformation to an arbitrarily flat network. We note that there exists such a T α transform for most networks. Scaling the last layer is one simple instance of such a transform. Another is that for fully connected and convolutional networks with ReLU non-linearities we observe that by the non-negative homogeneity scaling each layer also results in a valid T α transformation. The crucial property of the T α map is that it does not change the relative order of the model outputs and therefore, given two networks which have achieved full training accuracy we can not determine which network should generalize better based solely on the flatness of the local-geometry. We note that Banburski et al. (2019) mentioned such an issue but they did not discuss it in the context of flat minima and their arguments relied on further structure which we believe is less illuminating than our presentation and proofs.

6. CONCLUSIONS

In this paper, we have related flatness to the classification margins in a principled manner, in contrast to other works that have made a more intuitive or less quantifiable connection (Huang et al., 2019; Wang et al., 2018; Neyshabur et al., 2017; Petzka et al., 2020) . Our results lead to the immediate practical recommendation of using margins instead of the computationally expensive flatness to assess generalizability. We also use our results to replace the misleading notion that small-batch methods generalize better because they "escape" sharp minima, instead arguing that small-batch methods have more time to maximize margins. We were also motivated by the flatness and margin relationship to highlight that neural networks can be made arbitrarily flat. This implies that the generalizability of two networks can not be distinguished based on flatness and hence needs to be addressed to make flatness a viable generalization measure. Based on our results, future work may assess whether flatness is an epiphenomenon of the optimization methods, because now recent work on margins (e.g. Banburski et al. (2019); Soudry et al. (2018) ) can be applied to reason about flatness. Furthermore, by relating properties of the parameter space (flatness) to properties of the input space (margin) there is now an opportunity to further explore results such as by Sagun et al. (2017) , where they found that the Hessian, with respect to the parameters of a neural network upon convergence, has as many positive eigenvalues as the number of classes in the dataset used. Overall, our results enable more principled discussion on how flatness may contribute to generalizability. 

B APPENDIX: FLATNESS AND MARGIN CORRELATION

Here we present further evidence of the flatness and margin correlation discussed in Section 3. Like in Section 3 we have used appropriate learning rates and batch sizes to get a reasonable performance for the task, and have observed our results to hold for different hyperparameters. One instance where we demonstrate two different batch-sizes is for the Fully Connected Network with Batch Normalization on MNIST (Section B.1) where we present results for a batch size of 256 and 4096. We again only consider 1,000 randomly selected datapoints from the training-set due to the computational difficult of computing the Hessian trace. If the network achieves full training accuracy and there are no incorrectly classified datapoints, we set the r-value to zero. Overall, we observe the same results as in Section 3 and a correlation between 0.8 and 1. As before, the correlation increases with increasing training accuracy for correctly predicted datapoints. For the general form we consider the cross-entropy loss for a predictor function which is scaled by some scalar α. Specifically, we assume an arbitrary input-output pair (x, y) ∈ X × Y and will compute the partial derivatives with respect to the parameters θ of the predictor function αφ(θ, x). Since the equations can become very long we will declutter the notation by letting S = S(αφ(θ, x)), φ = φ(θ, x) and for two d-dimensional vectors x, y ∈ R d we write x, y = d i=1 x i y i . We also denote elementwise multiplication by and let Φ be a matrix such that (Φ) ij = φ j . Lemma C.1. The first partial derivative of the cross-entropy loss with respect to an element θ i is: ∂ θi (αφ(θ, x), y) = -α(∂ θi φ y - C l=1 ∂ θi φ l S l (φ)) = -α(∂ θi φ y -∂ θi φ, S ). Proof. ∂ θi (αφ, y) = ∂ θi -log(S y ) = - 1 S y ∂ θi S y With some manipulation we compute ∂ θi S y :  ∂ θi S y = ∂ θi e αφy Combining Equations 3 and 4 we obtain Lemma C.1: ∂ θi (αφ, y) = -α(∂ θi φ y - C l=1 ∂ θi φ l S l ). Lemma C.2. The second partial derivative of the cross-entropy loss with respect to elements θ s and θ t is: ∂ θt ∂ θs (αφ(θ, x), y) = -α(∂ θt ∂ θs φ y -∂ θt ∂ θs φ, S -∂ θs φ, αS (∂ θt φ -(∂ θt Φ)S) ). Proof. Differentiating the first order derivative given by Lemma C.1 we obtain by the multi-variable chain rule: ∂ θt ∂ θs (αφ, y) = -α(∂ θt ∂ θs φ y -∂ θt ∂ θs φ, S -∂ θs φ, ∂ θt S ). To compute ∂ θt S in Equation 5we use Equation 4 and obtain: (∂ θt S) i = ∂ θt S i = S i α(∂ θt φ i -∂ θt φ, S ), which after some simplification reduces to: ∂ θt S(φ) = αS (∂ θt φ -(∂ θt Φ)S). Combining Equations 5 and 6 we obtain Lemma C.2: ∂ θt ∂ θs (αφ, y) = -α(∂ θt ∂ θs φ y -∂ θt ∂ θs φ, S -∂ θs φ, αS (∂ θt φ -(∂ θt Φ)S) ).

C.2 AFFINE CROSS-ENTROPY HESSIAN TRACE

We now present the proof of Proposition 3.1: Proof. Throughout the proof we make use of Lemma C.2 and let α = 1. We also notice that any second derivative with respect to φ((θ, b), x) is zero since φ is an affine classifier. We first consider the derivatives with respect to elements of θ where we use θ i,j to denote the element in the ith row and jth column of the matrix θ. Notice that ∂ θi,j φ = x j e i which we write as x i j . The second order derivatives are given by: ∂ θi,j ∂ θs,t (φ, y) = x s t , S (x i j -x j S i 1) = -x t (S s ((x i j ) s -x j S i )), when computing the trace we only compute the elements on the diagonal and hence we get: ∂ θi,j ∂ θi,j (φ, y) = x j (S i (x j -x j S i )) = x 2 j S i (1 -S i ) Now we consider derivatives with respect to elements of b and notice that ∂ bi φ = e i . For the second derivative we then get: ∂ bi ∂ bj (φ, y) = e j , S (e i -e i S i ) = δ ij S i (1 -S i ). Finally, summing up the diagonal of the total Hessian we get: T r(H(l(Θ, x, y))) = (|x| 2 + 1)(1 - j S 2 j ) where we used the fact that i S i = 1. where the last line follows from applying L'Hopital's rule k times. Since we assumed that y is the only y ∈ {1, ..., N } such that φ y = max k∈{1,...,N φ k , we have that φ k < φ y for all k = y. Hence, as α → ∞ we have e α(φ k -φy) → 0. Therefore: lim α→∞ k!e α(φi-φy) (φ y -φ i ) k + C k=1,k =y (φ k -φ i ) k e α(φ k -φy) = 0 We are now ready to prove Proposition 5.1: Proof. We first show that the term -α 2 ∂ θi φ, S (∂ θt φ -∂ θt Φ, S ) ) always goes to zero. Expanding we get: 



Figure1: To visualize how 1 -η(S(φ(θ, x))) represents the confidence of a model's prediction we plot 1 -η(a) for all a ∈ R 3 such that a is a valid probability distribution over three classes (i.e. for all the elements of the standard 2-simplex). Since there are only two free variables, x and y in the plot represent a by a = [x, y, 1 -x -y]. We see that 1 -η(a) is only zero when a = e i for some i, namely when a model would be most confident. We also note that 1 -η(a) is largest when a model would be least confident in its prediction-i.e. when a = [1/3, 1/3, 1/3].

Figure 2: We consider the affine predictor φ((θ, b), x) = θ t x + b with arbitrarily chosen parameters (θ, b) on 1, 000 randomly sampled datapoints from the MNIST dataset. Each scatter point represents a datapoint for which we compute the ACEHT and the empirical Hessian trace. The linear relationship between the ACEHT distribution and empirical Hessian trace both confirms our derivation of the ACEHT and provides a baseline for the numerical methods used for the empirical Hessian trace.

Figure 3: We observe how ACEHT and the softmax margin (SM) bound relate, via the Pearson Correlation Coefficient, to the observed Hessian trace (HT) of LeNet trained on MNIST. In Figures 3a 3b we see that for correctly classified datapoints (orange) the empirical Hessian trace correlates well with the ACEHT and SM bound.In Figure3cwe observe that the increase in correlation occurs with an increase in training accuracy. To demonstrate the evolution of the distributions throughout training we plot the ACEHT and empirical HT distribution against each other in Figures3d 3e 3f. We observe that while the most apparent outliers were removed, some still skew the linear regression.

Initialization (e) Step 5000 (f) Final Step (10,000) B.2.1 FULLY CONNECTED NETWORK WITH BATCH NORMALIZATION ON CIFAR10 (a) ACEHT to HT (b) SM Bound to HT (c) Training Accuracy (d) Initialization (e) Step 5000 (f) Final Step (10,000) C APPENDIX: DERIVATIVES OF THE CROSS-ENTROPY LOSS C.1 GENERAL FORM

θi φ l S l ).

APPENDIX: SCALING PROOFTo prove Proposition 5.1 we first prove the following lemma: Lemma D.1. Assume that the argmax of φ is the correct class y and is unique then for k ∈ N, k ≥ 1 and i = y:lim α→∞ α k S i (αφ) = 0 (7)Proof. Let y be such that φ y = max k∈{1,...,C} φ k . For i = y we have:-φ i ) k e α(φy-φi) + C k=1,k =y (φ k -φ i ) k e α(φ k -φi)

θi φ l (S l (∂ θt φ l -C k=1 ∂ θt φ k S k )) = α 2 ∂ θi φ y S y (∂ θt φ l -C k=1 ∂ θt φ k S k ) + α 2 θi φ l (S l (∂ θt φ l -C k=1 ∂ θt φ k S k ))We now show that each term in the sum goes to zero. Consider l = y:θi φ l S l (∂ θt φ l -∂ θt φ k S k )| ≤ α 2 S l C k=1 |∂ θi φ l (∂ θt φ l -∂ θt φ k S k )|by the Triangle Inequality and 0 ≤ S l ≤ 1 ≤ α 2 S l C We let M = C k=1 |∂ θi φ l (∂ θt φ l -∂ θt φ k S k )| and note that M < ∞ for all 0 < α < ∞ since 0 ≤ S k ≤ 1, ∂ θi φ l , ∂ θt φ l , ∂ θt φ k are constants, and it is a finite sum. By Lemma D.1 as α → ∞ we have α 2 S l C → 0 and hence α 2 C k=1 ∂ θi φ l S l (∂ θt φ l -∂ θt φ k S k ) → 0. We now consider l = y: |α 2 ∂ θi φ y S y (∂ θt φ y -C k=1 ∂ θt φ k S k )| ≤ α 2 |∂ θi φ y |S y   |∂ θt φ y -∂ θt φ y S y )| + C k=1,k =y |∂ θt φ k S k |   α 2 |∂ θt φ y -∂ θt φ y S y )| = |∂ θt |α 2 T | =yS s since S s > 0 and using similar arguments and Lemma D.1 follows that this term is zero in the limit.It is also obvious that α 2 C k=1,k =y |∂ θt φ k S k | goes to zero. We are left with showing that α(∂ θt ∂ θi φ y -∂ θt ∂ θi φ, S goes to zero, this is only guaranteed when y is the true label). We will use the same method as above:|α(∂ θt ∂ θi φ y -∂ θt ∂ θi φ, S )| ≤ α(|∂ θt ∂ θi φ y -∂ θt ∂ θi φ y S y | + C l=1,l =y |∂ θt ∂ θi φ l S l |)and the result follows using again Lemma D.1.

The results of our trained fully connected network with batch-normalized layers on MNIST optimized with SGD and a 0.1 learning rate. The results reflect the observations made byKeskar et al. (2016). I.e. the small-batch method has a smaller Hessian trace and generalizes better.

Shai Shalev-Shwartz and Shai Ben-David. Understanding machine learning: From theory to algorithms. Cambridge university press, 2014. Daniel Soudry, Elad Hoffer, Mor Shpigel Nacson, Suriya Gunasekar, and Nathan Srebro. The implicit bias of gradient descent on separable data. The Journal of Machine Learning Research, 19 (1):2822-2878, 2018. The datasets used for this paper.

