TEMPERATURE CHECK: THEORY AND PRACTICE FOR TRAINING MODELS WITH SOFTMAX-CROSS-ENTROPY LOSSES

Abstract

The softmax function combined with a cross-entropy loss is a principled approach to modeling probability distributions that has become ubiquitous in deep learning. The softmax function is defined by a lone hyperparameter, the temperature, that is commonly set to one or regarded as a way to tune model confidence after training; however, less is known about how the temperature impacts training dynamics or generalization performance. In this work we develop a theory of early learning for models trained with softmax-cross-entropy loss and show that the learning dynamics depend crucially on the inverse-temperature β as well as the magnitude of the logits at initialization, ||βz|| 2 . We follow up these analytic results with a large-scale empirical study of a variety of model architectures trained on CIFAR10, ImageNet, and IMDB sentiment analysis. We find that generalization performance depends strongly on the temperature, but only weakly on the initial logit magnitude. We provide evidence that the dependence of generalization on β is not due to changes in model confidence, but is a dynamical phenomenon. It follows that the addition of β as a tunable hyperparameter is key to maximizing model performance. Although we find the optimal β to be sensitive to the architecture, our results suggest that tuning β over the range 10 -2 to 10 1 improves performance over all architectures studied. We find that smaller β may lead to better peak performance at the cost of learning stability.

1. INTRODUCTION

Deep learning has led to breakthroughs across a slew of classification tasks (LeCun et al., 1989; Krizhevsky et al., 2012; Zagoruyko and Komodakis, 2017) . Crucial components of this success have been the use of the softmax function to model predicted class-probabilities combined with the cross-entropy loss function as a measure of distance between the predicted distribution and the label (Kline and Berardi, 2005; Golik et al., 2013) . Significant work has gone into improving the generalization performance of softmax-cross-entropy learning. A particularly successful approach has been to improve overfitting by reducing model confidence; this has been done by regularizing outputs using confidence regularization (Pereyra et al., 2017) or by augmenting data using label smoothing (Müller et al., 2019; Szegedy et al., 2016) . Another way to manipulate model confidence is to tune the temperature of the softmax function, which is otherwise commonly set to one. Adjusting the softmax temperature during training has been shown to be important in metric learning (Wu et al., 2018; Zhai and Wu, 2019) and when performing distillation (Hinton et al., 2015) ; as well as for post-training calibration of prediction probabilities (Platt, 2000; Guo et al., 2017) . The interplay between temperature, learning, and generalization is complex and not well-understood in the general case. Although significant recent theoretical progress has been made understanding generalization and learning in wide neural networks approximated as linear models, analysis of linearized learning dynamics has largely focused on the case of squared error losses (Jacot et al., 2018; Du et al., 2019; Lee et al., 2019; Novak et al., 2019a; Xiao et al., 2019) . Infinitely-wide networks trained with softmax-cross-entropy loss have been shown to converge to max-margin classifiers in a particular function space norm (Chizat and Bach, 2020) , but timescales of convergence are not known. Additionally, many well-performing models operate best away from the linearized regime (Novak et al., 2019a; Aitchison, 2019) . This means that understanding the deviations of models from their linearization around initialization is important for understanding generalization (Lee et al., 2019; Chizat et al., 2019) . In this paper, we investigate the training of neural networks with softmax-cross-entropy losses. In general this problem is analytically intractable; to make progress we pursue a strategy that combines analytic insights at short times with a comprehensive set of experiments that capture the entirety of training. At short times, models can be understood in terms of a linearization about their initial parameters along with nonlinear corrections. In the linear regime we find that networks trained with different inverse-temperatures, β = 1/T , behave identically provided the learning rate is scaled as η = ηβ 2 . Here, networks begin to learn over a timescale τ z ∼ Z 0 2 /η where Z 0 are the initial logits of the network after being multiplied by β. This implies that we expect learning to begin faster for networks with smaller logits. The learning dynamics begin to become nonlinear over another, independent, timescale τ nl ∼ β/η, suggesting more nonlinear learning for small β. From previous results we expect that neural networks will perform best in this regime where they quickly exit the linear regime (Chizat et al., 2019; Lee et al., 2020; Lewkowycz et al., 2020) . We combine these analytic results with extensive experiments on competitive neural networks across a range of architectures and domains including: Wide Residual networks (Zagoruyko and Komodakis, 2017) on CIFAR10 (Krizhevsky, 2009) , ResNet-50 (He et al., 2016) on ImageNet (Deng et al., 2009) , and GRUs (Chung et al., 2014) on the IMDB sentiment analysis task (Maas et al., 2011) . In the case of residual networks, we consider architectures with and without batch normalization, which can appreciably change the learning dynamics (Ioffe and Szegedy, 2015) . For all models studied, we find that generalization performance is poor at Z 0 2 1 but otherwise largely independent of Z 0 2 . Moreover, learning becomes slower and less stable at very small β; indeed, the optimal learning rate scales like η * ∼ 1/β and the resulting early learning timescale can be written as τ * z ∼ Z 0 2 /β. For all models studied, we observe strong performance for β ∈ [10 -2 , 10 1 ] although the specific optimal β is architecture dependent. Emphatically, the optimal β is often far from 1. For models without batch normalization, smaller β can give stronger results on some training runs, with others failing to train due to instability. Overall, these results suggest that model performance can often be improved by tuning β over the range of [10 -2 , 10 1 ].

2. THEORY

We begin with a precise description of the problem setting before discussing a theory of learning at short times. We will show the following: • The inverse temperature β and logit scale Z 0 control timescales which determine the rate of change of the loss, the relative change in logits, and the time for learning to leave the linear learning regime. • Small β causes training to access the non-linear learning regime. We will see empirically that increasing access to the non-linear regime can improve generalization. • The largest allowable learning rate is set by the timescale to leave the linearized learning regime, which suggests that networks with small β will train more slowly. All numerical results in this section are using a Wide Resnet (Zagoruyko and Komodakis, 2017) trained on CIFAR10.

2.1. BASIC MODEL AND NOTATION

