ON THE NEURAL TANGENT KERNEL OF EQUILIBRIUM MODELS

Abstract

This work studies the neural tangent kernel (NTK) of the deep equilibrium (DEQ) model, a practical "infinite-depth" architecture which directly computes the infinite-depth limit of a weight-tied network via root-finding. Even though the NTK of a fully-connected neural network can be stochastic if its width and depth both tend to infinity simultaneously, we show that contrarily a DEQ model still enjoys a deterministic NTK despite its width and depth going to infinity at the same time under mild conditions. Moreover, this deterministic NTK can be found efficiently via root-finding.

1. INTRODUCTION

Implicit models form a new class of machine learning models where instead of stacking explicit "layers", they output z s.t g(x, z) = 0, where g can be either a fixed point equation (Bai et al., 2019) , a differential equation (Chen et al., 2018b) , or an optimization problem (Gould et al., 2019) . This work focuses on deep equilibrium models, a class of models that effectively represent a "infinitedepth" weight-tied network with input injection. Specifically, let f θ be a network parameterized by θ, let x be an input injection, DEQ finds z * such that f (z * , x) = z * , and uses z * as the input for downstream tasks. One interesting question to ask is, what will DEQs become if their widths also go to infinity? It is well-known that at certain random initialization, neural networks of various structures converge to Gaussian processes as their widths go to infinity (Neal, 1996; Lee et al., 2017; Yang, 2019; Matthews et al., 2018; Novak et al., 2018; Garriga-Alonso et al., 2018) . Recent deep learning theory advances have also shown that in the infinite width limit, with proper initialization (the NTK initialization), training the network f θ with gradient descent is equivalent to solving kernel regression with respect to the neural tangent kernel (NTK) (Arora et al., 2019; Jacot et al., 2018; Yang, 2019; Huang et al., 2020) . These kernel regimes provide important insights to understanding how neural networks work. However, the infinite depth (denote depth as d) regime introduces several caveats. Since the NTK correlates with the infinite width (denote width as n) limit, a question naturally arises as how do we let n, d → ∞? Hanin & Nica (2019) proved that as long as d/n ∈ (0, ∞), the NTK of vanilla fully-connected neural network (FCNN) becomes stochastic. On the other hand, if we first take the n → ∞, then d → ∞ 1 , Jacot et al. (2019) showed that the NTK of a FCNN converges either to a constant (freeze), or to the Kronecker Delta (chaos). In this work, we prove that with proper initialization, the NTK-of-DEQ enjoys a limit exchanging property lim d→∞ lim n→∞ Θ (d) n (x, y) = lim n→∞ lim d→∞ Θ (d) n (x, y) with high probability, where Θ (d) n denotes the empirical NTK of a neural network with d layers and n neurons each layer. Intuitively, we name the left hand side "DEQ-of-NTK" and the right hand side "NTK-of-DEQ". The NTK-of-DEQ converges to meaningful deterministic fixed points that depend on the input in a non-trivial way, thus avoiding the freeze vs. chaos scenario. Furthermore, analogous to DEQ models, we can compute these kernels by solving fixed point equations, rather than iteratively applying the updates as for traditional NTK. We evaluate our approach and demonstrate that it matches the performance of existing regularized NTK methods. 1 The computed quantity is lim d→∞ limn→∞ Θ 

2. BACKGROUND AND PRELIMINARIES

