QUADRATIC MODELS FOR UNDERSTANDING NEURAL NETWORK DYNAMICS

Abstract

In this work, we show that recently proposed quadratic models capture optimization and generalization properties of wide neural networks that cannot be captured by linear models. In particular, we prove that quadratic models for shallow ReLU networks exhibit the "catapult phase" from Lewkowycz et al. ( 2020) that arises when training such models with large learning rates. We then empirically show that the behaviour of quadratic models parallels that of neural networks in generalization, especially in the catapult phase regime. Our analysis further demonstrates that quadratic models are an effective tool for analysis of neural networks.

1. INTRODUCTION

A recent remarkable finding on neural networks, originating from Jacot et al. (2018) and termed as the "transition to linearity" (Liu et al., 2020) , is that, as network width goes to infinity, such models become linear functions in the parameter space. Thus, a linear (in parameters) model can be built to accurately approximate wide neural networks under certain conditions. While this finding has helped improve our understanding of trained neural networks (Du et al., 2019; Nichani et al., 2021; Zou & Gu, 2019; Montanari & Zhong, 2020; Ji & Telgarsky, 2019; Chizat et al., 2019) , not all properties of finite width neural networks can be understood in terms of linear models, as is shown in several recent works (Yang & Hu, 2020; Ortiz-Jiménez et al., 2021; Long, 2021; Fort et al., 2020) . In this work, we show that properties of finitely wide neural networks in optimization and generalization that cannot be captured by linear models are, in fact, manifested in quadratic models. The training dynamics of linear models with respect to the choice of the learning ratesfoot_0 are wellunderstood (Polyak, 1987) . Indeed, such models exhibit linear training dynamics, i.e., there exists a critical learning rate, η crit , such that the loss converges monotonically if and only if the learning rate is smaller than η crit (see Figure 1a ).  f lin (w; x) = f (w 0 ; x) + (w -w 0 ) T ∇f (w 0 ; x), where ∇f (w 0 ; x) denotes the gradientfoot_1 of f with respect to trainable parameters w at an initial point w 0 and input sample x. This approximation holds for learning rates less than η crit ≈ 2/ ∇f (w 0 ; x) 2 , when the width is sufficiently large. However, the training dynamics of finite width neural networks, f , can sharply differ from those of linear models when using large learning rates. A striking non-linear property of wide neural networks discovered in Lewkowycz et al. ( 2020) is that when the learning rate is larger than η crit but smaller than a certain maximum learning rate, η max , gradient descent still converges but experiences a "catapult phase." Specifically, the loss initially grows exponentially and then decreases after reaching a large value, along with the decrease of the norm of tangent kernel (see Figure 2a ), and therefore, such training dynamics are non-linear (see Figure 1b ). As linear models cannot exhibit such a catapult phase, under what models and conditions does this phenomenon arise? The work of Lewkowycz et al. (2020) first observed the catapult phase phenomenon in finite width neural networks and analyzed this phenomenon for a two-layer linear neural network. However, a theoretical understanding of this phenomenon for general non-linear neural networks remains open. In this work, we utilize a quadratic model as a tool to shed light on the optimization and generalization discrepancies between finite and infinite width neural networks. We call this model Neural Quadratic Model (NQM) as it is given by the second order Taylor series expansion of f (w; x) around the point w 0 : f quad (w) = f (w 0 ) + (w -w 0 ) T ∇f (w 0 ) f lin(w) + 1 2 (w -w 0 ) T H f (w 0 )(w -w 0 ). (2) Here in the notation we suppress the dependence on the input data x, and H f (w 0 ) is the Hessian of f with respect to w evaluated at w 0 . Indeed, we note that NQMs are contained in a more general class of quadratic models: g(w; x) = w T φ(x) + 1 2 γw T Σ(x)w,



Unless stated otherwise, we always consider the setting where models are trained with squared loss using gradient descent. For non-differentiable functions, e.g. neural networks with ReLU activation functions, we define the gradient based on the update rule used in practice. Similarly, we use H f to denote the second derivative of f in Eq. (2).



Figure 1: Optimization dynamics for linear and non-linear models based on choice of learning rate. (a) Linear models either converge monotonically if learning rate is less than η crit and diverge otherwise. (b) Unlike linear models, finitely wide neural networks and NQMs Eq. (2) (or general quadratic models Eq. (3)) can additionally observe a catapult phase when η crit < η < η max .

Figure 2: (a) Optimization dynamics of wide neural networks with sub-critical and supercritical learning rates. With sub-critical learning rates (0 < η < η crit ), the tangent kernel of wide neural networks is nearly constant during training, and the loss decreases monotonically. The whole optimization path is contained in the ball B(w 0 , R) := {w : ww 0 ≤ R} with a finite radius R. With super-critical learning rates (η crit < η < η max ), the catapult phase happens: the loss first increases and then decreases, along with a decrease of the norm of the tangent kernel . The optimization path goes beyond the finite radius ball. (b) Test loss of f quad , f and f lin plotted against different learning rates. With sub-critical learning rates, all three models have nearly identical test loss for any sub-critical learning rate. With super-critical learning rates, f and f quad have smaller best test loss than the one with sub-critical learning rates. Experimental details are in Appendix J.4. Recent work Lee et al. (2019) showed that the training dynamics of a wide neural network f (w; x) can be accurately approximated by that of a linear model f lin (w; x):