We consider a classification task with K classes. For an N dimensional input x, let z(x, θ) be the pre-softmax output of a classification model parameterized by θ ∈ R P , such that the classifier predicts the class i corresponding to the largest output value z i . We will mainly consider θ trained by SGD on a training set (X , Y) of M input-label pairs. We focus on models trained with cross-entropy loss with a non-trivial inverse temperature β. The softmax-cross-entropy loss can be written as L(θ, X , Y) = K i=1 Y i • ln(σ(βz i (X , θ))) = K i=1 Y i • ln(σ(Z i (X , θ))) where we define the rescaled logits Z = βz and σ(Z) i = e Zi / j e Zj is the softmax function. Here Z(X , θ) is the M × K dimensional matrix of rescaled logits on the training set. As we will see later, the statistics of individual σ(Z) i will have a strong influence on the learning dynamics. While the statistics of σ(Z) i are intractable for intermediate magnitudes, Z 2 , they can be understood in the limits of large and small Z 2 . For a fixed model z(x, θ), β controls the certainty of the predicted probabilities. Values of β such that β 1/ z 2 will give small values of Z 2 1, and the outputs of the softmax will be close to 1/K independent of i (the maximum-entropy distribution on K classes). Larger values of β such that β 1/ z 2 will lead to large values of Z 2 1; the resulting distribution has probability close to 1 on one label, and (exponentially) close to 0 on the others. The continuous time learning dynamics (exact in the limit of small learning rate) are given by: θ = ηβ K i=1 ∂z i (X , θ(t)) ∂θ T (Y i -σ(Z i (X , θ(t))) for learning rate η. We will drop the explicit dependence of Z i on θ from here onward, and we will denote time dependence as Z i (X , t) explicitly where needed. In function space, the dynamics of the model outputs on an input x are given by dz i (x) dt = ηβ K j=1 ( Θθ ) ij (x, X )(Y j -σ(Z j (X ))) where we define the M × M × K × K dimensional tensor Θθ , the empirical neural tangent kernel (NTK), as ( Θθ ) ij (x, X ) ≡ ∂z i (x) ∂θ ∂z j (X ) ∂θ T for class indices i and j, which is block-diagonal in the infinite-width limit. From Equation 3, we see that the early-time dynamics in function space depend on β, the initial softmax input Z(X , 0) on the training set, and the initial Θθ . Changing these observables across a model family will lead to different learning trajectories early in learning. Since significant work has already studied the effects of the NTK, here we focus on the effects of changing β and Z 0 F ≡ Z(X , 0) F (the norm of the M × K dimensional matrix of training logits), independent of Θθ .

2.2. LINEARIZED DYNAMICS

For small changes in θ, the tangent kernel is approximately constant throughout learning (Jacot et al., 2018) , and we drop the explicit θ dependence in this subsection. The linearized dynamics of z(x, t) only depend on the initial value of Θ and the β-scaled logit values Z(X , t). This suggests that there is a universal timescale across β and η which can be used to compare linearized trajectories with different parameter values. Indeed, if we define an effective learning rate η ≡ ηβ 2 , we have dZ i (x) dt = η K j=1 ( Θ) ij (x, X )(Y j -σ(Z j (X ))) which removes explicit β dependence of the dynamics. We note that a similar rescaling exists for the continuous time versions of other optimizers like momentum (Appendix B). The effective learning rate η is useful for understanding the nonlinear dynamics, as plotting learning curves versus ηt causes early-time collapse for fixed Z 0 across β and η (Figure 1 ). We see that there is a strong, monotonic, dependence of the time at which the nonlinear model begins to deviate from its linearization on β. We will return to and explain this phenomenon in Section 2.4. Unless otherwise noted, we will analyze all timescales in units of η instead of η, as it will allow for the appropriate early-time comparisons between models with different β. Loss β = 10 -3 β = 10 -2 β = 10 -1 β = 10 0 β = 10 1 β = 10 2 β = 10 3

Linearized

Figure 1 : For fixed initial training set logits Z 0 , plotting learning curves against ηt = β 2 ηt causes the learning curves to collapse to the learning curve of the linearized model at early times (right), in contrast to un-scaled curves (left). Models with large β follow linearized dynamics the longest.

2.3. EARLY LEARNING TIMESCALE

We now define and compute the early learning timescale, τ z , that measures the time it takes for the logits to change significantly from their initial value. Specifically, we define τ z such that for t τ z we expect Z(x, t)-Z(x, 0) F Z(x, 0) F and for t τ z , Z(x, t)-Z(x, 0) F ∼ Z(x, 0) F (or larger). This is synonymous with the timescale over which the model begins to learn. As we will show below, τ z ∝ Z 0 F /η. Therefore in units of η, τ z only depends on Z 0 F and not β. To see this, note that at very short times it follows from Equation 5 that Z i (x, t) -Z i (x, 0) ≈ η K j=1 ( Θ) ij (x, X )(Y j -σ(Z j (X )))t + O(t 2 ) It follows that we can define a timescale over which the logits (on the training set) change appreciably from their initial value as τ z ≡ 1 η Z 0 F Θ(X , X )(Y -σ(Z 0 )) F . ( ) where the norms are once again taken across all classes as well as training points. This definition has the desired properties for t τ z and t τ z . In units of η, τ z depends only on Z 0 F , in two ways. The first is a linear scaling in Z 0 F ; the second comes from the contribution from the gradient Θ(X , X )(Y -σ(Z(X , 0))) F . As previously discussed, since σ(Z 0 ) saturates at small and large values of Z 0 F , it follows that the gradient term will also saturate for large and small Z 0 F , and the ratio of saturating values is some O(1) constant independent of Z 0 F and β. 10 -5 10 -3 10 -1 10 1 10 3 10 5  ||Z 0 || F 10 -2 10 -1 τ z /||Z 0 || F β = 1.00e -03 β = 1.00e -02 β = 1.00e -01 β = 1.00e + 00 β = 1.00e + 01 β = 1.00e + 02 β = 1.00e + 03 ||Z 0 ||F = 10 -5 ||Z 0 ||F = 10 -4 ||Z 0 ||F = 10 -3 ||Z 0 ||F = 10 -2 10 -4 10 -3 10 -2 10 -1 10 0 10 1 10 2  ||Z 0 ||F = 10 -5 ||Z 0 ||F = 10 -4 ||Z 0 ||F = 10 -3 ||Z 0 ||F = 10 -2 Figure 2: The timescale τ z depends only on Z 0 F in units of η = β 2 η (left, inset). τ z depends linearly on Z 0 F , up to an O(1) coefficient which saturates at large and small Z 0 F (left, main). Accuracy increases more quickly for small initial Z 0 F , though late time dynamics are similar (center). Rescaling time to t/τ z causes early accuracy curves to collapse (right). The quantitative and conceptual nature of τ z can both be confirmed numerically. We compute τ z explicitly as ||[Z i (x, t) -Z i (x, 0)]/t|| for short times (Figure 2, left) . When plotted over a wide range of Z 0 F and β, the ratio τ z / Z 0 F (in rescaled time units) undergoes a saturating, O(1) variation from small to large Z 0 F . The quantitative dependence of the transition on the NTK is confirmed in Appendix C. Additionally, for fixed β and varying Z 0 F , rescaling time by 1/τ z causes accuracy curves to collapse at early times (Figure 2 , middle), even if they are very different at early times without the rescaling (right). We note here that the late time accuracy curves seem similar across Z 0 F without rescaling, a point which we will return to in Section 3.2.

2.4. NONLINEAR TIMESCALE

While linearized dynamics are useful to understand some features of learning, the best performing networks often reside in the nonlinear regime (Novak et al., 2019a) . Here we define the nonlinear timescale, τ nl , corresponding to the time over which the network deviates appreciably from the linearized equations. We will show that τ nl ∝ β/η. Therefore, in terms of β and Z 0 F , networks with small β will access the nonlinear regime early in learning, while networks with large β will be effectively linearized throughout training. We note that a similar point was raised in Chizat et al. (2019) , primarily in the context of MSE loss. We define τ nl to be the timescale over which the change in Θθ (which contributes to the second order term in Equation 6) can no longer be neglected. Examining the second time derivative of Z, we have d 2 Z i dt 2 = η K j=1      -( Θθ ) ij (x, X ) d dt σ(Z(X )) linearized dynamics + d dt ( Θθ ) ij (x, X ) (Y j -σ(Z j (X )) nonlinearized dynamics ≡ Znl      (8) The first term is the second derivative under a fixed kernel, while the second term is due to the change in the kernel (neglected in the linearized limit). A direct calculation shows that the second term, which we denote Znl , can be written as ( Znl ) i = β -1 η2 K j=1 K k=1 (Y k (X ) -σ(Z k (X )) T ∂z k (X ) ∂θ • ∂ ∂θ [ Θθ ] ij (Y j (X ) -σ(Z j (X ))) (9) This gives us a nonlinear timescale τ nl defined, at initialization, by τ nl ≡ Ż(X , 0) F / Znl (X , 0) F . We can interpret τ nl as the time it takes for changes in the kernel to contribute to learning. Though computing Znl (X , 0) F in exactly is analytically intractable, its basic scaling in terms of β and Z 0 F (and therefore, that of τ nl ) is computable. We first note the explicit β -1 η2 dependence. The remaining terms are independent of β and vary by at most O(1) with Z 0 F ; indeed as described above, Y(X ) -σ(Z(X , 0)) F saturates for large and small Z 0 F . Morevoer, the derivative, ∂z(X ,0) ∂θ , is the square root of the NTK and, at initialization, it is independent of Z 0 F . Together with our analysis of τ z we have that, up to some O(1) dependence on Z 0 F , τ nl ∝ β/η. Therefore, the degree of nonlinearity early in learning is controlled via β alone. Once again we can confirm the quantitative and conceptual understanding of τ nl numerically. Qualitatively, we see that for fixed Z 0 F , models with smaller β deviate sooner from the linearized dynamics when learning curves are plotted against ηt (Figure 1 ). We compute τ nl explicitly by taking [Z(t) -Z lin (t)]/t 2 as small times, where Z lin (t) is the solution to Equation 5. τ nl /β has an O(1) dependence on Z 0 F only (Figure 3 ). ||Z 0 || F 10 1 10 2 2 × 10 1 3 × 10 1 4 × 10 1 6 × 10 1 τ nl /β β = 10 -3 β = 10 -2 β = 10 -1 β = 10 0 β = 10 1 β = 10 2 β = 10 3

2.5. LEARNING RATES AND LEARNING SPEED

The timescales, τ z and τ nl , can be combined to gain information about training with SGD. Indeed, the largest allowable learning rate is controlled by the curvature of the loss function Du et al. (2019) ; Allen-Zhu et al. (2019) . A necessary condition for training is that the curvature be small compared with the step size, which happens when τ nl ≤ c, where c is an architecture-dependent constant of O(1) which has not yet been calculated theoretically (see Lewkowycz et al. (2020) for empirical calculations of c). This predicts a maximum effective learning rate η of O(β) which in turn implies a raw learning rate of O(β -1 ) (as η = η/β 2 ). However, if η is O(β), then τ z is O(β -1 ). This means that, for early learning, networks with smaller β will take more SGD steps to reach an appreciable change in logits. Therefore, we predict that networks with small β take longer to train (which we see in practice as well).

3.1. OPTIMAL LEARNING RATE

We begin our empirical investigation by training wide resnets (Szegedy et al., 2016) without batch normalization on CIFAR10, as this architecture is well within the regime where our theory applies. In order to understand the effects of the different timescales on learning, we control β and Z 0 F independently by using a correlated initialization strategy outlined in Appendix D.1. Before considering model performance, it is first useful to understand the scaling of the learning rate with β. We define the optimal learning rate η * as the learning rate with the best generalization performance. To do this, we initialize networks with different β and conduct learning rate sweeps for each β. The optimal learning rate η * has a clear 1/β dependence (Figure 4 ). This matches the prediction in Section 2.5, suggesting the maximum learning rate corresponds to the regime where the non-linear effects become important at the fastest rate for which training still converges. Again, as predicted, networks with smaller β also learn more slowly in terms of number of SGD steps. Thus, at small β we expect learning to take place slowly and nonlinear effects to become important by the time the function has changed appreciably. At large β, by contrast, our results suggest that the network will have learned a significant amount before the dynamics become appreciably nonlinear.

3.2. PHASE PLANE

In the preceding discussion two quantities emerged that control the behavior of early-time dynamics: the inverse-temperature, β, and the rescaled logits Z 0 F . In attempting to understand the behavior of real neural networks trained using softmax-cross-entropy loss, it therefore makes sense to try to reason about this behavior by considering neural networks that span the β -Z 0 F phase plane, the space of allowable pairs (β, Z 0 F ). By construction, the phase plane is characterized by the timescales involved in early learning. To summarize, τ z ∼ Z 0 F /η sets the timescale for early learning, with larger values of Z 0 F leading to longer time before significant accuracy gains are made (Section 2.3). Meanwhile, τ nl ∼ β/η controls the timescale for learning dynamics to leave the linearized regime -with small β leading to immediate departures from linearity, while models with large β may stay linearized throughout their learning trajectories (Section 2.4). In Figure 5 (a), we show a schematic of the phase plane. The colormap shows the test performance of a wide residual network (Zagoruyko and Komodakis, 2017) , without batch normalization, trained on CIFAR10 in different parts of the phase plane. The value of β makes a large difference in generalization, with optimal performance achieved at β ≈ 10 -2 . In general, larger β performed worse than small β as expected. Moreover, we observe similar generalization for all sufficiently large β; this is to be expected since models in this regime are close to their linearization throughout training (see Figure 1 ) and we expect the linearized models to have β-independent performance. Generalization was largely insensitive to Z 0 F so long as the network was sufficiently well-conditioned to be trainable. This suggests that long term learning is insensitive to τ z . In Figure 5 (b), we plot the accuracy after 20 steps of optimization (with the optimal learning rate). For fixed Z 0 F , the training speed was slow for the smallest β and then became faster with increasing β. For fixed β the training speed was fastest for small Z 0 F and slowed as Z 0 F increased. Both these phenomena were predicted by our theory and shows that both parameters are important in determining the early-time dynamics. However, we note that the relative accuracy across the phase plane at early times did not correlate with the performance at late times. This highlights that differences in generalization are a dynamical phenomenon. Another indication of this fact is that at the end of training, at time t f , the final training set logit values Z f F ≡ Z(X , t f ) F tend towards 1 independent of the initial β and Z 0 F (Figure 5 , (c)). With the exception of the poorly-performing large Z 0 F regime, the different models reach similar levels of certainty by the end of training, despite having different generalization performances. Therefore generalization is not well correlated with the final model certainty (a typical motivation for tuning β). Under review as a conference paper at ICLR 2021

3.3. ARCHITECTURE DEPENDENCE OF THE OPTIMAL β

Having demonstrated that β controls the generalization performance of neural networks with softmaxcross-entropy loss, we now discuss the question of choosing the optimal β. Here we investigate this question through the lens of a number of different architectures. We find the optimal choice of β to be strongly architecture dependent. Whether or not the optimal β can be predicted analytically is an open question that we leave for future work. Nonetheless, we show that all architectures considered display optimal β between approximately 10 -2 and 10 1 . We observe that by taking the time to tune β it is often the case that performance can be improved over the naive setting of β = 1.

3.3.1. WIDE RESNET ON CIFAR10

10 -4 10 -3 10 -2 10 -1 10 0 10 1 10 2 , learning is less stable, as evidenced by low average performance but high maximum performance (over 10 random seeds). (c) We see similar phenomenology on the IMDB sentiment analysis task trained with GRUs -where average-case best performance is near β = 1 but peak performance is at small β. In Figure 6 (a) we show the accuracy against β for several wide residual networks whose weights are drawn from normal distributions of different variances, σ 2 w , trained without batchnorm, as well as a network with σ 2 w = 1 trained with batchnorm (averaged over 10 seeds). The best average performance is attained for β < 1, σ w = 1 without batchnorm, and in particular networks with large σ w are dramatically improved with β tuning. The network with batchnorm is better at all β, with optimal β ≈ 10. However, we see that the best performing seed is often at a lower β (Figure 6 (b)), with larger σ w networks competitive with σ w = 1, and even with batchnorm at fixed β (though batchnorm with β = 10 still performs the best). This suggests that small β can improve best case performance, at the cost of stability. Our results emphasize the importance of tuning β, especially for models that have not otherwise been optimized.

3.3.2. RESNET50 ON IMAGENET

Table 1 : Accuracy on Imagenet dataset for ResNet-50. Tuning β significantly improves accuracy.

Method

Accuracy (%) ResNet-50 (Ghiasi et al., 2018) 76.51 ± 0.07 ResNet-50 + Dropout (Ghiasi et al., 2018) 76.80 ± 0.04 ResNet-50 + Label Smoothing (Ghiasi et al., 2018) 77.17 ± 0.05 ResNet-50 + Temperature check (β = 0.3) 77.37 ± 0.02 Motivated by our results on CIFAR10, we experimentally explored the effects of β as a tunable hyperparameter for ResNet-50 trained on Imagenet. We follow the experimental protocol established by (Ghiasi et al., 2018) . A key difference between this procedure and standard training is that we train for substantially longer: the number of training epochs is increased from 90 to 270. Ghiasi et al. (2018) found that this longer training regimen was beneficial when using additional regularization. Table 1 shows that scaling β improves accuracy for ResNet-50 with batchnorm. However, we did not find that using β < 1 was optimal for ResNet-50 without normalization. This further emphasizes the subtle architecture dependence that warrants further study.

3.3.3. GRUS ON IMDB SENTIMENT ANALYSIS

To further explore the architecture dependence of optimal β, we train GRUs (from Maheswaranathan et al. ( 2019)) whose weights are drawn from two different distributions on an IMDB sentiment analysis task that has been widely studied (Maas et al., 2011) . We plot the results in Figure 6 (c) and observe that the results look qualitatively similar to the results on CIFAR10 without batch normalization. We observe a peak performance near β ∼ 1 averaged over an ensemble of networks, but we observe that smaller β can give better optimal performance at the expense of stability.

3.4. PROPOSED TUNING PROCEDURE

Our results suggest the following tuning procedure for networks trained with SGD/momentum: • Train/tune a model as normal. Note the optimal learning rate η 0 . • For the best parameter set, sweep over β ∈ [10 -2 , 10 1 ], scaling learning rate as η 0 /β. • If best performing β is at an endpoint of the range, continue tuning. If there is sufficient compute, the β search can instead be folded into the overall hyperparameter tuning. We note that our observation of optimal β < 1 in classification settings stands in contrast to the observation of optimal β > 1 in the Bayesian inference setting (Wenzel et al., 2020) . We speculate that the differences are related to the fact that in Bayesian inference, the uncertainty of estimates are important, and that the objective (learn an SDE which reproduces a target distribution) may have qualitatively different properties than classification tasks.

4. CONCLUSIONS

Our empirical results show that tuning β can yield sometimes significant improvements to model performance. Perhaps most surprisingly, we observe gains on ImageNet even with the highlyoptimized ResNet50 model. Our results on CIFAR10 suggest that the effect of β may be even stronger in networks which are not yet highly-optimized, and results on IMDB show that this effect holds beyond the image classification setting. It is possible that even more gains can be made by more carefully tuning β jointly with other hyperparameters, in particular the learning rate schedule and batch size. One key lesson of our theoretical work is that properties of learning dynamics must be compared using the right units. For example, τ nl ∝ 1/βη, which at first glance suggests that models with smaller β will become nonlinear more slowly than their large β counterparts. However, analyzing τ nl with respect to the effective learning rate η = β 2 η yields τ nl ∝ β/η. Thus we see that, in fact, networks with smaller β tend to become more non-linearized before much learning has occurred, compared to networks with large β which can remain in the linearized regime throughout training. Our numerical results confirm this intuition developed using the theoretical analysis. As discussed above, our analysis suggests a range of good β, but does not predict the optimal value. Architecture-dependent, non-scaling multiplicative factors to key learning parameters have been observed in other contexts (Lewkowycz et al., 2020) , and their numerical estimation is in general difficult. Extending the theoretical results to make predictions about these quantities is an interesting avenue for future work. Another area that warrants further study is the instability in training at small β.

A LINEARIZED LEARNING DYNAMICS

A.1 FIXED POINTS For the linearized learning dynamics, the trajectory z(x, t) can be written in terms of the trajectories of the training set as z(x, t) -z(x, 0) = Θ(x, X ) Θ+ (X , X )(z(X , t) -z(X , 0)) ( ) where + is the pseudo-inverse. Therefore, if one can solve for z(X , t), then in principle properties of generalization are computable. However, in general Equation 3does not admit an analytic solution even for fixed Θ, in contrast to the case of mean squared loss. It not even have an equilibrium -if the model can achieve perfect training accuracy, the logits will grow indefinitely. However, there is a guaranteed fixed point if the appropriate L 2 regularization is added to the training objective. Given a regularizer 1 2 λ θ δθ 2 on the change in parameters δθ = θ(t) -θ(0), the dynamics in the linearized regime are given by ż(x) = βη Θ(x, X )(Y(X ) -σ(βz(X ))) -λ θ δz(x) where the last term comes from the fact that ∂z ∂θ δθ = δz(x) in the linearized limit. We can write down self-consistent equations for equilibria, which are approximately solvable in certain limits. For an arbitrary input x, the equilibrium solution z * (x) is 0 = β Θ(x, X )(Y(X ) -σ(βz * (X ))) -λ θ δz * (x) This can be rewritten in terms of the training set as δz * (x) = Θ(x, X ) Θ+ (X , X )z * (X ) similar to kernel learning. It remains then to solve for z * (X ). We have: δz * (X ) = β λ θ Θ(X , X )[Y(X ) -σ(βz * (X ))] We immediately note that the solution depends on the initialization. We assume z(x, 0) = 0, so δz = z in order to simplify the analysis. The easiest case to analyze is when βz * (X ) F 1. Then we have: z * (X ) = β λ θ Θ(X , X ) Y(X ) - 1 K (1 + βz * (X )) which gives us z * (X ) = β λ θ 1 + β Kλ θ Θ(X , X ) -1 Θ(X , X )(Y(X ) -1/K) (16) Therefore the self-consistency condition for this solution is β λ θ Θ F 1, which simplifies the solution to z * (X ) = β λ θ Θ(X , X )(Y(X ) -1/K) (17) This is equivalent to the solution after a single step of (full-batch) SGD with appropriate learning rate. We note that unlike linearized dynamics with L 2 loss and a full-rank kernel, there is no guarantee that the solution converges to 0 training error. The other natural limit is βz * (X ) 2 1. We focus on the 2 class case, in order to take advantage of the conserved quantity of learning with cross-entropy loss. We note that the vector on the right hand side of Equation 3 sums to 1 for every training point. Suppose at initialization, Θθ has no logit-logit interactions, as is the case for most architectures in the infinite width limit with random initialization. More formally, we can write Θθ = Id K×K ⊗ Θx where Θx is M × M . Then, the sum of the logits for any input x is conserved during linearized training, as we have: 1 T ż(x) = ηβ1 T Id K×K ⊗ Θx (Y -σ(βz(X ))) (18) 1 T ż(x) = ηβ Θx 1 T (Y -σ(βz(X ))) = 0 (Note that if Θθ has explicit dependence on the logits, there still is a conserved quantity, which is more complicated to compute.) Now we can analyze βz * (X ) F 1. With two classes, and z(X ) = 0 at initialization, we have z * 1 = -z * 2 . Therefore, without loss of generality, we focus on z * 1 , the logit of the first class. In this limit, the leading order correction to the softmax is approximately: σ(βz * 1 ) ≈ 1 z * 1 >0 -sign(z * 1 )e -2β|z * 1 | (20) The self-consistency equation is then: z * 1 (X ) = β λ θ Θ(X , X ) Y(X ) -1 z * 1 >0 + sign(z * 1 )e -2β|z * 1 | (21) The vector on the right hand side has entries that are O(e -2β|z * 1 | ) for correct classifications, and O(1) for incorrect ones. If we assume that the training error is 0, then we have: z * 1 (X ) = β λ θ Θ(X , X )sign(z * 1 )e -2β|z * 1 | This is still non-trivial to solve, but we see that the self consistency condition is that ln(β|| Θ|| F /λ θ ) 1. Here also it may be difficult to train and generalize well. The individual elements of the right-handside vector are broadly distributed due to the exponential -so the outputs of the model are sensitive to/may only depend on a small number of datapoints. Even if the equilibrium solution has no training loss, generalization error may be high for the same reasons. This suggests that even for NTK learning (with L 2 regularization), the scale of ||βz|| plays an important role in both good training accuracy and good generalization. In the NTK regime, there is one unique solution so (in the continuous time limit) the initialization doesn't matter; rather, the ratio of β and λ θ (compared to the appropriate norm of Θ) needs to be balanced to prevent falling into the small βz regime (where training error might be large) or the large βz regime (where a few datapoints might dominate and reduce generalization).

A.2 DYNAMICS NEAR EQUILIBRIUM

The dynamics near the equilibrium can be analyzed by expanding around the fixed point equation. We focus on the dynamics on the training set. The dynamics of the difference z(X ) = z(X ) -z * (X ) for small perturbations is given by ż(X ) = -η β 2 [Id z ⊗ Θ(X , X )]σ z (βz * (X )) + λ θ z(X ) where σ z is the derivative of the softmax matrix σ z (z) ≡ ∂σ(z) ∂z = diag(σ(z)) -σ(z)σ(z ) T We can perform some analysis in the large and small β cases (once again ignoring λ z ). For small βz * (X ) F , we have β λ θ Θ F 1 which leads to: σ z (βz * (X )) = (1/K -11 T /K 2 ) + O(β Θ) This matrix has K -1 eigenvalues with value 1/K, and one zero eigenvalue (corresponding to the conservation of probability). Therefore β 2 [Id z ⊗ Θ(X , X )]σ z (βz * (X )) F λ θ , and the well-conditioned regularizer dominates the approach to equilibrium. In the large β case (ln(β|| Θ||/λ θ ) 1), the values of σ(βz(X )) are exponentially close to 0 (K -1 values) or 1 (the value corresponding to the largest logit). This means that σ z (βz(X )) has exponentially small values in βz(X ) F -if any one of σ(βz i (X )) and σ(βz j (X )) is exponentially small, the corresponding element of σ z (βz(X )) is as well; for the largest logit i the diagonal is σ(βz i (X ))(1 -σ(βz i (X ))) which is also exponentially small. From Equation 22, we have λ θ β 2 e 2β|z * 1 | ; therefore, though the σ z term of H is exponentially small, it dominates the linearized dynamics near the fixed point, and the approach to equilibrium is slow. We will analyze the conditioning of the dynamics in the remainder of this section.

A.3 CONDITIONING OF DYNAMICS

Understanding the conditioning of the linearized dynamics requires understanding the spectrum of the Hessian matrix H = Id z ⊗ Θ(X , X ) σ z (βz * (X )). In the limit of large model size, the first factor is block-diagonal with training set by training set blocks (no logit-logit interactions), and the second term is block-diagonal with K × K blocks (no datapoint-datapoint interactions). We will use the following lemma to get bounds on the conditioning: Lemma: Let M = AB be a matrix that is the product of two matrices. The condition number κ(M) ≡ λ M,max λ M,min has bound κ(B)/κ(A) ≤ κ(M) ≤ κ(A)κ(B) Proof: Consider the vector v that is the eigenvector of B associated with λ B,min . Note that ||Av||/||v|| ≤ λ A,max . Analogously, for w, the eigenvector associated with λ B,max , ||Aw||/||w|| ≥ λ A,min . This gives us the two bounds: λ M,min ≤ λ A,max λ B,min , λ M,max ≥ λ A,min λ B,max This means that the condition number κ(H) ≡ λ M,max λ M,min is bounded by κ(M) ≥ λ A,max λ B,min λ A,min λ B,max = κ(B)/κ(A) In total, we have the bound of Equation 26, where the upper bound is trivial to prove. In particular, this means that a poorly conditioned σ z (βz * (X )) will lead to poor conditioning of the linearized dynamics if the NTK Θ(X , X ) is (relatively) well conditioned. This bound will be important in establishing the poor conditioning of the linearized dynamics for the large logit regime ||βz|| 1.

A.3.1 SMALL LOGIT CONDITIONING

For βz * (X ) F 1, the Hessian H is H = 1 K 1 - 1 K 11 T ⊗ Θ(X , X ) Since H is the Kroenecker product of two matrices, the condition numbers multiply, and we have κ(H) = κ( Θ) which is well-conditioned so long as the NTK is. Regardless, the well-conditioned regularization due to λ θ dominates the approach to equilibrium.

A.3.2 LARGE LOGIT CONDITIONING

Now consider βz * (X ) F 1. Here we will show that the linearized dynamics is poorly conditioned, and that κ(H) is exponentially large in β. We first try to understand σ z (βz * (x)) for an individual x ∈ X . To 0th order (in an as-of-yetundefined expansion), σ z is zero -at large temperature the softmax returns either 0 or 1, which by Equation 24gives 0 in all entries. The size of the corrections end up being exponentially dependent on βz * F ; the entries will have broad, log-normal type distributions with magnitudes which scale

B.2 CONTINUOUS TIME EQUATIONS

We can write down the continuous time version of the learning dynamics as follows. For SGD, for small learning rates we have: dθ dt = -ηg For the momentum equations we have dv dt = -γv -g (42) dθ dt = ηv From these equations, we can see that in the continuous time limit, there are coordinate transformations which can be used to cause sets of trajectories with different parameters to collapse to a single trajectory. SGD is the simplest, where rescaling time to τ ≡ ηt causes learning curves to be identical for all learning rates. For momentum, instead of a single universal learning curve, there is a one-parameter family of curves controlled by the ratio T mom ≡ η/γ 2 . Consider rescaling time to τ = at and ν = bv, where a and b will be chosen to put the equations in a canonical form. In our new coordinates, we have dν dτ = -(γ/a)ν -(b/a)g (44) dθ dτ = ην/(ab) The canonical form we choose is dν dτ = -λν -g dθ dτ = ν (47) From which we arrive at a = b = √ η, which gives us λ = γ/ √ η. Note that this is not a unique canonical form; for example, if we fix a coefficient of -1 on ν, we end up with dν dτ = -ν -(η/γ 2 )g (48) dθ dτ = ν with a = γ. This is a different time rescaling, but still controlled by T mom . Working in the canonical form of Equations 46 and 47, we can analyze the dynamics. One immediate question is the difference between λ 1 and λ 1. We note that the integral equation ν(τ ) = ν(0) + τ 0 e -λ(τ -τ ) g(τ )dτ solves the differential equation for ν. Therefore, for λ 1, ν(t) only depends on the current value g(t) and we have ν(τ ) ≈ g(τ )/λ. Therefore, we have, approximately: dθ dτ ≈ 1 λ g This means that for large λ all the curves will approximately collapse, with timescale given by √ ηλ -1 = γη (dynamics similar to SGD). For λ 1, the momentum is essentially the integrated gradient across all time. If ν(0) = 0, then we have dθ dτ ≈ τ 0 g(τ )dτ In this limit, θ(τ ) is the double integral of the gradient with respect to time. Given the form of the controlling parameter T mom , we can choose to parameterize γ = γ√ η. Under this parameterization, we have T mom = γ2 . The dynamical equations then become: dν dτ = -γν -g (53) dθ dτ = ν which automatically removes explicit dependence on η. One particular regime of interest is the early-time dynamics, starting from ν(0) = 0. Integrating directly, we have: θ(τ ) = - 1 2 gτ 2 + 1 6 γgτ 3 + . . . This means that τ alone is the correct timescale for early learning, at least until τ γ ∼ 1 -which in the original parameters corresponds to t ∼ 1/γ (the time it takes for the momentum to be first "fully integrated").

B.3 DETAILED ANALYSIS OF MOMENTUM TIMESCALES

One important subtlety is that 1 is not the correct value to compare λ to. The real timescale involved is the one over which g changes significantly. We can approximate this in the following way. Suppose that there is some relative change ∆ θ ||θ|| ∼ c of the parameters that leads to an appreciable relative change in g. Then the timescale over which θ changes by that amount is the one we must compare λ to. We can compute that timescale in the following way. We assume g fixed for what follows. Therefore, Equation 51 approximately holds. The timescale τ c of the change is then given by: ∆ θ ||θ|| = 1 λ ||g|| ||θ|| τ c ∼ c (56) which gives τ c ∼ cλ||θ||/||g|| In particular, this means that the approximation is good when λτ c 1, which gives γ 2 /η ||g|| ||θ||the former being a function of the dynamical parameters, the latter being a function of the geometry of L with respect to θ. One consequence of this analysis is that if the ||θ|| remains roughly constant, for fixed η and γ, late in learning when the gradients become small the dynamics shifts into the regime where λ is large, and we effectively have SGD.

B.4 CONNECTING DISCRETE AND CONTINUOUS TIME

One use for the form of the continuous time rescalings is to use them to compare learning curves for the actual discrete optimization that is performed with different learning rates. For small learning rates, the curves are expected to collapse, while for larger learning rates the deviations from the continuous expectation can be informative. With momentum, we only have perfect collapse when γ and η are scaled together. However, one typical use case for momentum is to fix the parameter γ, and sweep through different learning rates. With this setup, if g is changing slowly compared to γ (more precisely, γ 2 /η ||g||/||θ||), as may be the case at later training times, the change in parameters from a single step is ∆ θ ∼ (η/γ))||g|| and the rescaling of taking t to ηt (as for SGD) collapses the dynamics. Therefore given a collection of varying η, but fixed γ curves, it is possible to get intermediate and late time dynamics on the same scale. However, at early times, while the momentum is still getting "up to speed" (i.e. in the first 1/γ steps), the appropriate timescale is η -1/2 . Therefore, in order to get learning curves to collapse across different η at early times, we need to rescale γ with η as implied by Equations 46 and 47. Namely, one must fix γ and rescale γ = γ√ η. We note that, since γ < 1, this gives us a restriction η < γ-2 for the maximum learning that can be supported by the rescaled momentum parameter.

B.5 MOMENTUM EQUATIONS WITH SOFTMAX-CROSS-ENTROPY LEARNING

For cross-entropy learning with softmax inputs βz, all the scales acquire dependence on β. If we define z ≡ ∂L ∂βz and g z ≡ ∂z ∂θ F , then we have, approximately, ||g|| ≈ β z g z . Consider the goal of obtaining identical early-time learning curves for different values of β. (The curves are only globally consistent across β in the linearized regime.) In order to get learning curves to collapse, we want dL dτ to be independent of β in the rescaled time units. We note that the change in the loss function ∆ L from a single step of SGD goes as ∆ L ∼ ηβ 2 2 z g 2 z ( ) This suggests that one way to collapse learning curves is to plot them against the rescaled learning rate ηt, where η = ηβ 2 . While hyperparameter tuning across β, one could use η = η/β 2 , sweeping over η in order to easily obtain comparable learning curves. However, a better goal for a learning rate rescaling is to try and stay within the continuous time limitthat is, to control the change in parameters ∆ θ for a single step to be small across β. We have ∆ θ ∼ ηβ z g z which suggests that maximum allowable learning rates will scale as 1/β. This suggests setting η = ηβ -1 , and rescaling time as ηβ in order to best explore the continuous learning dynamics. We can perform a similar analysis for the momentum optimizers. We begin by analyzing the continuous time equations for the dynamics of the loss.  dν dβτ = - γ β ν -g (63) dθ dβτ = 1 β ν (64) dL dβτ = g • ν (65) This rescaling causes a collapse of the trajectories of the L at early times if γ/β is constant for varying β. One scheme to arrive at the final canonical form, across β, is by the following definitions of η, γ, and τ : • η = ηβ -2 • γ ≡ β √ ηγ = √ ηγ • τ ≡ β √ ηt = √ ηt where curves with fixed γ will collapse. The latter two equations are similar to before, except with η replaced with η. The dynamical equations are then: dν dτ = -γν -g (66) dθ dτ = 1 β ν (67) dL dτ = g • ν (68) The change in parameters from a single step (assuming constant g and saturation) is ∆ θ = ||g|| γβ η If we instead want the change in parameters from a single step to be invariant of β so the continuous time approximation holds, while maintaining collapse of trajectories, we first note that ∆ θ ∼ √ η γ β z g z (70) from a single step of the momentum optimizer. To keep ∆ θ invariant of β, we can set: • η = ηβ -1 • γ ≡ β √ ηγ = √ ηγ = β 1/2 √ ηγ • τ ≡ β √ ηt = β 1/2 √ ηt Note that the relationship between γ and γ is the same in both schemes when measured with respect to the raw learning rate η.

C SOFTMAX-CROSS-ENTROPY GRADIENT MAGNITUDE C.1 MAGNITUDE OF GRADIENTS IN FULLY-CONNECTED NETWORKS

The value of τ z has nontrivial (but bounded) dependence on Z 0 F via the Θ(Y -σ(Z 0 (X ))) F term in Equation 7. We can confirm the dependence for highly overparameterized models by using the theoretical Θ. In particular, for wide neural networks, the tangent kernel is block-diagonal in the logits, and easily computable. The numerically computed τ z / Z 0 F correlates well with Θ(Y -σ(Z 0 (X ))) -1 F for wide (2000 hidden units/layer) fully connected networks (Figure 7 ). The ratio depends on details like the nonlinearities in the network; for example, Relu units tend to have a larger ratio than Erf units (left and middle). The ratio also depends on the properties of the dataset. For example, the ratio increases on CIFAR10 when the training labels are randomly shuffled (right). Therefore in general the ratio of τ z / Z 0 F at large and small Z 0 F depends subtly on the relationship between the NTK and the properties of the data distribution. A full analysis of this relationship is beyond the scope of this work. The exact form of the transition is likely even more sensitive to these properties and is therefore harder to analyze than the ratio alone.

D EXPERIMENTAL DETAILS D.1 CORRELATED INITIALIZATION

In order to avoid confounding effects of changing β and Z 0 F with changes to Θ, we use a correlated initialization strategy (similar to (Chizat et al., 2019) ) which fixes Θ while allowing for independent variation of β and Z 0 F . Given a network with final hidden layer h(x, θ) and output weights W O , we define a combined network z c (x, θ) explicitly as for correlation coefficient c ∈ [-1, 1], where δ ij is the Kronecker-delta which is 1 is i = j and 0 otherwise. Under this approach, the initial magnitude of the training set logits is given by Z 0 F = β (1 + c) z 0 F , where z 0 F is the initial magnitude of the logits of the base model. By manipulating β and c, we can independently change β and Z 0 F with the caveat that Z 0 F ≤ √ 2β z 0 F since c ≤ 1. It follows that the small β, large Z 0 F region of the phase plane (upper left in Figure 5 ) is inaccessible with most well-conditioned models where z 0 F ∼ 1 at initialization. If we only train one set of weights, the Θ is independent of c. Unless otherwise noted, all empirical studies in Sections 2 and 3.2 involve training a wide resnet on CIFAR10 with SGD, using GPUs, using the above correlated initialization strategy to fix Θ. All experiments used JAX (Bradbury et al., 2018) and experiments involving linearization or direct computation of the NTK used Neural Tangents (Novak et al., 2019b) .



Figure 3: The time to deviation from linearized dynamics, τ nl , has large deviation over β and Z 0 F (left), which can be largely explained by linear dependence on β (right), in units of η = β 2 η. There is an O(1) dependence on Z 0 F which is consistent across varying β for fixed Z 0 F .

Figure 5: Properties of early learning dynamics, which affect generalization, can be determined by location in the β-Z 0 F phase plane (a). At optimal learning rate η * , small β and larger Z 0 F leads to slower early learning (b), and larger β increases time before nonlinear dynamics contributes to learning. Large Z 0 F has poorly conditioned linearized dynamics. Generalization for a wide resnet trained on CIFAR10 is highly sensitive to β, and relatively insensitive to Z 0 F outside poor conditioning regime. Final logit variance is relatively insensitive to parameters (c).

Figure6: Dependence of test accuracy for various architectures with β tuning. (a) For WRN with batchnorm, trained on CIFAR10, the optimal β ≈ 10. Without batchnorm, the performance of the network can be nearly recovered with β-scaling alone with β ≈ 10 -2 . Even poorly conditioned networks (achieved by increasing weight scale σ w ) recover performance. (b) For β < 10 -2 , learning is less stable, as evidenced by low average performance but high maximum performance (over 10 random seeds). (c) We see similar phenomenology on the IMDB sentiment analysis task trained with GRUs -where average-case best performance is near β = 1 but peak performance is at small β.

Figure 7: τ z / Z 0F is highly correlated with Θ(Y -σ(Z 0 )) -1 , with Θ computed in the infinite width limit (in units of effective learning rate η = β 2 η). Ratio between normalized timescales at large and small Z 0 F depends on nonlinearity (left and middle), as well as training set statistics (right, CIFAR10 with shuffled labels).

annex

as exp(-β|z * 1 |). There will be two scaling regimes one with a small number of labels in the sense √ β ln(K), where the largest logit dominates the statistics, and one where the number of labels is large (and the central limit theorem applies to the partition function). In both cases, however, there is still exponential dependence on β; we will focus on the first which is easier to analyze and more realistic (e.g. for 10 6 labels "large" β is only ∼ 15).Let z 1 be the largest of K logits, z 2 the second largest, and so on. Then using Equation 24 we have:for i = j andThe eigenvectors and eigenvalues can be approximately computed as:with all non-explicit eigenvector components 0. This expansion is valid provided that β/K 1 (so that e β(z1-zi) e β(z1-zi+1) ).Therefore the spectrum of any individual block σ z (βz(x)) is exponentially small in β. Using the bound in the Lemma, we have:This is a very loose bound, as it assumes that the largest eigendirections of σ z are aligned with the smallest eigendirections of Id z ⊗ Θ, and vice versa. It is possible κ(H) is closer in magnitude to the upper bound e β(z2-z K ) κ( Θ(X , X )).Regardless, κ(H) is exponentially large in β -meaning that the conditioning is exponentially poor for large βz * F .

B SGD AND MOMENTUM RESCALINGS B.1 DISCRETE EQUATIONS

Consider full-batch SGD training. The update equations for the parameters θ are:We will denote g t ≡ ∇ θ L for ease of notation.Training with momentum, the equations of motion are given by:where γ ∈ [0, 1].One key point to consider later will be the relative magnitude ∆ θ of updates to the parameters. For SGD, the magnitude of updates is η||g||. For momentum with slowly-varying gradients the magnitude is η||g||/γ.