A vanilla FCNN has the form g (t) = σ(W (t) g (t-1) + b (t) ) for the t-th layer, and in principle t can be as large as one wants. A weight-tied FCNN with input injection (FCNN-IJ) makes the bias term related to the original input and ties the weight in each layer by taking the form z (t) := f (z (t-1) , x) = σ(W z (t-1) + U x + b). Bai et al. ( 2019) proposed the DEQ model, which can be equivalent to running an infinite-depth FCNN-IJ, but updated in a more clever way. The forward pass of DEQ is done by solving f (z * , x) = z * . For a stable system, this is equivalent to solving lim t→∞ f (t) (z (0) , x). The backward iteration is done by computing df (z * , x)/dz * directly through the implicit function theorem, thus avoiding storing the Jacobian for each layer. This method traces back to some of the original work in recurrent backpropagation (Almeida, 1990; Pineda, 1988) , but with specific emphasis on: 1) computing the fixed point directly via root-finding rather than forward iteration; and 2) incorporating the elements from modern deep networks in the single "layer", such as self-attention transformers (Bai et al., 2019) , multi-scale convolutions (Bai et al., 2020), etc. DEQ models achieve nearly state-of-the-art performances on many large-scale tasks including the CityScape semantic segmentation and ImageNet classification, while only requiring constant memory. Although a general DEQ model does not always guarantee to find a stable fixed point, with careful parameterization and update method, monotone operator DEQs can ensure the existence of a unique stable fixed point (Winston & Kolter, 2020) . The study of large width limits of neural networks dates back to Neal (1996) , who first discovered that a single-layered network with randomly initialized parameters becomes a Gaussian process (GP) in the large width limit. Such connection between neural networks and GP was later extended to multiple layers (Lee et al., 2017; Matthews et al., 2018) and various other architectures (Yang, 2019; Novak et al., 2018; Garriga-Alonso et al., 2018) . The networks studied in this line of works are randomly initialized, and the GP kernels they induce are often referred to as the NNGP. A line of closely-related yet orthogonal work to ours is the mean-field theory of neural networks. This line of work studies the relation between depth and large-width networks (hence a GP kernel in limit) at initialization. Poole et al. (2016); Schoenholz et al. (2016) showed that at initialization, the correlations between all inputs on an infinitely wide network become either perfectly correlated (order) or decorrelated (chaos) as depth increases. They suggested we should initialize the neural network on the "edge-of-chaos" to make sure that signals can propagate deep enough in the forward direction, and the gradient does not vanish or explode during backpropagation (Raghu et al., 2017; Schoenholz et al., 2016) . These mean-field behaviors were later proven for various other structures like RNNs, CNNs, and NTKs as well (Chen et al., 2018a; Xiao et al., 2018; Gilboa et al., 2019; Hayou et al., 2019) . We emphasize that despite the similar appearance, our setting avoids the order vs. chaos scheme completely by adding input injection. The injection guarantees the converged NTK depends nontrivially on the inputs, as we will see later in the experiments. While previous results hold either only at initialization or networks with only last layer trained, analogous limiting behavior was proven by Jacot et al. (2018) to hold for fully-trained networks as well. They showed the kernel induced by a fully-trained infinite-width network is the following: Θ(x, y) = E θ∼N 󰀥 󰀟 ∂f (θ, x) ∂θ , ∂f (θ, y) ∂θ 󰀠 󰀦 , where N represents the Gaussian distribution. They also gave a recursive formulation for the NTK of FCNN. One may ask what happens if both the width and the depth in a fully-trained network go to infinity. This question requires careful formulations as one should consider the order of two limits, as Hanin & Nica (2019) proved that width and depth cannot simultaneously tend to infinity and result in a deterministic NTK, suggesting one cannot always swap the two limits. An interesting example is that Huang et al. ( 2020) showed that the infinite depth limit of a ResNet-NTK is deterministic, but if we let the width and depth go to infinity at the same rate, the ResNet behaves in a log-Gaussian fashion (Li et al., 2021) . Meanwhile, the infinite depth limit of NTK does not always present favorable properties. It turns out that the vanilla FCNN does not have a meaningful convergence: either it gives a constant kernel or the Kronecker Delta kernel (Jacot et al., 2019) .



(d)    n (x, y).

Arora et al. (2019); Alemohammad et al. (2020); Yang (2020) later provided formulation for convolutional NTK, recurrent NTK, and other structures.

