PREDICTING THE OUTPUTS OF FINITE NETWORKS TRAINED WITH NOISY GRADIENTS Anonymous authors Paper under double-blind review

Abstract

A recent line of works studied wide deep neural networks (DNNs) by approximating them as Gaussian Processes (GPs). A DNN trained with gradient flow was shown to map to a GP governed by the Neural Tangent Kernel (NTK), whereas earlier works showed that a DNN with an i.i.d. prior over its weights maps to the socalled Neural Network Gaussian Process (NNGP). Here we consider a DNN training protocol, involving noise, weight decay and finite width, whose outcome corresponds to a certain non-Gaussian stochastic process. An analytical framework is then introduced to analyze this non-Gaussian process, whose deviation from a GP is controlled by the finite width. Our contribution is three-fold: (i) In the infinite width limit, we establish a correspondence between DNNs trained with noisy gradients and the NNGP, not the NTK. (ii) We provide a general analytical form for the finite width correction (FWC) for DNNs with arbitrary activation functions and depth and use it to predict the outputs of empirical finite networks with high accuracy. Analyzing the FWC behavior as a function of n, the training set size, we find that it is negligible for both the very small n regime, and, surprisingly, for the large n regime (where the GP error scales as O(1/n)). (iii) We flesh-out algebraically how these FWCs can improve the performance of finite convolutional neural networks (CNNs) relative to their GP counterparts on image classification tasks.

1. INTRODUCTION

Deep neural networks (DNNs) have been rapidly advancing the state-of-the-art in machine learning, yet a complete analytic theory remains elusive. Recently, several exact results were obtained in the highly over-parameterized regime (N → ∞ where N denotes the width or number of channels for fully connected networks (FCNs) and convolutional neural networks (CNNs), respectively) (Daniely et al., 2016) . This facilitated the derivation of an exact correspondence with Gaussian Processes (GPs) known as the Neural Tangent Kernel (NTK) (Jacot et al., 2018) . The latter holds when highly over-parameterized DNNs are trained by gradient flow, namely with vanishing learning rate and involving no stochasticity. The NTK result has provided the first example of a DNN to GP correspondence valid after end-to-end DNN training. This theoretical breakthrough allows one to think of DNNs as inference problems with underlying GPs (Rasmussen & Williams, 2005) . For instance, it provides a quantitative description of the generalization properties (Cohen et al., 2019; Rahaman et al., 2018) and training dynamics (Jacot et al., 2018; Basri et al., 2019) of DNNs. Roughly speaking, highly over-parameterized DNNs generalize well because they have a strong implicit bias to simple functions, and train well because low-error solutions in weight space can be reached by making a small change to the random values of the weights at initialization. Despite its novelty and importance, the NTK correspondence suffers from a few shortcomings: (a) Its deterministic training is qualitatively different from the stochastic one used in practice, which may lead to poorer performance when combined with a small learning rate (Keskar et al., 2016) . (b) It under-performs, often by a large margin, convolutional neural networks (CNNs) trained with SGD (Arora et al., 2019) . (c) Deriving explicit finite width corrections (FWCs) is challenging, as it requires solving a set of coupled ODEs (Dyer & Gur-Ari, 2020; Huang & Yau, 2019) . Thus, there is a need for an extended theory for end-to-end trained deep networks which is valid for finite width DNNs. Our contribution is three-fold. First, we prove a correspondence between a DNN trained with noisy gradients and a Stochastic Process (SP) which at N → ∞ tends to the Neural Network Gaussian Process (NNGP) (Lee et al., 2018; Matthews et al., 2018) . In these works, the NNGP kernel is determined by the distribution of the DNN weights at initialization which are i.i.d. random variables, whereas in our correspondence the weights are sampled across the stochastic training dynamics, drifting far away from their initial values. We call ours the NNSP correspondence, and show that it holds when the training dynamics in output space exhibit ergodicity. Second, we predict the outputs of trained finite-width DNNs, significantly improving upon the corresponding GP predictions. This is done by deriving leading FWCs which are found to scale with width as 1/N . The accuracy at which we can predict the empirical DNNs' outputs serves as a strong verification for our aforementioned ergodicity assumption. In the regime where the GP RMSE error scales as 1/n, we find that the leading FWC are a decaying function of n, and thus overall negligible. In the small n regime we find that the FWC is small and grows with n. We thus conclude that finite-width corrections are important for intermediate values of n (Fig. 1 ). Third, we propose an explanation for why finite CNNs trained on image classification tasks can outperform their infinite-width counterparts, as observed by Novak et al. (2018) . The key difference is that in finite CNNs weight sharing is beneficial. Our theory, which accounts for the finite width, quantifies this difference ( §4.2). Overall, the NNSP correspondence provides a rich analytical and numerical framework for exploring the theory of deep learning, unique in its ability to incorporate finite over-parameterization, stochasticity, and depth. We note that there are several factors that make finite SGD-trained DNNs used in practice different from their GP counterparts, e.g. large learning rates, early stopping etc. (Lee et al., 2020) . Importantly, our framework quantifies the contribution of finite-width effects to this difference, distilling it from the contribution of these other factors.

1.1. RELATED WORK

The idea of leveraging the dynamics of the gradient descent algorithm for approximating Bayesian inference has been considered in various works (Welling & Teh, 2011; Mandt et al., 2017; Teh et al., 2016; Maddox et al., 2019; Ye et al., 2017) . However, to the best of our knowledge, a correspondence with a concrete SP or a non-parametric model was not established nor was a comparison made of the DNN's outputs with analytical predictions. Another recent paper (Yaida, 2020) studied Bayesian inference with weakly non-Gaussian priors induced by finite-N DNNs. Unlike here, there was no attempt to establish a correspondence with trained DNNs. The formulation presented here has the conceptual advantage of representing a distribution over function space for arbitrary training and test data, rather than over specific draws of data sets. This is useful for studying the large n behavior of learning curves, where analytical insights into generalization can be gained (Cohen et al., 2019) . A somewhat related line of work studied the mean field regime of shallow NNs (Mei et al., 2018; Chen et al., 2020; Tzen & Raginsky, 2020) . We point out the main differences from our work: (a) The NN output is scaled differently with width. (b) In the mean field regime one is interested in the dynamics (finite t) of the distribution over the NN parameters in the form of a PDE of the Fokker-Planck type. In contrast, in our framework we are interested in the distribution over function



width corrections were studied recently in the context of the NTK correspondence by several authors. Hanin & Nica (2019) study the NTK of finite DNNs, but where the depth scales together with width, whereas we keep the depth fixed. Dyer & Gur-Ari (2020) obtained a finite N correction to the linear integral equation governing the evolution of the predictions on the training set. Our work differs in several aspects: (a) We describe a different correspondence under different a training protocol with qualitatively different behavior. (b) We derive relatively simple formulae for the outputs which become entirely explicit at large n. (c) We account for all sources of finite N corrections whereas finite N NTK randomness remained an empirical source of corrections not accounted for by Dyer & Gur-Ari (2020). (d) Our formalism differs considerably: its statistical mechanical nature enables one to import various standard tools for treating randomness, ergodicity breaking, and taking into account non-perturbative effects. (e) We have no smoothness limitation on our activation functions and provide FWCs on a generic data point and not just on the training set.

