QUADRATIC MODELS FOR UNDERSTANDING NEURAL NETWORK DYNAMICS

Abstract

In this work, we show that recently proposed quadratic models capture optimization and generalization properties of wide neural networks that cannot be captured by linear models. In particular, we prove that quadratic models for shallow ReLU networks exhibit the "catapult phase" from Lewkowycz et al. ( 2020) that arises when training such models with large learning rates. We then empirically show that the behaviour of quadratic models parallels that of neural networks in generalization, especially in the catapult phase regime. Our analysis further demonstrates that quadratic models are an effective tool for analysis of neural networks.

1. INTRODUCTION

A recent remarkable finding on neural networks, originating from Jacot et al. (2018) and termed as the "transition to linearity" (Liu et al., 2020) , is that, as network width goes to infinity, such models become linear functions in the parameter space. Thus, a linear (in parameters) model can be built to accurately approximate wide neural networks under certain conditions. While this finding has helped improve our understanding of trained neural networks (Du et al., 2019; Nichani et al., 2021; Zou & Gu, 2019; Montanari & Zhong, 2020; Ji & Telgarsky, 2019; Chizat et al., 2019) , not all properties of finite width neural networks can be understood in terms of linear models, as is shown in several recent works (Yang & Hu, 2020; Ortiz-Jiménez et al., 2021; Long, 2021; Fort et al., 2020) . In this work, we show that properties of finitely wide neural networks in optimization and generalization that cannot be captured by linear models are, in fact, manifested in quadratic models. The training dynamics of linear models with respect to the choice of the learning ratesfoot_0 are wellunderstood (Polyak, 1987) . Indeed, such models exhibit linear training dynamics, i.e., there exists a critical learning rate, η crit , such that the loss converges monotonically if and only if the learning rate is smaller than η crit (see Figure 1a ). 3)) can additionally observe a catapult phase when η crit < η < η max .



Unless stated otherwise, we always consider the setting where models are trained with squared loss using gradient descent.1



Figure 1: Optimization dynamics for linear and non-linear models based on choice of learning rate. (a) Linear models either converge monotonically if learning rate is less than η crit and diverge otherwise. (b) Unlike linear models, finitely wide neural networks and NQMs Eq. (2) (or general quadratic models Eq. (3)) can additionally observe a catapult phase when η crit < η < η max .

