QUADRATIC MODELS FOR UNDERSTANDING NEURAL NETWORK DYNAMICS

Abstract

In this work, we show that recently proposed quadratic models capture optimization and generalization properties of wide neural networks that cannot be captured by linear models. In particular, we prove that quadratic models for shallow ReLU networks exhibit the "catapult phase" from Lewkowycz et al. ( 2020) that arises when training such models with large learning rates. We then empirically show that the behaviour of quadratic models parallels that of neural networks in generalization, especially in the catapult phase regime. Our analysis further demonstrates that quadratic models are an effective tool for analysis of neural networks.

1. INTRODUCTION

A recent remarkable finding on neural networks, originating from Jacot et al. (2018) and termed as the "transition to linearity" (Liu et al., 2020) , is that, as network width goes to infinity, such models become linear functions in the parameter space. Thus, a linear (in parameters) model can be built to accurately approximate wide neural networks under certain conditions. While this finding has helped improve our understanding of trained neural networks (Du et al., 2019; Nichani et al., 2021; Zou & Gu, 2019; Montanari & Zhong, 2020; Ji & Telgarsky, 2019; Chizat et al., 2019) , not all properties of finite width neural networks can be understood in terms of linear models, as is shown in several recent works (Yang & Hu, 2020; Ortiz-Jiménez et al., 2021; Long, 2021; Fort et al., 2020) . In this work, we show that properties of finitely wide neural networks in optimization and generalization that cannot be captured by linear models are, in fact, manifested in quadratic models. The training dynamics of linear models with respect to the choice of the learning ratesfoot_0 are wellunderstood (Polyak, 1987) . Indeed, such models exhibit linear training dynamics, i.e., there exists a critical learning rate, η crit , such that the loss converges monotonically if and only if the learning rate is smaller than η crit (see Figure 1a ). (2) (or general quadratic models Eq. ( 3)) can additionally observe a catapult phase when η crit < η < η max . (a) Optimization dynamics for f (wide neural networks): linear dynamics and catapult dynamics. (b) Generalization performance for f , f lin and f quad . Figure 2 : (a) Optimization dynamics of wide neural networks with sub-critical and supercritical learning rates. With sub-critical learning rates (0 < η < η crit ), the tangent kernel of wide neural networks is nearly constant during training, and the loss decreases monotonically. The whole optimization path is contained in the ball B(w 0 , R) := {w : ww 0 ≤ R} with a finite radius R. With super-critical learning rates (η crit < η < η max ), the catapult phase happens: the loss first increases and then decreases, along with a decrease of the norm of the tangent kernel . The optimization path goes beyond the finite radius ball. (b) Test loss of f quad , f and f lin plotted against different learning rates. With sub-critical learning rates, all three models have nearly identical test loss for any sub-critical learning rate. With super-critical learning rates, f and f quad have smaller best test loss than the one with sub-critical learning rates. Experimental details are in Appendix J.4. Recent work Lee et al. (2019) showed that the training dynamics of a wide neural network f (w; x) can be accurately approximated by that of a linear model f lin (w; x): f lin (w; x) = f (w 0 ; x) + (w -w 0 ) T ∇f (w 0 ; x), where ∇f (w 0 ; x) denotes the gradientfoot_1 of f with respect to trainable parameters w at an initial point w 0 and input sample x. This approximation holds for learning rates less than η crit ≈ 2/ ∇f (w 0 ; x) 2 , when the width is sufficiently large. However, the training dynamics of finite width neural networks, f , can sharply differ from those of linear models when using large learning rates. A striking non-linear property of wide neural networks discovered in Lewkowycz et al. (2020) is that when the learning rate is larger than η crit but smaller than a certain maximum learning rate, η max , gradient descent still converges but experiences a "catapult phase." Specifically, the loss initially grows exponentially and then decreases after reaching a large value, along with the decrease of the norm of tangent kernel (see Figure 2a ), and therefore, such training dynamics are non-linear (see Figure 1b ). As linear models cannot exhibit such a catapult phase, under what models and conditions does this phenomenon arise? The work of Lewkowycz et al. (2020) first observed the catapult phase phenomenon in finite width neural networks and analyzed this phenomenon for a two-layer linear neural network. However, a theoretical understanding of this phenomenon for general non-linear neural networks remains open. In this work, we utilize a quadratic model as a tool to shed light on the optimization and generalization discrepancies between finite and infinite width neural networks. We call this model Neural Quadratic Model (NQM) as it is given by the second order Taylor series expansion of f (w; x) around the point w 0 : f quad (w) = f (w 0 ) + (w -w 0 ) T ∇f (w 0 ) f lin(w) + 1 2 (w -w 0 ) T H f (w 0 )(w -w 0 ). Here in the notation we suppress the dependence on the input data x, and H f (w 0 ) is the Hessian of f with respect to w evaluated at w 0 . Indeed, we note that NQMs are contained in a more general class of quadratic models: g(w; x) = w T φ(x) + 1 2 γw T Σ(x)w, where w are trainable parameters and x is input data. We discuss the optimization dynamics of such general quadratic models in Section 3.3 and show empirically that they exhibit the catapult phase phenomenon in Appendix J.3. Note that the two-layer linear network analyzed in Lewkowycz et al. (2020) is a special case of Eq. ( 3), when φ(x) = 0 (See Appendix I). Main Contributions. We prove that NQMs, f quad , which approximate shallow fully-connected ReLU activated neural networks, exhibit catapult phase dynamics. Specifically, we analyze the optimization dynamics of f quad by deriving the evolution of f quad and the tangent kernel during gradient descent with squared loss, for a single training example and multiple uni-dimensional training examples. We identify three learning rate regimes yielding different optimization dynamics for f quad , which are (1) converging monotonically (linear dynamics); (2) converging via a catapult phase (catapult dynamics); and (3) diverging. We provide a number of experimental results corroborating our theoretical analysis (See Section 3). We then empirically show that NQMs, for the architectures of shallow (see Figure 2b (2022) . Similarly, the following related works did not study catapult dynamics. Huang & Yau (2020) analyzed higher order approximations to neural networks under gradient flow (infinitesimal learning rates). Bai & Lee (2019) studied different quadratic models with randomized second order terms and Zhang et al. (2019) considered the loss in the quadratic form, where no catapult phase happens. Discontinuity in dynamics transition. In the ball B(w 0 , R) := {w : ww 0 ≤ R} with constant radius R > 0, the transition to linearity of a wide neural network (with linear output layer) is continuous in the network width m. That is, the deviation from the network function to its linear approximation within the ball can be continuously controlled by the Hessian of the network function, i.e. H f , which scales with m (Liu et al., 2020) : f (w) -f lin (w) ≤ sup w∈B(w0,R) H f (w) R 2 = Õ(1/ √ m). Using the inequality from Eq. ( 4), we obtain f quad -f lin = Õ(1/ √ m), hence f quad transitions to linearity continuously as well in B(w 0 , R)foot_2 . Given the continuous nature of the transition to linearity, one may expect that the transition from non-linear dynamics to linear dynamics for f and f quad is continuous in m as well. Namely, one would expect that the domain of catapult dynamics, [η crit , η max ], shrinks and ultimately converges to a single point, i.e., η crit = η max , as m goes to infinity, with non-linear dynamics turning into linear dynamics. However, as shown both analytically and empirically, the transition is not continuous, for both network functions f and NQMs f quad , since the domain of the catapult dynamics can be independent of the width m (or γ). Additionally, the length of the optimization path of f in catapult dynamics grows with m since otherwise, the optimization path could be contained in a ball with a constant radius independent of m, in which f can be approximated by f lin . Since f lin diverges in catapult dynamics, by the approximation, f diverges as well, which contradicts the fact that f can converge in catapult dynamics (See Figure 2a ).

2. NOTATION AND PRELIMINARY

We use bold lowercase letters to denote vectors and capital letters to denote matrices. We denote the set {1, 2, • • • , n} by [n] . We use • to denote the Euclidean norm for vectors and the spectral norm for matrices. We use to denote element-wise multiplication (Hadamard product) for vectors. We use λ max (A) and λ min (A) to denote the largest and smallest eigenvalue of a matrix A, respectively. Given a model f (w; x), where x is input data and w are model parameters, we use ∇ w f to represent the partial first derivative ∂f (w; x)/∂w. When clear from context, we let ∇f := ∇ w f for ease of notation. We use H f and H L to denote the Hessian (second derivative matrix) of the function f (w; x) and the loss L(w) with respect to parameters w, respectively. In the paper, we consider the following supervised learning task: given training data {(x i , y i )} n i=1 with data x i ∈ R d and labels y i ∈ R for i ∈ [n], we minimize the empirical risk with the squared loss L(w) = 1 2 n i=1 (f (w; x i ) -y i ) 2 . Here f (w; •) is a parametric family of models, e.g., a neural network or a kernel machine, with parameters w ∈ R p . We use full-batch gradient descent to minimize the loss, and we denote trainable parameters w at iteration t by w(t). With constant step size (learning rate) η, the update rule for the parameters is: w(t + 1) = w(t) -η dL(w) dw (t), ∀t ≥ 0. Definition 1 (Tangent Kernel). The tangent kernel K(w; •, •) of f (w; •) is defined as K(w; x, z) = ∇f (w; x), ∇f (w; z) , ∀x, z ∈ R d . In the context of the optimization problem with n training examples, the tangent kernel matrix K ∈ R n×n satisfies K i,j (w) = K(w; x i , x j ), i, j ∈ [n]. The critical learning rate for optimization is given as follows. Definition 2 (Critical learning rate). With an initialization of parameters w 0 , the critical learning rate of f (w; •) is defined as η crit := 2/λ max (H L (w 0 )). A learning rate η is said to be sub-critical if 0 < η < η crit or super-critical if η crit < η < η max . Here η max is the maximum leaning rate such that the optimization of L(w) initialized at w 0 can converge. Dynamics for Linear models. When f is linear in w, the gradient, ∇f , and tangent kernel are constant: K(w(t)) = K(w 0 ). Therefore, gradient descent dynamics are: F (w(t + 1)) -y = (I -ηK(w 0 ))(F (w(t)) -y), ∀t ≥ 0, where F (w 0 ) = [f 1 (w 0 ), ..., f n (w 0 )] T with f i (w 0 ) = f (w 0 ; x i ). Noting that H L (w 0 ) = ∇F (w 0 ) T ∇F (w 0 ) and that tangent kernel K(w 0 ) = ∇F (w 0 )∇F (w 0 ) T share the same positive eigenvalues, we have λ max (H L (w 0 )) = λ max (K(w 0 )), and hence, η crit = 2/λ max (K(w 0 )). Therefore, from Eq. equation 7, if 0 < η < η crit , the loss L decreases monotonically and if η > η crit , the loss L diverges. Note that the critical and maximum learning rates are equal in this setting.

3. OPTIMIZATION DYNAMICS IN NEURAL QUADRATIC MODELS

In this section, we analyze the gradient descent dynamics of the NQM that approximates two-layer fully connected ReLU activated neural networks up to the second order. We show that the extra quadratic term in NQMs allows for catapult convergence: the loss increases at early stage and then converges afterwards. Note that this type of convergence happens with super-critical learning rates and cannot happen for linear models. Interestingly, the top eigenvalues of the tangent kernel typically decrease after the catapult phase, while they are nearly constant when training with sub-critical learning rates, where the loss converges monotonically. Neural Quadratic Model (NQM). Consider the NQM that approximates the following two-layer neural network: f (u, v; x) = 1 √ m m i=1 v i σ 1 √ d u T i x , where  u i ∈ R d , v i ∈ R for i ∈ [m] are trainable parameters, x ∈ R d is g(u, v; x) = f (u 0 , v 0 ; x) + 1 √ md m i=1 v 0,i (u i -u 0,i ) T x1 {u T 0,i x≥0} + 1 √ md m i=1 (v i -v 0,i )σ u T 0,i x + 1 √ md m i=1 (v i -v 0,i )(u i -u 0,i ) T x1 {u T 0,i x≥0} . Given training data {x i , y i } n i=1 , we minimize the empirical risk with the squared loss L(w) = 1 2 n i=1 (g(w; x i ) -y i ) 2 using GD with constant learning rate η. Throughout this section, we denote g(u(t), v(t); x) by g(t) and its tangent kernel K(u(t), v(t)) by K(t), where t is the iteration of GD. We assume x i = O(1) and |y i | = O(1) for i ∈ [n], and we assume the width of f is much larger than the input dimension d and the data size n, i.e., m max{d, n}. Hence, d and n can be regarded as small constants. In the whole paper, we use the big-O and small-o notation with respect to the width m. Below, we start with the single training example case, which already showcases the non-linear dynamics of NQMs.

3.1. CATAPULT DYNAMICS WITH SINGLE TRAINING EXAMPLE

In this subsection, we consider training dynamics of NQMs with a single training example (x, y) where x ∈ R d and y ∈ R. In this case, the tangent kernel matrix K reduces to a scalar, and we denote K by λ to distinguish it from a matrix. By gradient descent with step size η, the updates for g(t) -y and λ(t), which we refer to as dynamics equations, can be derived as follows (see Appendix B.1): Dynamics equations. g(t + 1) -y =     1 -ηλ(t) + x 2 md η 2 (g(t) -y)g(t) Rg(t)     (g(t) -y) := µ(t)(g(t) -y), λ(t + 1) = λ(t) + η x 2 md (g(t) -y) 2 ηλ(t) -4 g(t) g(t) -y R λ (t) , ∀t ≥ 0. (12) Note that as the loss is given by L(t) = 1/2(g(t) -y) 2 , to understand convergence, it suffices to analyze the dynamics equations above. Compared to the linear dynamics Eq. ( 7), this non-linear dynamics has extra terms R g (t) and R λ (t), which are induced by the non-linear term in the NQM. We will see that the convergence of gradient descent depends on the scale and sign of R g (t) and R λ (t). For example, for constant learning rate that is slightly larger than η crit (which would result in divergence for linear models), R λ (t) stays negative during training, resulting in both monotonic decrease of tangent kernel λ and convergence of the loss. For the scale of λ 0 , which is non-negative by Definition 1, we can show that with high probability over random initialization, |λ 0 | = Θ(1) (see Appendix D). As λ(t) = λ 0 + t τ =0 R λ (τ ) , to track the scale of |µ(t)|, we will focus on the scale and sign of R g (t) and R λ (t) in the following analysis. We start by establishing monotonic convergence for sub-critical learning rates. Monotonic convergence: sub-critical learning rates (η < 2/λ(0)). The key observation we use is that when |g(t)| is small, i.e., of the order o( √ m), and λ(t) = Θ(1), |R g (t)| and |R λ (t)| are of the order o(1)(see Proposition 3 in Appendix F). Then, the dynamics equations approximately reduce to the ones of linear dynamics: g(t + 1) -y = (1 -ηλ(t) + o(1)) (g(t) -y), λ(t + 1) = λ(t) + o(1). Note that at initialization, with high probability over random initialization, the output satisfies (Jacot et al., 2018) , and we have shown λ(0) = Θ(1).  |g(0)| = O(1) if x = O(1) (t) = o(1) • log O(1) = o(1) . Catapult convergence: super-critical learning rates (2/λ(0) < η < 4/λ(0)). The training dynamics are given by the following theorem. Theorem 1 (Catapult dynamics on a single training example). Consider training a NQM, Eq. ( 10), with squared loss on a single training example by gradient descent. If the learning rate is supercritical, i.e., 2/λ 0 < η < 4/λ 0 , then there exist T 1 , T 2 , T 3 , T 4 such that 0 < T 1 < T 2 < T 3 < T 4 and the training dynamics exhibits: (i) Increasing phase: t ∈ [0, T 1 ]. In this phase, L(t) = o(m). The loss grows exponentially and the tangent kernel is nearly constant, i.e. |λ(t) -λ(0 )| = o(1). (ii) Peak phase: t ∈ [T 2 , T 3 ]. In this phase, L(t) = Θ(m). The tangent kernel decreases significantly: ∞) . In this phase, the loss satisfies L(t) = o(m) again and decreases. The tangent kernel is nearly constant until convergence: λ(t + 1) -λ(t) < 0 and |λ(t + 1) -λ(t)| = Θ(1). (iii) Decreasing phase: t ∈ [T 4 , |λ(t) -λ(∞)| = o(1). Furthermore, T 1 = o(log m), T 2 = Θ(log m), T 3 -T 2 = Θ(1) and T 4 = Θ(log m). Proof of Theorem 1. We will analyze the training dynamics in each phase sequentially. The loss grows exponentially at iteration 0: by the choice of the learning rate, µ(0) satisfies |µ(0)| = |1 -ηλ(0) + R g (0)| = |1 -ηλ(0) + o(1)| > 1. Therefore |g(1)-y| > |g(0)-y|, and the tangent kernel almost does not change: λ(1) = λ(0)+o(1). We can recursively apply this argument for the following steps as long as |g(t)| = o( √ m) according to Proposition 3. Note that in the increasing phase, the loss grows exponentially due to |µ(t)| > 1. Therefore, the tangent kernel is nearly constant since the cumulative change is 1), where we use the fact that g(t) 2 grows exponentially to the order of o(m). And we can get T 1 = o(log m). Furthermore, it is not hard to see that until the loss grows to the order of Θ(m), the tangent kernel does not change much hence the loss keeps increasing exponentially. (ii) Peak phase. The key observation we use here is that when |g(t)| is large, i.e., of the order Θ( √ m), |R g (t)| and |R λ (t)| will be of the order Θ(1)(see Proposition 4 in Appendix F), which can lead to the decrease of the loss. T1 t=0 |R λ (t)| = T1 t=0 Θ g(t) 2 /m = o( In the peak phase, we have |g(t)| = Θ( √ m), then by Proposition 4, the scale of |R λ (t)| is Θ(1), and R λ (t) < 0 since λ(t) < 4/η (in the increasing phase, λ(t) almost does not change, hence we have λ(t) ≈ λ(0) < 4/η before the peak phase, and will decrease significantly). Consequently, by Eq. ( 12), λ(t) will have significant decrease as 1) and R g (t) > 0 by Proposition 4. Then we can see the increase of loss slows down compared to that in the increasing phase: |R λ (t)| = Θ(1) and λ(t + 1) = λ(t) + R λ (t) < λ(t), which is further smaller than 4/η. Similarly, when |g(t)| = Θ( √ m), |R g (t)| = Θ( |µ(t)| = |1 -ηλ(t) + R g (t)| < |1 -ηλ(0) + R g (0)| = |1 -ηλ(0) + o(1)| ≈ µ(0). In general, the loss grows exponentially prior to the peak phase hence T 2 = Θ(log m). And the peak phase only lasts Θ(1) steps during training, i.e., |T 3 -T 2 | = Θ(1), since the decrease of λ(t) is Θ(1) at each step and the training dynamics will enter into the decreasing phase once |µ(t)| < 1, which will happen if λ(t) is sufficiently small. (iii) Decreasing phase. The peak phase ends when |µ(t)| < 1, then |g(t) -y| starts to decrease. We note that our analysis implicitly assumes that the optimization path will not arrive at the saddle point, i.e., |µ(t)| = 1, and it is generally true with discrete steps. Recall that at the peak phase, λ(t) decreases, and when |µ(t)| > 1, |g(t)| increases which causes the increase of R g (t) since R g (t) scales with |g(t)|. As a result |µ(t)| = |1 -ηλ(t) + R g (t) | decreases and will be less than 1 ultimately. Note that though |g(t) -y| starts to decrease, it is still of the order Θ( √ m), which makes λ(t) decrease significantly as R λ (t) = Θ(1) by Proposition 4. The decrease of λ(t) will stop once |g(t)| decreases to the order of o(m), as R λ (t) = o(1) again by Proposition 3. At that moment, T 4 -T 3 = Θ(log m) as the loss decreases exponentially and the linear dynamics dominate again. Therefore, similar to the increasing phase, starting from t such that |g(t)| = o( √ m), the loss decreases exponentially hence the change of λ(t) until convergence is o(1). Divergence (η > η max = 4/λ(0)). Initially, it follows the same dynamics with those in the increasing phase in catapult convergence: the loss increases exponentially and the tangent kernel is nearly constant. However, when |g(t)| grows to the order of Θ( √ m), corresponding to the peak phase in catapult convergence, λ(t) does not decrease but increases significantly. Specifically, since η > 4/λ(0), we approximately have η > 4/λ(t) at the end of the increasing phase by the same analysis in catapult convergence. By Proposition 3, R λ (t) > 0, then λ(t) increases as λ(t + 1) = λ(t) + R λ (t) > λ(t). Larger λ(t) leads to the faster increase on λ(t), hence |µ(t)| becomes even larger. As a result, |g(t) -y| grows faster, therefore the loss diverges.

3.2. CATAPULT DYNAMICS WITH MULTIPLE TRAINING EXAMPLES

In this subsection we show the catapult phase will happen for NQMs Eq. ( 9) with multiple training examples. We assume unidimensional input data, which is common in the literature and simplifies the analysis for neural networks (see for example Williams et al. (2019) ; Savarese et al. (2019) ). Assumption 1. The input dimension d = 1 and not all x i is 0, i.e., |x i | > 0. Since x i is a scalar for all i ∈ [n], with the homogeneity of ReLU activation function, we can compute the exact eigenvectors of K(t) for all t ≥ 0. To that end, we group the data into two sets S + and S - according to their sign: S + := {i : x i ≥ 0, i ∈ [n]}, S -:= {i : x i < 0, i ∈ [n] }. Now we have the proposition for the tangent kernel K(the proof is deferred to Appendix C): Proposition 1 (Eigenvectors and low rank structure of K). For any u, v ∈ R m , rank(K) ≤ 2. Furthermore, p 1 , p 2 are eigenvectors of K, where p 1,i = x i 1 {i∈S+} , p 2,i = x i 1 {i∈S-} , for i ∈ [n]. Note that when all x i are of the same sign, rank(K) = 1 and K only has one eigenvector (either p 1 or p 2 depending on the sign). It is in fact a simpler setting since we only need to consider one direction, whose analysis is covered by the one for rank(K) = 2. Therefore, in the following we will assume rank(K) = 2. We denote two eigenvalues of K(t) by λ 1 (t) and λ 2 (t) corresponding to p 1 and p 2 respectively, i.e., K(t)p 1 = λ 1 (t)p 1 , K(t)p 2 = λ 2 (t)p 2 . Without loss of generality, we assume λ 1 (0) ≥ λ 2 (0). We similarly analyze the dynamics equations for multiple training examples (see Eq. ( 14) and ( 15) which are update equations of g(t) -y and K(t)) with different learning rates. And we formulate the result for the catapult dynamics, which happens when training with super-critical learning rates, into the following theorem: Theorem 2 (Catapult dynamics on multiple training examples). Consider training a NQM Eq. ( 10) with squared loss on multiple training examples by gradient descent. Under Assumption 1, if the learning rate is super-critical i.e., 2/λ 1 (0) < η < min{2/λ 2 (0), 4/λ 1 (0)}, then there exist T 1 , T 2 , T 3 , T 4 such that 0 < T 1 < T 2 < T 3 < T 4 and the training dynamics exhibits: (i) Increasing phase: t ∈ [0, T 1 ]. In this phase, L(t) = o(m). The loss grows exponentially; both eigenvalues are nearly constant i.e., |λ k (t) -λ k (0)| = o(1) for k = 1, 2. (ii) Peak phase: t ∈ [T 2 , T 3 ]. In this phase, L(t) = Θ(m). For the tangent kernel: (a) If 2/λ 2 (0) ≤ 4/λ 1 (0), both eigenvalues decrease significantly, i.e., λ 1 (t+1)-λ 1 (t) < 0, λ 2 (t + 1) -λ 2 (t) < 0 and both difference are of the order Θ(1). (b) If 2/λ 2 (0) > 4/λ 1 (0), only λ 1 (t) decreases significantly. (iii) Decreasing phase: t ∈ [T 4 , ∞). In this phase, the loss satisfies L(t) = o(m) again and decreases. Both eigenvalues are nearly constant until convergence: By our analysis, two critical values are 2/λ 1 (0) = 0.37 and 2/λ 2 (0) = 0.39. When η < 0.37, linear dynamics dominate hence the kernel is nearly constant; when 0.37 < η < 0.39, the catapult phase happens in p 1 and only λ 1 (t) decreases; when 0.39 < η < η max , the catapult phase happens in p 1 and p 2 hence both λ 1 (t) and λ 2 (t) decreases. The experiment details can be found in Appendix J.1. |λ k (t)-λ k (∞)| = o(1), for k = 1, 2.

3.3. CONNECTION TO WIDE NEURAL NETWORKS AND GENERAL QUADRATIC MODELS

Wide neural networks. We have seen that NQMs, with fixed Hessian, exhibit the catapult phase phenomenon. Therefore, the change in the Hessian of wide neural networks during training is not required to produce the catapult phase. In our analysis, we show that the catapult phase arises because the eigenvectors of the tangent kernel "align" with the Hessian's spectrum, i.e., H gi for i ∈ S + are proportional with coefficients p 1 , and the same holds for H gi for i ∈ S -with coefficients p 2 . E.g., H gj /p 1,j = H g k /p 1,k if j, k ∈ S + . We believe this idea can be used to analyze the catapult dynamics in wide neural networks with changing Hessian. A similar behaviour of top eigenvalues of the tangent kernel with the one for NQMs is observed for wide neural networks when training with different learning rates (See Figure 5 in Appendix J). General quadratic models. As mentioned in the introduction, NQMs are contained in a general class of quadratic models of the form given in Eq. ( 3). We show that the two-layer linear neural network analyzed in Lewkowycz et al. ( 2020) is a special case of Eq. ( 3), and we provide a more general condition for such models to have catapult dynamics in Appendix I. Furthermore, we empirically observe that a broader class of quadratic models g can have catapult dynamics simply by letting φ(x) and Σ be random and assigning a small value to γ (See Appendix J.3).

4. QUADRATIC MODELS PARALLEL NEURAL NETWORKS IN GENERALIZATION

In this section, we empirically compare the test performance of three different models considered in this paper upon varying learning rate. In particular, we consider (1) the NQM, f quad ; (2) corresponding neural networks, f ; and (3) the linear model, f lin . We implement our experiments on 3 vision datasets: CIFAR-2 (a 2-class subset of CIFAR-10 ( Krizhevsky et al., 2009) ), MNIST (LeCun et al., 1998) , and SVHN (The Street View House Numbers) (Netzer et al., 2011) , 1 speech dataset: Free Spoken Digit dataset (FSDD) (Jakobovski, 2020) and 1 text dataset: AG NEWS (Gulli, 2005) . In all experiments, we train the models by minimizing the squared loss using standard GD/SGD with constant learning rate η. We report the best test loss achieved during the training process with each learning rate. Experimental details can be found in Appendix J.4. We also report the best test accuracy in Appendix J.5. For networks with 3 layers, see Appendix J.6. From the experimental results, we observe the following: Sub-critical learning rates. In accordance with our theoretical analyses, we observe that all three models have nearly identical test loss for any sub-critical learning rate. Specifically, note that as the width m increases, f and f quad will transition to linearity in the ball B(w 0 , R): f -f lin = Õ(1/ √ m), f quad -f lin = Õ(1/ √ m), where R > 0 is a constant which is large enough to contain the optimization path with respect to sub-critical learning rates. Thus, the generalization performance of these three models will be similar when m is large, as shown in Figure 4 . Super-critical learning rates. The best test loss of both f (w) and f quad (w) is consistently smaller than the one with sub-critical learning rates, and decreases for an increasing learning rate in a range of values beyond η crit , which was observed for wide neural networks in Lewkowycz et al. (2020) . As discussed in the introduction, with super-critical learning rates, both f quad and f can be observed to have catapult phase, while the loss of f lin diverges. Together with the similar behaviour of f quad and f in generalization with super-critical learning rates, we believe NQMs are a better model to understand f in training and testing dynamics, than the linear approximation f lin . In Figure 4 we report the results for networks with ReLU activation function. We also implement the experiments using networks with Tanh and Swish (Ramachandran et al., 2017) activation functions, and observe the same phenomena in generalization for f , f lin and f quad (See Appendix J.7).

5. CONCLUSIONS

In this paper, we use quadratic models as a tool to better understand optimization and generalization properties of finite width neural networks trained using large learning rates. Notably, we prove that quadratic models exhibit properties of neural networks such as the catapult dynamics that cannot be explained using linear models, which importantly includes linear approximations to neural networks given by the neural tangent kernel. Interestingly, we show empirically that quadratic models mimic the generalization properties of neural networks when trained with large learning rate, and that such models perform better than linearized neural networks.

A DERIVATION OF NQM

We will derive the NQM that approximate the two-layer fully connected ReLU activated neural networks based on Eq. ( 2). The first derivative of f can be computed by: ∂f ∂u i = 1 √ md v i 1 {u T i x≥0} x T , ∂f ∂v i = 1 √ m σ 1 √ d u T i x , ∀i ∈ [m]. And each entry of the Hessian of f , i.e., H f , can be computed by ∂ 2 f ∂u 2 i = 0, ∂ 2 f ∂v 2 i = 0, ∂ 2 f ∂u i v i = 1 √ md 1 {u T i x≥0} x T , ∀i ∈ [m]. Now we get f quad taking the following form NQM : f quad (u, v; x) = f (u 0 , v 0 ; x) + 1 √ md m i=1 (u i -u 0,i ) T x1 {u T 0,i x≥0} v 0,i + 1 √ m m i=1 (v i -v 0,i )σ 1 √ d u T 0,i x + 1 √ md m i=1 (u i -u 0,i ) T x1 {u T 0,i x≥0} (v i -v 0,i ).

B DERIVATION OF DYNAMICS EQUATIONS B.1 SINGLE TRAINING EXAMPLE

The NQM can be equivalently written as: g(u, v; x) = g(u 0 , v 0 ; x) + u -u 0 , ∇ u g(u, v; x) u=u0,v=v0 + v -v 0 , ∇ v g(u, v; x) u=u0,v=v0 + u -u 0 , ∂ 2 g(u, v; x) ∂u∂v u=u0,v=v0 (v -v 0 ) , since ∂ 2 g ∂u 2 = 0 and ∂ 2 g ∂v 2 = 0. And the tangent kernel λ(u, v; x) takes the form λ(u, v; x) = ∇ u g(u, v; x) u=u0,v=v0 + ∂ 2 g(u, v; x) ∂u∂v u=u0 (v -v 0 ) 2 F + ∇ v g(u, v; x) u=u0,v=v0 + (u -u 0 ) T ∂ 2 g(u, v; x) ∂u∂v u=u0,v=v0 2 . Here ∇ ui g(u, v; x) u=u0,v=v0 = 1 √ md m i=1 v 0,i 1 {u T 0,i x≥0} x, ∀i ∈ [m], ∇ v g(u, v; x) u=u0,v=v0 = 1 √ md σ u T 0 x . In the following, we will consider the dynamics of g and λ with GD, hence for simplicity of notations, we denote ∇ u g(0) := ∇ u g(u, v; x) u=u0,v=v0 , ∇ v g(0) := ∇ v g(u, v; x) u=u0,v=v0 , ∂ 2 g(0) ∂u∂v := ∂ 2 g(u, v; x) ∂u∂v u=u0,v=v0 . By gradient descent with learning rate η, at iteration t, we have the update equations for weights u and v: u(t + 1) = u(t) -η(g(t) -y) ∇ u g(0) + ∂ 2 g(0) ∂u∂v (v(t) -v(0)) , v(t + 1) = v(t) -η(g(t) -y) ∇ v g(0) + (u(t) -u(0)) T ∂ 2 g(0) ∂u∂v . Then we plug them in the expression of λ(t + 1) and we get λ(t + 1) = ∇ u g(0) + ∂ 2 g(0) ∂u∂v (v(t + 1) -v(0)) 2 F + ∇ v g(0) + (u(t + 1) -u(0)) T ∂ 2 g(0) ∂u∂v 2 = ∇ u g(0) + ∂ 2 g(0) ∂u∂v v(t) -η(g(t) -y) ∇ v g(0) + (u(t) -u(0)) T ∂ 2 g(0) ∂u∂v -v(0) 2 F + ∇ v g(0) + u(t) -η(g(t) -y) ∇ u g(0) + ∂ 2 g(0) ∂u∂v (v(t) -v(0)) -u(0) T ∂ 2 g(0) ∂u∂v 2 = λ(t) + η 2 (g(t) -y) 2 ∂ 2 g(0) ∂u∂v ∇ v g(0) + (u(t) -u(0)) T ∂ 2 g(0) ∂u∂v 2 F + η 2 (g(t) -y) 2 ∇ u g(0) + ∂ 2 g(0) ∂u∂v (v(t) -v(0)) T ∂ 2 g(0) ∂u∂v 2 -2η(g(t) -y) ∇ u g(0) + ∂ 2 g(0) ∂u∂v (v(t) -v(0)), ∂ 2 g(0) ∂u∂v ∇ v g(0) + (u(t) -u(0)) T ∂ 2 g(0) ∂u∂v -2η(g(t) -y) ∇ v g(0) + (u(t) -u(0)) T ∂ 2 g(0) ∂u∂v , ∇ u g(0) + ∂ 2 g(0) ∂u∂v (v(t) -v(0)) T ∂ 2 g(0) ∂u∂v . Due to the structure of ∂ 2 g(0) ∂u∂v , we have ∂ 2 g(0) ∂u∂v ∇ v g(0) + (u(t) -u(0)) T ∂ 2 g(0) ∂u∂v 2 F = x 2 md ∇ v g(0) + (u(t) -u(0)) T ∂ 2 g(0) ∂u∂v 2 = x 2 md ∇ v g(t) 2 , and ∇ u g(0) + ∂ 2 g(0) ∂u∂v (v(t) -v(0)) T ∂ 2 g(0) ∂u∂v 2 = x 2 md ∇ u g(0) + ∂ 2 g(0) ∂u∂v (v(t) -v(0)) 2 F = x 2 md ∇ u g(t) 2 F . Furthermore, ∇ u g(0) + ∂ 2 g(0) ∂u∂v (v(t) -v(0)), ∂ 2 g(0) ∂u∂v ∇ v g(0) + (u(t) -u(0)) T ∂ 2 g(0) ∂u∂v = x 2 md v(t) -v(0), ∇ v g(0) + x 2 md ∇ u g(0), u(t) -u(0) + ∇ u g(0), ∂ 2 g(0) ∂u∂v ∇ v g(0) + ∂ 2 g(0) ∂u∂v (v(t) -v(0)), ∂ 2 g(0) ∂u∂v (u(t) -u(0)) T ∂ 2 g(0) ∂u∂v = x 2 md v(t) -v(0), ∇ v g(0) + x 2 md ∇ u g(0), u(t) -u(0) + g(0) + x 2 md v(t) -v(0), ∂ 2 g(0) ∂u∂v (u(t) -u(0)) T = g(t) x 2 /md. Similarly, we have ∇ v g(0) + (u(t) -u(0)) T ∂ 2 g(0) ∂u∂v , ∇ u g(0) + ∂ 2 g(0) ∂u∂v (v(t) -v(0)) T ∂ 2 g(0) ∂u∂v = g(t) x 2 /md. As a result, λ(t + 1) = λ(t) + x 2 md η 2 (g(t) -y) 2 λ(t) - 4 x 2 md η(g(t) -y)g(t) = λ(t) + η x 2 md (g(t) -y) 2 ηλ(t) -4 g(t) g(t) -y . For g, we plug the update equations for u and v in the expression of g(t + 1) and we can get g(t + 1) = g(0) + u(t + 1) -u(0), ∇ u g(0) + v(t + 1) -v(0), ∇ v g(0) + u(t + 1) -u(0), ∂ 2 g(0) ∂u∂v (v(t + 1) -v(0) = g(0) + u(t) -η(g(t) -y) ∇ u g(0) + ∂ 2 g(0) ∂u∂v (v(t) -v(0)) -u(0), ∇ u g(0) + v(t) -η(g(t) -y) ∇ v g(0) + (u(t) -u(0)) T ∂ 2 g(0) ∂u∂v -v(0), ∇ v g(0) + u(t) -η(g(t) -y) ∇ u g(0) + ∂ 2 g(0) ∂u∂v (v(t) -v(0)) -u(0) , ∂ 2 g(0) ∂u∂v v(t) -η(g(t) -y) ∇ v g(0) + (u(t) -u(0)) T ∂ 2 g(0) ∂u∂v -v(0) = g(t) -η(g(t) -y) ∇ u g(0) + ∂ 2 g(0) ∂u∂v (v(t) -v(0)), ∇ u g(0) -η(g(t) -y) ∇ v g(0) + (u(t) -u(0)) T ∂ 2 g(0) ∂u∂v , ∇ v g(0) + η 2 (g(t) -y) 2 ∇ u g(0) + ∂ 2 g(0) ∂u∂v (v(t) -v(0)), ∂ 2 g(0) ∂u∂v ∇ v g(0) + (u(t) -u(0)) T ∂ 2 g(0) ∂u∂v -η(g(t) -y) u(t) -u(0), ∂ 2 g(0) ∂u∂v ∇ v g(0) + (u(t) -u(0)) T ∂ 2 g(0) ∂u∂v -η(g(t) -y) ∇ u g(0) + ∂ 2 g(0) ∂u∂v (v(t) -v(0)), ∂ 2 g(0) ∂u∂v (v(t) -v(0)) = g(t) -η(g(t) -y)λ(t) + η 2 (g(t) -y) 2 ∇ u g(0) + ∂ 2 g(0) ∂u∂v (v(t) -v(0)), ∂ 2 g(0) ∂u∂v ∇ v g(0) + (u(t) -u(0)) T ∂ 2 g(0) ∂u∂v = g(t) -η(g(t) -y)λ(t) + x 2 md η 2 (g(t) -y) 2 g(t) Therefore, g(t + 1) -y = 1 -ηλ(t) + x 2 md η 2 (g(t) -y)g(t) (g(t) -y).

B.2 MULTIPLE TRAINING EXAMPLES

We follow the similar notation on the first and second order derivative of g with Appendix B.1. Specifically, for k ∈ [n], we denote ∇ u g k (0) := ∇ u g(u, v; x k ) u=u0,v=v0 , ∇ v g k (0) := ∇ v g(u, v; x k ) u=u0,v=v0 , ∂ 2 g k (0) ∂u∂v := ∂ 2 g(u, v; x k ) ∂u∂v u=u0,v=v0 . By GD with learning rate η, we have the update equations for weights u and v at iteration t: u(t + 1) = u(t) -η n k=1 (g k (t) -y k ) ∇ u g k (0) + ∂ 2 g k (0) ∂u∂v (v(t) -v(0)) , v(t + 1) = v(t) -η n k=1 (g k (t) -y k ) ∇ v g k (0) + (u(t) -u(0)) T ∂ 2 g k (0) ∂u∂v . We consider the evolution of K(t) first. K i,j (t + 1) = ∇ u g i (0) + ∂ 2 g i (0) ∂u∂v (v(t + 1) -v(0)), ∇ u g j (0) + ∂ 2 g j (0) ∂u∂v (v(t + 1) -v(0)) + ∇ v g i (0) + (u(t + 1) -u(0)) T ∂ 2 g i (0) ∂u∂v , ∇ v g j (0) + (u(t + 1) -u(0)) T ∂ 2 g j (0) ∂u∂v = K i,j (t) -η ∂ 2 g i (0) ∂u∂v n k=1 (g k (t) -y k ) ∇ v g k (0) + (u(t) -u(0)) T ∂ 2 g k (0) ∂u∂v , ∇ u g j (0) + ∂ 2 g j (0) ∂u∂v (v(t) -v(0)) -η ∂ 2 g j (0) ∂u∂v n k=1 (g k (t) -y k ) ∇ v g k (0) + (u(t) -u(0)) T ∂ 2 g k (0) ∂u∂v , ∇ u g i (0) + ∂ 2 g i (0) ∂u∂v (v(t) -v(0)) + η ∂ 2 g i (0) ∂u∂v n k=1 (g k (t) -y k ) ∇ v g k (0) + (u(t) -u(0)) T ∂ 2 g k (0) ∂u∂v , η ∂ 2 g j (0) ∂u∂v n k=1 (g k (t) -y k ) ∇ v g k (0) + (u(t) -u(0)) T ∂ 2 g k (0) ∂u∂v -η ∂ 2 g j (0) ∂u∂v n k=1 (g k (t) -y k ) ∇ u g k (0) + ∂ 2 g k (0) ∂u∂v (v(t) -v(0)) , ∇ v g i (0) + (u(t) -u(0)) T ∂ 2 g i (0) ∂u∂v -η ∂ 2 g i (0) ∂u∂v n k=1 (g k (t) -y k ) ∇ u g k (0) + ∂ 2 g k (0) ∂u∂v (v(t) -v(0)) , ∇ v g j (0) + (u(t) -u(0)) T ∂ 2 g j (0) ∂u∂v + η ∂ 2 g i (0) ∂u∂v n k=1 (g k (t) -y k ) ∇ u g k (0) + ∂ 2 g k (0) ∂u∂v (v(t) -v(0)) , η ∂ 2 g j (0) ∂u∂v n k=1 (g k (t) -y k ) ∇ u g k (0) + ∂ 2 g k (0) ∂u∂v (v(t) -v(0)) . We separate the data into two sets according to their sign: S + := {i : x i ≥ 0, i ∈ [n]}, S -:= {i : x i < 0, i ∈ [n]}. We consider two scenarios: (1) x i and x j have different signs; (2) x i and x j have the same sign. (1) With simple calculation, we get if x i and x j have different signs, i.e., i ∈ S + , j ∈ S -or i ∈ S -, j ∈ S + , ∂ 2 g i (0) ∂u∂v ∂ 2 g j (0) ∂u∂v = 0, ∂ 2 g i (0) ∂u∂v ∇ u g j (0) = 0, ∂ 2 g i (0) ∂u∂v ∇ v g j (0) = 0. Without lose of generality, we assume i ∈ S + , j ∈ S -. Then we have K i,j (t + 1) = K i,j (t). (2) If x i and x j have the same sign, i.e., i, j ∈ S + or i, j ∈ S -, ∂ 2 g i (0) ∂u∂v ∂ 2 g j (0) ∂u∂v = 1 √ m ∂ 2 g i (0) ∂u∂v x j , ∂ 2 g i (0) ∂u∂v ∇ u g j (0) = 1 √ m ∇ u g i (0)x j , ∂ 2 g i (0) ∂u∂v ∇ v g j (0) = 1 √ m ∇ v g i (0)x j . For i, j ∈ S + , we have K i,j (t + 1) = K i,j (t) - 2η √ m k∈S+ (g k (t) -y k )x i ∇ v g k (0) + (u(t) -u(0)) T ∂ 2 g k (0) ∂u∂v , ∇ u g j (0) + ∂ 2 g j (0) ∂u∂v (v(t) -v(0)) - 2η √ m k∈S+ (g k (t) -y k )x i ∇ u g k (0) + ∂ 2 g k (0) ∂u∂v (v(t) -v(0)), ∇ v g j (0) + (u(t) -u(0)) T ∂ 2 g j (0) ∂u∂v + η 2 m x i x j k∈S+ (g k (t) -y k ) ∇ v g k (0) + (u(t) -u(0)) T ∂ 2 g k (0) ∂u∂v 2 + η 2 m x i x j k∈S+ (g k (t) -y k ) ∇ u g k (0) + ∂ 2 g k (0) ∂u∂v (v(t) -v(0)) 2 = K i,j (t) - 4η m x i x j k∈S+ (g k (t) -y k )g k (t) + η 2 m x i x j ((g(t) -y) m + ) T K(t) ((g(t) -y) m + ) = K i,j (t) - 4η m x i x j ((g(t) -y) m + ) T (g(t) m + ) + η 2 m x i x j ((g(t) -y) m + ) T K(t) ((g(t) -y) m + ) . Similarly, for i, j ∈ S -, we have K i,j (t + 1) = K i,j (t) - 4η m x i x j ((g(t) -y) m -) T (g(t) m -) + η 2 m x i x j ((g(t) -y) m -) T K(t) ((g(t) -y) m -) . Combining the results together, we have K(t + 1) = K(t) + η 2 m ((g(t) -y) m + ) T K(t) ((g(t) -y) m + ) p 1 p T 1 + η 2 m ((g(t) -y) m -) T K(t) ((g(t) -y) m -) p 2 p T 2 - 4η m ((g(t) -y) m + ) T (g(t) m + ) p 1 p T 1 - 4η m ((g(t) -y) m -) T (g(t) m -) p 2 p T 2 . Now we derive the evolution of g(t) -y. Suppose i ∈ S + . Then we have g i (t + 1) = g i (0) + u(t + 1) -u(0), ∇ u g i (0) + v(t + 1) -v(0), ∇ v g i (0) + u(t + 1) -u(0), ∂ 2 g i (0) ∂u∂v (v(t + 1) -v(0) = g i (t) -η n k=1 (g k (t) -y k ) ∇ u g k (0) + ∂ 2 g k (0) ∂u∂v (v(t) -v(0)) , ∇ u g i (0) -η n k=1 (g k (t) -y k ) ∇ v g k (0) + (u(t) -u(0)) T ∂ 2 g k (0) ∂u∂v , ∇ v g i (0) -η n k=1 (g k (t) -y k ) ∇ u g k (0) + ∂ 2 g k (0) ∂u∂v (v(t) -v(0)) , ∂ 2 g i (0) ∂u∂v (v(t) -v(0) -η n k=1 (g k (t) -y k ) ∇ v g k (0) + (u(t) -u(0)) T ∂ 2 g k (0) ∂u∂v , (u(t) -u(0) T ∂ 2 g i (0) ∂u∂v + η 2 n k=1 (g k (t) -y k ) ∇ u g k (0) + ∂ 2 g k (0) ∂u∂v (v(t) -v(0)) , ∂ 2 g i (0) ∂u∂v n k=1 (g k (t) -y k ) ∇ v g k (0) + (u(t) -u(0)) T ∂ 2 g k (0) ∂u∂v = g i (t) -η k∈S+ (g k (t) -y k )K k,i (t) + η 2 m k∈S+ j∈S+ (g k (t) -y k )(g j (t) -y j )g j (t)x k x i . Similarly, for i ∈ S -, we have g i (t + 1) = g i (t) -η k∈S- (g k (t) -y k )K k,i (t) + η 2 m k∈S-f j∈S- (g k (t) -y k )(g j (t) -y j )g j (t)x k x i . Combining the results together, we have g(t + 1) -y = I -ηK(t) + η 2 m ((g(t) -y) m + ) T (g(t) m + )p 1 p T 1 + η 2 m ((g(t) -y) m -) T (g(t) m -)p 2 p T 2 (g(t) -y).

C PROOF OF PROPOSITION 1

Restate Proposition 1: For any u, v ∈ R m , rank(K) ≤ 2. Furthermore, p 1 , p 2 are eigenvectors of K, where p 1,i = x i 1 {i∈S+} , p 2,i = x i 1 {i∈S-} , for i ∈ [n]. Proof. By Definition 1, K i,j = 1 m m k=1 (v 2 k + u 2 k )x i x j 1 {u k xi≥0} 1 {u k xj ≥0} , i, j ∈ [n]. By definition of eigenvector, we can see n j=1 K i,j p 1,j = 1 m n j=1 m k=1 (v 2 k + u 2 k )x i x 2 j 1 {u k xi≥0} 1 {u k xj ≥0} 1 {j∈S+} = n j=1 x 2 j 1 {j∈S+} 1 m m k=1 (v 2 k + u 2 k )x i 1 {u k xi≥0} 1 {u k xj ≥0} = x i 1 {xi∈S+} n j=1 x 2 j 1 {j∈S+} 1 m m k=1 (v 2 k + u 2 k )1 {u k xj ≥0} , where we use the fact that if x i x j < 0, K i,j = 0. As p 1,i = x i 1 {xi∈S+} and n j=1 x 2 j 1 {j∈S+} 1 m m k=1 (v 2 k + u 2 k )1 {u k xj ≥0} does not depend on i, we can see p 1 is an eigenvector of K with corresponding eigenvalue λ 1 = n j=1 x 2 j 1 {j∈S+} 1 m m k=1 (v 2 k + u 2 k )1 {u k xj ≥0} . The same analysis can be applied to show p 2 is another eigenvector of K with corresponding λ 2 = n j=1 x 2 j 1 {j∈S-} 1 m m k=1 (v 2 k + u 2 k )1 {u k xj ≥0} . For the rank of K, it is not hard to verify that K = λ 1 p 1 p T 1 + λ 2 p 2 p T 2 hence the rank of K is at most 2.

D SCALE OF THE TANGENT KERNEL FOR SINGLE TRAINING EXAMPLE

Proposition 2 (Scale of tangent kernel). For any δ ∈ (0, 1), if m ≥ c log(4/δ) where c is an absolute constant, with probability at least 1 -δ, x 2 /(2d) ≤ λ(0) ≤ 3 x 2 /(2d). Proof. Note that when t = 0, λ(0) = 1 md m i=1 u T 0,i x1 {u T 0,i x≥0} 2 + 1 md m i=1 (v 0,i ) 2 x 2 1 {u T 0,i x≥0} 2 . According to NTK initialization, for each i ∈ [m], v 0,i ∼ N (0, 1) and u 0,i ∼ N (0, I). We consider the random variable ζ i := u T 0,i x1 {u T 0,i x≥0} , ξ i := v 0,i 1 {u T 0,i x≥0} . it is not hard to see that ζ i and ξ i are sub-guassian since u T 0,i x and v 0,i are sub-gaussian. Specifically, for any t ≥ 0, P{|ζ i | ≥ t} ≤ P{|u T 0,i x| ≥ t} ≤ 2 exp -t 2 /(2 x 2 ) , P{|ξ i | ≥ t} ≤ P{|v 0,i | ≥ t} ≤ 2 exp -t 2 /2 , where the second inequality comes from the definition of sub-gaussian variables. Since ξ i is sub-gaussian, by definition, ξ 2 is sub-exponential, and its sub-exponential norm is bounded: ξ 2 i ψ1 ≤ ξ i 2 ψ2 ≤ C, where C > 0 is a absolute constant. Similarly we have ζ i 2 ψ2 ≤ C x 2 . By Bernstein's inequality, for every t ≥ 0, we have P m i=1 ξ 2 i - m 2 ≥ t ≤ 2 exp -c min t 2 m i=1 ξ 2 i 2 ψ1 , t max i ξ 2 i ψ1 , where c > 0 is an absolute constant. Letting t = m/4, we have with probability at least 1 -2 exp (-m/c ), m 4 ≤ m i=1 ξ 2 i ≤ 3m 4 , where c = c/(4C). Similarity, we have with probability at least 1 -2 exp (-m/c ), m 4 x 2 ≤ m i=1 ζ 2 i ≤ 3m 4 x 2 . As a result, using union bound, we have probability at least 1 -4 exp (-m/c ), x 2 2d ≤ λ(0) ≤ 3 x 2 2d .

E SCALE OF THE TANGENT KERNEL FOR MULTIPLE TRAINING EXAMPLES

Proof. As shown in Proposition 1, p 1 and p 2 are eigenvectors of K, hence we have two eigenvalues: λ 1 (0) = p T 1 K(0)p 1 p 1 2 , λ 2 (0) = p T 2 K(0)p 2 p 2 2 . Take λ 1 (0) as an example: λ 1 (0) p 1 2 = n i,j=1 x i x j 1 {xi≥0} 1 {xj ≥0} m k=1 (u 2 0,k + v 2 0,k )x i x j 1 {u 0,k xi≥0} 1 {u 0,k xj ≥0} = m k=1 (u 2 0,k + v 2 0,k ) 1 {u 0,k ≥0} 2 n i,j=1 x 2 i x 2 j 1 {xi≥0} 1 {xj ≥0} . Similar to the proof of Proposition 2, we consider ξ k := v 0,k 1 {u 0,k ≥0} which is a sub-gaussian random variable. Hence ξ 2 k is sub-exponential so that ξ 2 k ψ1 ≤ C where C > 0 is an absolute constant. By Bernstein's inequality, for every t ≥ 0, we have P m i=1 ξ 2 i - m 2 ≥ t ≤ 2 exp -c min t 2 m i=1 ξ 2 i 2 ψ1 , t max i ξ 2 i ψ1 , where c > 0 is an absolute constant. Letting t = m/4, we have with probability at least 1 -2 exp (-m/c ), m 4 ≤ m i=1 ξ 2 i ≤ 3m 4 , where c = c/(4C). The same analysis applies to ζ k := u 0,k 1 {u 0,k ≥0} as well and we have with probability at least 1 -2 exp (-m/c ), m 4 ≤ m i=1 ζ 2 i ≤ 3m 4 . As a result, we have probability at least 1 -4 exp (-m/c ), λ 1 (0) p 1 2 = 1 m m i=k (u 2 0,k + v 2 0,k ) 1 {u k (0)≥0} 2 n i,j=1 x 2 i x 2 j 1 {xi≥0} 1 {xj ≥0} ∈   1 2 n i,j=1 x 2 i x 2 j 1 {xi≥0} 1 {xj ≥0} , 3 2 n i,j=1 x 2 i x 2 j 1 {xi≥0} 1 {xj ≥0}   . Applying the same analysis to λ 2 (0), we have with probability 1 -4 exp (-m/c ), λ 2 (0) p 2 2 = 1 m m i=k (u 2 0,k + v 2 0,k ) 1 {u k (0)≤0} 2 n i,j=1 x 2 i x 2 j 1 {xi≤0} 1 {xj ≤0} ∈   1 2 n i,j=1 x 2 i x 2 j 1 {xi≤0} 1 {xj ≤0} , 3 2 n i,j=1 x 2 i x 2 j 1 {xi≤0} 1 {xj ≤0}   . The largest eigenvalue is max{λ 1 (0), λ 2 (0)}. Combining the results together, we have with probability at least 1 -4 exp (-m/c ), 1 2 M ≤ K(0) ≤ 3 2 M, where M = max n i,j=1 x 2 i x 2 j 1{xi≥0}1{xj≥0} n i=1 x 2 i 1{xi≥0} , n i,j=1 x 2 i x 2 j 1{xi≤0}1{xj≤0} n i=1 x 2 i 1{xi≤0} . F SCALE ANALYSIS FOR R λ AND R g Proposition 3. Let ρ := |g(t)|/ √ m. Assume m 1 and |λ(t)| ≤ C for some constant C > 0. If ρ 1, i.e., |g(t)| = o( √ m), then (i) |R g (t)| ≤ x 2 η 2 d ρ 2 + x 2 η 2 |y| d √ m ρ = o(1), (ii) |R λ (t)| ≤ (η 2 λ(t)-4η) x 2 d ρ 2 + 2η 2 x 2 |y|λ(t)+4η x 2 |y| d √ m ρ + η 2 x 2 y 2 λ(t) dm = o(1). Proof of Proposition 3. According to dynamics equations, i.e., Eq. ( 11) and ( 12), R g (t) = x 2 md η 2 (g(t) -y)g(t), R λ (t) = η x 2 md (g(t) -y) 2 ηλ(t) -4 g(t) g(t) -y . Let ρ = o(1). Then with simple application of the triangle inequality, we have |R g (t)| = x 2 η 2 g(t) 2 md - x 2 η 2 g(t)y md ≤ x 2 η 2 d ρ 2 + x 2 η 2 |y| d √ m ρ = o(1), and |R λ (t)| = x 2 η 2 md (g(t) -y) 2 λ(t) - 4 x 2 η md g(t)(g(t) -y) = η 2 λ(t) x 2 -4η x 2 md g(t) 2 + 4 x 2 ηy -2 x 2 η 2 yλ(t) md g(t) + x 2 η 2 y 2 λ(t) md ≤ (η 2 λ(t) -4η) x 2 d ρ 2 + 2η 2 x 2 |y|λ(t) + 4η x 2 |y| d √ m ρ + η 2 x 2 y 2 λ(t) dm = o(1). Proposition 4. Let ρ := |g(t)|/ √ m and assume m 1. If ρ ∈ [C 1 , C 2 ] for some constants C 1 , C 2 > 0, i.e., |g(t)| = Θ( √ m) then (i) R g (t) ∈ x 2 η 2 C 2 1 d -, x 2 η 2 C 2 2 d + , (ii) if λ ≤ 4/η, R λ (t) ∈ x 2 C 2 2 η(ηλ(t)-4) d -, x 2 C 2 1 η(ηλ(t)-4) d + , otherwise, R λ (t) ∈ x 2 C 2 1 η(ηλ(t)-4) d -, x 2 C22η(ηλ(t)-4) d + , where = O(1/ √ m). Proof of Proposition 4. According to dynamics equations, i.e., Eq. ( 11) and ( 12), R g (t) = x 2 md η 2 (g(t) -y)g(t), R λ (t) = η x 2 md (g(t) -y) 2 ηλ(t) -4 g(t) g(t) -y . Let C 1 ≤ ρ ≤ C 2 . Then with simple application of the triangle inequity, we have R g (t) = x 2 η 2 g(t) 2 md - x 2 η 2 g(t)y md = x 2 η 2 d ρ 2 - x 2 η 2 y √ md ρ. Then x 2 η 2 d C 2 1 - x 2 η 2 y √ md C 2 ≤ R g (t) ≤ x 2 η 2 d C 2 2 - x 2 η 2 y √ md C 1 . And R λ (t) = x 2 η 2 md (g(t) -y) 2 λ(t) - 4 x 2 η md g(t)(g(t) -y) = η 2 λ(t) x 2 -4η x 2 d ρ 2 + 4 x 2 ηy -2 x 2 η 2 yλ(t) √ md ρ + x 2 η 2 y 2 λ(t) md . If λ(t) ≤ 4/η, we have R λ (t) ≥ η 2 λ(t) x 2 -4η x 2 d C 2 2 - 4 x 2 ηy + 2 x 2 η 2 yλ(t) √ md C 2 , R λ (t) ≤ η 2 λ(t) x 2 -4η x 2 d C 2 1 + 4 x 2 ηy + 2 x 2 η 2 yλ(t) √ md C 1 + x 2 η 2 y 2 λ(t) md . If λ(t) ≥ 4/η, we have R λ (t) ≥ η 2 λ(t) x 2 -4η x 2 d C 2 1 - 4 x 2 ηy + 2 x 2 η 2 yλ(t) √ md C 2 , R λ (t) ≤ η 2 λ(t) x 2 -4η x 2 d C 2 2 + 4 x 2 ηy + 2 x 2 η 2 yλ(t) √ md C 1 + x 2 η 2 y 2 λ(t) md . Picking = max 4 x 2 ηy+2 x 2 η 2 yλ(t) √ md C 2 + x 2 η 2 y 2 λ(t) md , x 2 η 2 y √ md C 2 , we have the result. G SCALE ANALYSIS FOR R K AND R g Proposition 5. Let ρ := g(t) / √ m. Assume m 1 and λ 1 (t) ≤ C for some constant C > 0. If ρ 1, i.e. g(t) = o( √ m), then (i) R g (t) ≤ 2η 2 (ρ 2 + y ρ/ √ m) i x 2 i = o(1), (ii) R K (t) ≤ 2η 2 C(ρ 2 + y ρ/ √ m) i x 2 i + 8η(ρ 2 + y ρ/ √ m) i x 2 i = o(1). Proof of Proposition 5. According to dynamics equations for multiple training examples, i.e., Eq. ( 14) and ( 15), we have R g (t) = η 2 m (g(t) -y) m + ) T (g(t) m + )p 1 p T 1 + η 2 m (g(t) -y) m -) T (g(t) m -)p 2 p T 2 , R K (t) = η 2 m ((g(t) -y) m + ) T K(t) ((g(t) -y) m + ) p 1 p T 1 + η 2 m ((g(t) -y) m -) T K(t) ((g(t) -y) m -) p 2 p T 2 - 4η m ((g(t) -y) m + ) T (g(t) m + ) p 1 p T 1 - 4η m ((g(t) -y) m -) T (g(t) m -) p 2 p T 2 . Let ρ = o(1). With simple application of the triangle inequality, we have R g (t) ≤ 2η 2 (ρ 2 + y ρ/ √ m) i x 2 i = o(1), since p 1 2 ≤ x 2 i and p 2 2 ≤ x 2 i . And R K (t) ≤ 2η 2 C(ρ 2 + y ρ/ √ m) i x 2 i + 8η(ρ 2 + y ρ/ √ m) i x 2 i = o(1). Proposition 6. Let ρ := g(t) / √ m. Assume m 1 and η < 4/λ 1 (t). If ρ ∈ [C 1 , C 2 ] for some constants C 1 , C 2 > 0, i.e., g(t) = Θ( √ m), then (i) R g (t) ∈ min{ p 1 2 C 2 1 , p 2 2 C 2 2 }η 2 + , max{ p 1 2 C 2 1 , p 2 2 C 2 2 }η 2 + , (ii) R K (t) ∈ -min{ p 1 2 C 2 1 , p 2 2 C 2 2 }η K(t) -4ηI -, -max{ p 1 2 C 2 1 , p 2 2 C 2 2 }η K(t) -4ηI + , where = O(1/ √ m). Proof of Proposition 6. According to dynamics equations for multiple training examples, i.e., Eq. ( 14) and ( 15), we have R g (t) = η 2 m (g(t) -y) m + ) T (g(t) m + )p 1 p T 1 + η 2 m (g(t) -y) m -) T (g(t) m -)p 2 p T 2 , R K (t) = η 2 m ((g(t) -y) m + ) T K(t) ((g(t) -y) m + ) p 1 p T 1 + η 2 m ((g(t) -y) m -) T K(t) ((g(t) -y) m -) p 2 p T 2 - 4η m ((g(t) -y) m + ) T (g(t) m + ) p 1 p T 1 - 4η m ((g(t) -y) m -) T (g(t) m -) p 2 p T 2 . Note that g(t) m + + g(t) m -= g(t). We further denote ρ + := g(t) m + / √ m and ρ -:= g(t) m -/ √ m. Then it is not hard to see that ρ 2 + + ρ 2 -= ρ 2 . And we have R g (t) = η 2 /m g(t) m + 2 p 1 p T 1 + η 2 /m g(t) m - 2 p 2 p T 2 -η 2 /m(y m+) T (g(t) m + )p 1 p T 1 -η 2 /m(y m-) T (g(t) m -)p 2 p T 2 . Therefore min{ p 1 2 C 2 1 , p 2 2 C 2 2 }η 2 + O 1 √ m ≤ R g (t) ≤ max{ p 1 2 C 2 1 , p 2 2 C 2 2 }η 2 + O 1 √ m . For R K (t), since the top eigenvalue of K(t) -4ηI is negative by our assumption, we have R K (t) ≤ -min{ p 1 2 C 2 1 , p 2 2 C 2 2 }η K(t) -4ηI + O(1/ √ m), R K (t) ≥ -max{ p 1 2 C 2 1 , p 2 2 C 2 2 }η K(t) -4ηI -O(1/ √ m). H PROOF OF THEOREM 2 AND ANALYSIS ON OPTIMIZATION DYNAMICS FOR MULTIPLE TRAINING EXAMPLES By Eq. ( 5), the tangent kernel K at step t is defined as: K i,j (t) = ∇ v g i (t), ∇ v g j (t) + ∇ u g i (t), ∇ u g j (t) = 1 m m k=1 (u k (t)) 2 + (v k (t)) 2 x i x j 1 {u k (0)xi≥0} 1 {u k (0)xj ≥0} , ∀i, j ∈ [n]. Similar to single example case, the largest eigenvalue of of tangent kernel is bounded from 0: Proposition 7. For any δ ∈ (0, 1), if m ≥ c log(4/δ) where c is an absolute constant, with probability at least 1 -δ, M/2 ≤ λ max (K(0)) ≤ 3M/2 where M = max n i,j=1 x 2 i x 2 j 1 {x i ≥0} 1 {x j ≥0} n i=1 x 2 i 1 {x i ≥0} , n i,j=1 x 2 i x 2 j 1 {x i ≤0} 1 {x j ≤0} n i=1 x 2 i 1 {x i ≤0} . The proof can be found in Appendix E. For the simplicity of notation, given p, m ∈ R n , we define the matrices K p,m and Q p,m : K p,m (t) := ((g(t) -y) m) T K(t) ((g(t) -y) m) pp T , Q p,m (t) := ((g(t) -y) m) T (g(t) m) pp T It is not hard to see that for all t, K p,m and Q p,m are rank-1 matrices. Specially, p is the only eigenvector of K p,m and Q p,m . With the above notations, we can write the update equations for g(t) -y and K(t) during gradient descent with learning rate η: Dynamics equations. g(t + 1) -y =     I -ηK(t) + η 2 m Q p1,m+ (t) + Q p2,m-(t) Rg(t)     (g(t) -y), K(t + 1) = K(t) + η 2 m K p1,m+ (t) + K p2,m-(t) - 4η m Q p1,m+ (t) + Q p2,m-(t) R K (t) , where m + , m -∈ R n are mask vectors: m +,i = 1 {i∈S+} , m -,i = 1 {i∈S-} . Now we are ready to discuss different three optimization dynamics for multiple training examples case, similar to single training example case in the following. Monotonic convergence: sub-critical learning rates (η < 2/λ 1 (0)). We use the key observation that when g(t) is small, i. At initialization, g(0) = O(1) with high probability over random initialization. By Proposition 5, when g(t) = o( √ m), the optimization follows linear dynamics. By the choice of the learning rate, we will have for all t ≥ 0, I -ηK(t) < 2, hence g(t) -y decreases exponentially. The cumulative change on the norm of tangent kernel is o(1) since R K (t) = o(1) and the loss decreases exponentially hence R K (t) = o(1) • log O(1) = o(1). Catapult convergence: super-critical learning rates (2/λ 1 (0) < η < min{2/λ 2 (0), 4/λ 1 (0)}). Restate Theorem 2: Consider training a NQM Eq. ( 10) with squared loss on multiple training examples by gradient descent. Under Assumption 1, if the learning rate is super-critical i.e., 2/λ 1 (0) < η < min{2/λ 2 (0), 4/λ 1 (0)}, then there exist T 1 , T 2 , T 3 , T 4 such that 0 < T 1 < T 2 < T 3 < T 4 and the training dynamics exhibits: (i)  (t) -λ k (0)| = o(1) for k = 1, 2. (ii) Peak phase: t ∈ [T 2 , T 3 ]. In this phase, L(t) = Θ(m). For the tangent kernel: (a) If 2/λ 2 (0) ≤ 4/λ 1 (0), both eigenvalues decrease significantly, i.e., λ 1 (t+1)-λ 1 (t) < 0, λ 2 (t + 1) -λ 2 (t) < 0 and both difference are of the order Θ(1). (b) If 2/λ 2 (0) > 4/λ 1 (0), only λ 1 (t) decreases significantly. Proof of Theorem 2. We assume 2/λ 2 (0) ≤ 4/λ 1 (0), since in this scenario, the catapult phase happens in both directions p 1 , p 2 , i.e., g(t) has significant projection in both directions at the peak of the loss. In fact, as p 1 is orthogonal to p 2 , the training dynamics in two directions are almost independent to each other due to the special structure of the Hessian. If instead 2/λ 2 (0) > 4/λ 1 (0), the catapult phase mainly happens in the direction p 1 since in the direction of p 2 the linear dynamics dominate. In this simpler setting, the analysis will be implied by our following analysis. (i) Increasing phase. Initially g(t) -y grows exponentially following linear dynamics by Proposition 5. By the choice of the learning rate, we will have for t in the increasing phase, p T 1 (I -ηK(t))p 1 / p 1 2 = |2 -ηλ 1 (t)| = |2 -ηλ 1 (0) + o(1)| > 1. And following the same analysis, we have |p T 2 (I -ηK(t))p 2 |/ p 2 2 > 1 as well. Therefore, g(t)y increases in the direction p 1 and p 2 (If instead 2/λ 2 (0) > 4/λ 1 (0), |p T 2 (I -ηK(t))p 2 |/ p 2 2 < 1 therefore g(t)-y only grows in the direction p 1 ). Note that when both |p T 1 (I -ηK(t))p 1 |/ p 1 2 and |p T 2 (I -ηK(t))p 2 |/ p 2 2 are smaller than 1, the loss stops increasing hence enters the peak phase. Up to the iteration when g(t) = o( √ m), the cumulative change of λ 1 (t) and λ 2 (t) from initialization will be o(1) since g(t) increases exponentially. Specifically, for λ 1 (t) |λ 1 (t) -λ 1 (0)| = T1 t=0 |p T 1 R K (t)p 1 |/ p 1 2 = T2 t=0 Θ g(t) 2 /m = o(1). The same analysis can be applied to λ 2 (t) as well which gives |λ 2 (t) -λ 2 (0)| = o(1). Furthermore, it is not hard to see that until the loss grows to the order of Θ(m), the tangent kernel does not change much hence the loss keeps increasing exponentially. Proposition 8. With learning rate 2 λ(0) < η < 4 λ(0) , if Σ(x) 2 = x 2 • I, g(w) exhibits catapult phase. Proof. With simple computation, we get g(t + 1) = 1 -ηλ(t) + γη 2 x 2 (g(t)) 2 g(t), λ(t + 1) = λ(t) + γ x 2 (g(t)) 2 (ηλ(t) -4). We note that the evolution of g and λ is almost the same with Eq. ( 11) and Eq. ( 12) if we regard γ = 1/m. Hence we can apply the same analysis to show the catapult phase phenomenon. It is worth pointing out that the two-layer linear neural network with input x ∈ R d analyzed in Lewkowycz et al. (2020) that f (U, v; x) = 1 √ m v T Ux, where v ∈ R m , U ∈ R m×d is a special case of our model with w = Vec(U) T , v T T , γ = 1/ √ m and Σ = 0 I m ⊗ x I m ⊗ x T 0 ∈ R md+m .

J EXPERIMENTAL SETTINGS AND ADDITIONAL RESULTS

J.1 VERIFICATION OF NON-LINEAR TRAINING DYNAMICS OF NQMS, I.E., FIGURE 3 We train the NQM which approximates the two-layer fully-connected neural network with ReLU activation function on 128 data points where each input is drawn i.i.d. from N (-2, 1) if the label is -1 or N (2, 1) if the label is 1. The network width is 5, 000.

J.2 EXPERIMENTS FOR TRAINING DYNAMICS OF WIDE NEURAL NETWORKS WITH MULTIPLE

EXAMPLES. We train a two-layer fully-connected neural network with ReLU activation function on 128 data points where each input is drawn i.i.d. from N (-2, 1) if the label is -1 or N (2, 1) if the label is 1. The network width is 5, 000. See the results in Figure 5 . 3 , the behaviour of of top eigenvalues is almost the same with different learning rates: when η < 0.37, the kernel is nearly constant; when 0.37 < η < 0.39, only λ 1 (t) decreases; when 0.39 < η < η max , both λ 1 (t) and λ 2 (t) decreases. See the experiment setting in Appendix J.2. 



Unless stated otherwise, we always consider the setting where models are trained with squared loss using gradient descent. For non-differentiable functions, e.g. neural networks with ReLU activation functions, we define the gradient based on the update rule used in practice. Similarly, we use H f to denote the second derivative of f in Eq. (2). For general quadratic models in Eq. (3), the transition to linearity is continuously controlled by γ.



Figure 1: Optimization dynamics for linear and non-linear models based on choice of learning rate. (a) Linear models either converge monotonically if learning rate is less than η crit and diverge otherwise. (b) Unlike linear models, finitely wide neural networks and NQMs Eq. (2) (or general quadratic models Eq. (3)) can additionally observe a catapult phase when η crit < η < η max .

(i) Increasing phase. At the beginning, |g(t)| grows exponentially following linear dynamics. Specifically, since |g(0)| = O(1), by Proposition 3, we have |R g (0)| = o(1) and |R λ (0)| = o(1).

Figure 3: Training dynamics of NQMs for multiple examples case with different learning rates.By our analysis, two critical values are 2/λ 1 (0) = 0.37 and 2/λ 2 (0) = 0.39. When η < 0.37, linear dynamics dominate hence the kernel is nearly constant; when 0.37 < η < 0.39, the catapult phase happens in p 1 and only λ 1 (t) decreases; when 0.39 < η < η max , the catapult phase happens in p 1 and p 2 hence both λ 1 (t) and λ 2 (t) decreases. The experiment details can be found in Appendix J.1.

Figure 4: Best test loss plotted against different learning rates for f (w), f lin (w) and f quad (w) across a variety of datasets and network architectures.

iii) Decreasing phase: t ∈ [T 4 , ∞). In this phase, the loss satisfies L(t) = o(m) again and decreases. Both eigenvalues are nearly constant until convergence:|λ k (t)-λ k (∞)| = o(1), for k = 1, 2.Furthermore, T 1 = o(log m), T 2 = Θ(log m), T 3 -T 2 = Θ(1) and T 4 = Θ(log m).

Figure 5: Training dynamics of wide neural networks for multiple examples case with different learning rates. Compared to the training dynamics of NQMs, i.e., Figure3, the behaviour of of top eigenvalues is almost the same with different learning rates: when η < 0.37, the kernel is nearly constant; when 0.37 < η < 0.39, only λ 1 (t) decreases; when 0.39 < η < η max , both λ 1 (t) and λ 2 (t) decreases. See the experiment setting in Appendix J.2.

Figure12: Best test loss plotted against different learning rates for f quad , f , and f lin . We choose 2-layer FC as the architecture and train the models on AG NEWS with GD.

as an example) and deep networks, have better test performances when catapult dynamics happens. While this was observed for some synthetic examples of neural networks inLewkowycz et al. (2020), we systematically demonstrate the improved generalization of NQMs across a range of experimental settings. Namely, we consider fully-connected and convolutional neural networks with ReLU and other activation functions trained with GD/SGD on multiple vision, speech and text datatsets (See Section 4).

(see a formal statement Proposition 5 in Appendix G). Then the dynamics equations approximately reduce to the ones of linear dynamics for multiple training examples

annex

(ii) Peak phase. We will use the key observation that when g(t) is large, i.e., of the order Θ( √ m), R K (t) and R g (t) will be of the order Θ(1) (See the formal statement Proposition 6 in Appendix G), which can lead to the decrease of the loss.When g(t) increases to the order of Θ( √ m), since g(t) only grows in direction p 1 and p 2 , we can see that g(t) is mostly aligned with p 1 and p 2 , i.e., g(t) T p 1 / p 1 + g(t) T p 2 / p 2 ≈ g(t) . Here by our assumption y = O(1) which is small compared to the scale of g(t) hence we can omit it. Since η is initially chosen to be in the interval (2/λ 1 (0), 4/λ 1 (0)) and λ 1 (t) = λ 1 (0) + o(1) when g(t) = o( √ m), we in general have λ 1 (t) < 4/η at the begining of the peak phase. Therefore, by Proposition 6, the tangent kernel decreases significantly in direction p 1 i.e., λ 1 (t) decreases. Specifically, we haveThe same analysis works for the direction p 2 then we have p T 2 K(t + 1)p 2 < p T 2 K(t)p 2 , i.e., λ 2 (t + 1) < λ 2 (t). Similarly, as g(t) increases, R g (t) increases as well by Proposition 6. Therefore, the factor I -ηK(t) + R g (t) on g(t) -y decreases which slows down the increase of the loss.In general, the loss grows exponentially prior to the peak phase hence T 2 = Θ(log m). And the peak phase only lasts Θ(1) steps during training, i.e., |T 3 -T 2 | = Θ(1), since the decrease of λ 1 (t) and λ 2 (t) is Θ(1) at each step.(iii) Decreasing phase. With the tangent kernel K(t) decreasing and g(t) -y increasing in the direction p 1 and p 2 , the factor I -ηK(t) + R g (t) will ultimately smaller than 1 in both directions, which makes the peak phase ends since g(t) -y starts to decrease. Again, due to the large scale of g(t) , i.e. Θ( √ m), the tangent kernel still decreases significantly.Similar to the single training example case, the decrease of λ 1 (t) and λ 2 (t) stops when g(t) decreases to the order o( √ m), as R K (t) = o(1). In general we have I -ηK(t) smaller than 1 in both directions. Hence by Proposition 5, the training dynamics become linear that the loss decreases monotonically and the tangent kernel is nearly constant until convergence.Divergence: (η > η max = 4/λ 1 (0)). Similar to the increasing phase in the catapult convergence, initially g(t) -y increases in direction p 1 and p 2 since linear dynamics dominate and the learning rate is chosen to be larger than η crit . Also, we approximately have η > 4/λ 1 (t) at the end of the increasing phase, by a similar analysis for the catapult convergence. We consider the evolution of K(t) in the direction p 1 . Note that when g(t) increases to the order of Θ( √ m), g(t) m + will be aligned with p 1 , hence with simple calculation, we approximately havebecomes even larger which makes g(t) -y grows faster, and ultimately leads to divergence of the optimization.

I SPECIAL CASE OF QUADRATIC MODELS WHEN φ(x) = 0

In this section we will show under some special settings, the catapult phase phenomenon also happens and how two layer linear neural networks fit in our quadratic model. We consider one training example (x, y) with label y = 0 and assume the initial tangent kernel λ(0) = Ω(1). Letting the feature vector φ(x) = 0, the quadratic model Eq.(3) becomes:For this quadratic model, we have the following proposition:J.3 TRAINING DYNAMICS OF GENERAL QUADRATIC MODELS AND NEURAL NETWORKS.As discussed at the end of Section 3, a more general quadratic model can exhibit the catapult phase phenomenon. Specifically, we consider a general quadratic model:We will train the general quadratic model with different learning rates, and different γ respectively, to see how the catapult phase phenomenon depends on these two factors. For comparison, we also implement the experiments for neural networks. See the experiment setting in the following:General quadratic models. We set the dimension of the input d = 100. We let the feature vector φ(x) = x/ x where x i ∼ N (0, 1) i.i.d. for each i ∈ [d]. We let Σ be a diagonal matrix with Σ i,i ∈ {-1, 1} randomly and independently. The weight parameters w are initialized by N (0, I d ).Unless stated otherwise, γ = 10 -3 , and the learning rate is set to be 2.8.Neural networks. We train a two-layer fully-connected neural networks with ReLU activation function on 20 data points of CIFAR-2. Unless stated otherwise, the network width is 10 4 , and the learning rate is set to be 2.8.See the results in Figure 6 . For the architectures of two-layer fully connected neural network and two-layer convolutional neural network, we set the width to be 5, 000 and 1, 000 respectively. Specific to Figure 2 (b), we use the architecture of a two-layer fully connected neural network.Due to the large number of parameters in NQMs, we choose a small subset of all the datasets. We use the first class (airplanes) and third class (birds) of CIFAR-10, which we call CIFAR-2, and select 256 data points out of it as the training set. We use the number 0 and 2 of SVHN, and select 256 data points as the training set. We select 128, 256, 128 data points out of MNIST, FSDD and AG NEWS dataset respectively as the training sets. The size of testing set is 2, 000 for all. When implementing SGD, we choose batch size to be 32.For each setting, we report the average result of 5 independent runs. In this section, we report the best test accuracy for f , f lin and f quad corresponding to the best test loss in Figure 4 . We use the same setting as in Appendix J.4. J.6 TEST PERFORMANCE OF f , f lin AND f quad WITH ARCHITECTURE OF 3-LAYER FCIn this section, we extend our results for shallow neural networks discussed in Section 4 to 3-layer fully connected neural networks. In the same way, we compare the test performance of three models, f , f lin and f quad upon varying learning rate. We observe the same phenomenon for 3-layer ReLU activated FC with shallow neural networks. See Figure 10 and 11.We use the first class (airplanes) and third class (birds) of CIFAR-10, which we call CIFAR-2, and select 100 data points out of it as the training set. We use the number 0 and 2 of SVHN, and select 100 data points as the training set. We select 100 data points out of AG NEWS dataset as the training set. For the speech data set FSDD, we select 100 data points in class 1 and 3 as the training set. The size of testing set is 500 for all.For each setting, we report the average result of 5 independent runs. 

