CRITICAL INITIALIZATION OF WIDE AND DEEP NEU-RAL NETWORKS THROUGH PARTIAL JACOBIANS: GENERAL THEORY AND APPLICATIONS

Abstract

Deep neural networks are notorious for defying theoretical treatment. However, when the number of parameters in each layer tends to infinity, the network function is a Gaussian process (GP) and quantitatively predictive description is possible. Gaussian approximation allows to formulate criteria for selecting hyperparameters, such as variances of weights and biases, as well as the learning rate. These criteria rely on the notion of criticality defined for deep neural networks. In this work we describe a new practical way to diagnose criticality. We introduce partial Jacobians of a network, defined as derivatives of preactivations in layer l with respect to preactivations in layer l 0 ≤ l. We derive recurrence relations for the norms of partial Jacobians and utilize these relations to analyze criticality of deep fully connected neural networks with LayerNorm and/or residual connections. We derive and implement a simple and cheap numerical test that allows one to select optimal initialization for a broad class of deep neural networks; including fully connected, convolutional and attention layers. Using these tools we show quantitatively that proper stacking of the LayerNorm (applied to preactivations) and residual connections leads to an architecture that is critical for any initialization. Finally, we apply our methods to analyze the MLP-Mixer architecture and show that it is everywhere critical.

1. INTRODUCTION

When the number of parameters in each layer becomes large, the functional space description of deep neural networks simplifies dramatically. The network function, f (x), in this limit, is a Gaussian process (Neal, 1996; Lee et al., 2018) with a kernel -sometimes referred to as neural network Gaussian process (NNGP) kernel (Lee et al., 2018) -determined by the network architecture and hyperparameters (e.g depth, precise choices of layers and the activation functions, as well as the distribution of weights and biases). Similar line of reasoning was earlier developed for recurrent neural networks (Molgedey et al., 1992) . Furthermore, for special choices of parameterization and MSE loss function, the training dynamics under gradient descent can be solved exactly in terms of the neural tangent kernel (NTK) (Jacot et al., 2018; Lee et al., 2019) . A large body of work was devoted to the calculation of the NNGP kernel and NTK for different architectures, calculation of the finite width corrections to these quantities, and empirical investigation of the training dynamics of wide networks (Novak et al., 2018b; Xiao et al., 2018; Hron et al., 2020; Dyer & Gur-Ari, 2019; Andreassen & Dyer, 2020; Lewkowycz & Gur-Ari, 2020; Aitken & Gur-Ari, 2020; Geiger et al., 2020; Hanin, 2021; Roberts et al., 2022; Yaida, 2020; Shankar et al., 2020; Arora et al., 2019b; a; Lee et al., 2020; Yang et al., 2018; Yang & Hu, 2021; Yang, 2019b; a; Matthews et al., 2018; Garriga-Alonso et al., 2018; Allen-Zhu et al., 2019; Tsuchida et al., 2021; Martens et al., 2021) . One important result that arose from these works is that the network architecture determines the most appropriate initialization of the weights and biases (Poole et al., 2016; Schoenholz et al., 2016; Lee et al., 2018) . To state this result, we consider networks with/without LayerNorm (Ba et al., 2016) and residual connections (He et al., 2016) ; the preactivations for which can be defined as follows h l+1 i (x) = N l j=1 w l+1 ij ϕ( hl j (x)) + b l+1 i + µh l i (x) , where hl j = LayerNorm(h l j ) and the parameter µ controls the strength of residual connections. For the input layer: h 1 i (x) = N0 j=1 w 1 ij x j + b 1 i . In the (l + 1)-th layer, weights w l+1 ij ∈ R N l+1 ×N l and biases b l+1 i ∈ R N l+1 ×1 are taken from normal distributions N (0, σ 2 w /N l ) and N (0, σ 2 b ), respectively. Hyperparameters σ w and σ b need to be tuned. ϕ(•) is the activation function and x ∈ R N0×1 is the input. For results discussed in this work, x can be sampled from either a realistic (i.e. highly correlated) dataset or a high entropy distribution. For a network of depth L, the network function is given by f (x) = h L (x). Different network architectures and activation functions, ϕ, lead to different "optimal" choices of (σ w , σ b ). The optimal choice can be understood, using the language of statistical mechanics, as a critical point (or manifold) in the σ b -σ w plane. The notion of criticality becomes sharp as the network depth, L, becomes large. Criticality ensures that both NNGP and the norm of gradients remain O(L 0 ) as the network gets deeper (Roberts et al., 2022) . Very deep networks will not train unless initialized critically, since the gradients explode or vanish exponentially. Moreover, high trainability does not imply that the trained model has a great performance (test accuracy) after training.

1.1. RESULTS

Here we focus on two main results of this work: (i) empirical method to check criticality of a neural network and (ii) an architecture based on layer normalization and residual connections that is critical for any initialization. First we introduce the notion of a partial Jacobian. Definition 1.1. Let h l i (x) be preactivations of a neural network f (x). The partial Jacobian J l0,l ij is defined as derivative of preactivations at layer l with respect to preactivations at layer l 0 ≤ l J l0,l ij (x) = ∂h l j (x) ∂h l0 i (x) . The partial Jacobian is a random matrix with vanishing mean at initialization. We introduce a deterministic measure of the magnitude of J l0,l ij -its squared Frobenius norm, averaged over parameterinitializations. Definition 1.2. Let J l0,l ij be a partial Jacobian of a neural network f (x). Averaged partial Jacobian norm (APJN) is defined as J l0,l (x) ≡ E θ   1 N l N l j=1 N l 0 i=1 ∂h l j (x) ∂h l0 i (x) 2   , where E θ indicates averaging over parameter-initializations. In what follows, we show that criticality, studied previously in literature, occurs when APJN either remains finite, or varies algebraically as l becomes large. To prove this we derive the recurrence relation for J l0,l (x) in the limit N l → ∞ and analyze it at large depth. Algebraic behaviour of APJN with depth is characterized by an architecture-dependent critical exponent, ζ, so that J l0,l (x) ≈ l -ζ . Such behaviour is familiar from statistical mechanics when a system is tuned to a critical point (Cardy, 1996) . Away from criticality, there are two phases: ordered and chaotic. In the ordered phase APJN vanishes exponentially with depth, whereas in the chaotic phase APJN grows exponentially J l0,l ≈ c l0 e ± l ξ . (4) Here ξ is the correlation length. It characterizes how fast gradients explode or vanish. Theorem 1.3 (Main result). Let f (x) be a deep MLP network with Lipschitz continuous activation ϕ(•). Assume that the LayerNorm is applied to preactivations and there are residual connections with strength µ acting according to (1). In the limit N l → ∞ the correlation length is bounded from below for σ 2 b < ∞ ξ ≥ 1 | log (1 -µ 2 ) A B + µ 2 | ,

