ON THE NEURAL TANGENT KERNEL OF EQUILIBRIUM MODELS

Abstract

Existing analyses of the neural tangent kernel (NTK) for infinite-depth networks show that the kernel typically becomes degenerate as the number of layers grows. This raises the question of how to apply such methods to practical "infinite depth" architectures such as the recently-proposed deep equilibrium (DEQ) model, which directly computes the infinite-depth limit of a weight-tied network via rootfinding. In this work, we show that because of the input injection component of these networks, DEQ models have non-degenerate NTKs even in the infinite depth limit. Furthermore, we show that these kernels themselves can be computed by an analogous root-finding problem as in traditional DEQs, and highlight methods for computing the NTK for both fully-connected and convolutional variants. We evaluate these models empirically, showing they match or improve upon the performance of existing regularized NTK methods.

1. INTRODUCTION

Recent works empirically observe that as the depth of a weight-tied input-injected network increases, its output tends to converge to a fixed point. Motivated by this phenomenon, DEQ models were proposed to effectively represent an "infinite depth" network by root-finding. A natural question to ask is, what will DEQs become if their widths also go to infinity? It is well-known that at certain random initialization, neural networks of various structures converge to Gaussian processes as their widths go to infinity (Neal, 1996; Lee et al., 2017; Yang, 2019; Matthews et al., 2018; Novak et al., 2018; Garriga-Alonso et al., 2018) . Recent deep learning theory advances have also shown that in the infinite width limit, with proper initialization (the NTK initialization), training the network f θ with gradient descent is equivalent to solving kernel regression with respect to the neural tangent kernel (Arora et al., 2019; Jacot et al., 2018; Yang, 2019; Huang et al., 2020) . However, as the depth goes to infinity, Jacot et al. (2019) showed that the NTKs of fully-connected neural networks (FCNN) converge either to a constant (freeze), or to the Kronecker Delta (chaos). In this work, we show that with input injection, the DEQ-NTKs converge to meaningful fixed points that depend on the input in a non-trivial way, thus avoiding both freeze and chaos. Furthermore, analogous to DEQ models, we can compute these kernels by solving an analogous fixed point equation, rather than simply iteratively applying the updates associated with the traditional NTK. Moreover, such derivations carry over to other structures like convolution DEQs (CDEQ) as well. We evaluate the approach and demonstrate that it typically matches or improves upon the performance of existing regularized NTK methods. 2 BACKGROUND AND PRELIMINARIES Bai et al. (2019) proposed the DEQ model, which is equivalent to running an infinite depth network with tied weight and input injection. These methods trace back to some of the original work in recurrent backpropagation (Almeida, 1990; Pineda, 1988) , but with specific emphasis on: 1) computing the fixed point directly via root-finding rather than forward iteration; and 2) incorporating the elements from modern deep networks in the single "layer", such as self-attention transformers (Bai et al., 2019) , multi-scale convolutions (Bai et al., 2020) , etc. The DEQ algorithm finds the infinite depth fixed point using quasi-Newton root finding methods, and then backpropagates using implicit differentiation without storing the derivatives in the intermediate layers, thus achieving a constant memory complexity. Furthermore, although a traditional DEQ model does not always guarantee to find a stable fixed point, with careful parameterization and update method, monotone operator DEQs can ensure the existence of a unique stable fixed point (Winston & Kolter, 2020) . On the side of connecting neural networks to kernel methods, Neal (1996) first discovered that a single-layered network with randomly initialized parameters becomes a Gaussian process (GP) in the large width limit. Such connection between neural networks and GP was later extended to multiple layers (Lee et al., 2017; Matthews et al., 2018) and various other architectures (Yang, 2019; Novak et al., 2018; Garriga-Alonso et al., 2018) . The networks studied in this line of works are randomly initialized, and one can imagine these networks as having fixed parameters throughout the training process, except for the last classification layer. Following the naming convention of Arora et al. (2019) , we call these networks weakly-trained, and networks where every layer is updated are called fully-trained. Weakly-trained nets induce the kernel Θ(x, y) = E θ∼N [f (θ, x) • f (θ, y)], where x, y ∈ R d are two samples, θ represents the parameters of the network, N is the initialization distribution (often Gaussian) over θ, and f (θ, •) ∈ R is the output of the network. One related topic in studying the relation between Gaussian process kernel and depth is the meanfield theory. Poole et al. (2016); Schoenholz et al. (2016) showed that the correlations between all inputs on an infinitely wide weakly-trained net become either perfectly correlated (order) or decorrelated (chaos) as depth increases. This aligns with the observation in Jacot et al. (2019) . They suggested we should initialize the neural network on the "edge-of-chaos" to make sure that signals can propagate deep enough in the forward direction, and the gradient does not vanish or explode during backpropagation (Raghu et al., 2017; Schoenholz et al., 2016) . These mean-field behaviors were later proven for various other structures like RNNs, CNNs, and NTKs as well (Chen et al., 2018a; Xiao et al., 2018; Gilboa et al., 2019; Hayou et al., 2019) . We emphasize that despite the similar appearance, our setting avoids the order vs. chaos scheme completely by adding input injection. Such structure guarantees the converged nets depend nontrivially on the inputs, as we will see later in the experiments. It can be unsatisfying that the previous results only involve weakly-trained nets. Interestingly, similar limiting behavior was proven by Jacot et al. (2018) to hold for fully-trained networks as well. They showed the kernel induced by a fully-trained infinite width network is the following: Θ(x, y) = E θ∼N ! " ∂f (θ, x) ∂θ , ∂f (θ, y) ∂θ # $ . ( ) They also gave a recursive formulation for the NTK of FCNN. Arora et al. (2019) ; Yang (2020) later provided formulation for convolution NTK and other structures. One may ask what happens if both the width and the depth go to infinity. It turns out that the vanilla FCNN does not have a meaningful convergence: either it gives constant kernels or Kronecker Delta kernels (Jacot et al., 2019) . On the bright side, this assertion is not always the case for other network structures. For example, the NTK induced by ResNet provides a meaningful fixed point in the large depth limit (Huang et al., 2020) . This may seem to give one explanation why ResNet outperforms FCNN, but unfortunately they also show that the ResNet NTK with infinite depth is no different from the ResNet NTK with just depth one. This conclusion makes the significance of infinite depth questionable.

2.1. NOTATIONS

Throughout the paper, we write θ as the parameters for some network f θ or equivalently, f (θ, •). We write capital letter W to represent matrices or tensors, which should be clear from the context, and use [W ] i to represent the element of W indexed by i. We write lower case letter x to represent vectors or scalars. For a ∈ Z + , let [a] = {1, . . . , a}. Denote σ(x) = √ 2 max(0, x) as the normalized ReLU and σ its derivative (which only need to be well-defined almost everywhere). The symbol σ 2 a with subscript is always used to denote the variance of some distribution. We write N (µ, Σ) as the Gaussian distribution with mean µ ∈ R d and covariance matrix Σ ∈ R d×d . We let S d-1 be the unit sphere embedded in R d .

3. DEQ-NTK OF FULLY-CONNECTED NEURAL NETWORKS

In this section, we show how to derive the NTK of the fully-connected DEQ (DEQ-NTK). From now on, we simplify fully-connected DEQs as DEQs. Recall that DEQs are equivalent to infinitely deep fully-connected neural nets with input injection (FCNN-IJ), and one can either exactly solve the fixed point using root-finding (up to machine epsilon and root-finding algorithm accuracy) or approximate the DEQs by just doing finite depth forward iterations. In section 3.1, we show the NTK of the approximated DEQ using finite depth iteration, and in section 3.2 we demonstrate how to get the exact convergence point of DEQ-NTK. The details of this section can be found in appendix A.

3.1. FINITE DEPTH ITERATION OF DEQ-NTK

Let d be the input dimension, x, y ∈ R d be a pair of inputs, N h be the width of the h th layers where h ∈ [L + 1]. Let N 0 = d and N L+1 = 1. Define the FCNN-IJ with L hidden layers as follows: f (h) θ (x) = % σ 2 W N h W (h) g (h-1) (x) + % σ 2 U N h U (h) x + % σ 2 b N h b (h) , h ∈ [L + 1] g (L) (x) = σ(f (L) θ (x)) where W (h) ∈ R N h ×N h-1 , U (h) ∈ R N h ×d are the internal weights, and b (h) ∈ R N h are the bias terms. These parameters are chosen using the NTK initialization. Let us pick σ W , σ U , σ b ∈ R arbitrarily in this section. NTK initialization. We randomly initialize every entry of every W, U, b from N (0, 1). Without loss of generality (WLOG) we assume the width of the hidden layer N h = N is the same across different layers. We remark the readers to distinguish FCNN-IJ from a recurrent neural network (RNN): our model injects the original input to each layer, whereas a RNN has a sequence of input (x 1 , . . . , x T ), and inject x t to the t th -layer. Here is a crucial distinction between finite width DEQs and infinite width DEQs: Remark 1. In the finite width regime, one typically has to assume the DEQs have tied weights, that is, W (1) = . . . = W (L+1) . Otherwise it is unlikely the network will converge at all. In fact, one needs to be very careful with the parametrization of the weights to guarantee the fixed point is unique and stable. This is not the case in the infinite width regime. As we shall see soon, even with distinct weights in each layer, the convergence of DEQ-NTKs only depend on σ 2 W , σ 2 U , σ 2 b , and the nonlinearity σ. Assuming untied weights makes the analysis easier, but the same argument can be made rigorously for tied weights as well, see Yang (2019; 2020) . Our main theorem is the following: Theorem 1. Recursively define the following quantities for h ∈ [L]: Σ (0) (x, y) = x ⊤ y (2) Λ (h) (x, y) = & Σ (h-1) (x, x) Σ (h-1) (x, y) Σ (h-1) (y, x) Σ (h-1) (y, y) ' ∈ R 2×2 (3) Σ (h) (x, y) = σ 2 W E (u,v)∼N (0,Λ (h) ) [σ(u)σ(v)] + σ 2 U x ⊤ y + σ 2 b (4) Σ(h) (x, y) = σ 2 W E (u,v)∼N (0,Λ (h) ) [ σ(u) σ(v)] Then the L-depth iteration to the DEQ-NTK can be expressed as: Θ (L) (x, y) = L+1 ( h=1 ) * + Σ (h-1) (x, y) , • L+1 - h ′ =h Σ(h ′ ) (x, y) . / , where by convention we set ΣL+1 (x, y) = 1 for the L-depth iteration. Proof Sketch. The first step is to show that at each layer h ∈ [L], the representation f (h) θ (x) is associated with a Gaussian process with kernel eq. ( 3) as N → ∞. Then use the characterization in eq. ( 1), calculate the NTK by: Θ (L) (x, y) = E θ ! " ∂f (θ, x) ∂θ , ∂f (θ, y) ∂θ # $ = E θ ! " ∂f (θ, x) ∂W , ∂f (θ, y) ∂W # $ 0 12 3 1 + E θ ! " ∂f (θ, x) ∂U , ∂f (θ, y) ∂U # $ 0 12 3 2 + E θ ! " ∂f (θ, x) ∂b , ∂f (θ, y) ∂b # $ 0 12 3 3 . Calculating each term using the chain rule, we get eq. ( 6).

3.2. FIXED POINT OF DEQ-NTK

Based on eq. ( 6), we are now ready to answer what the fixed point of Θ (L) is. By convention, we assume the two samples x, y ∈ S d-1 , and we require the parameters σ 2 W , σ 2 U , σ 2 b obey the DEQ-NTK initialization. DEQ-NTK initialization. Let every entry of every W, U, b follows the NTK initialization described in section 3.1, as well as the additional requirement σ 2 W + σ 2 U + σ 2 b = 1. Let the nonlinear activation function σ be the normalized ReLU: σ(x) = √ 2 max(0, x). Definition 3.1 (Normalized activation). We call an activation function σ : R → R normalized if E x∼N (0,1) [σ(x) 2 ] = 1. Using normalized activations along with DEQ-NTK initialization, we can derive the main convergence theorem: Theorem 2. Use same notations and settings in theorem 1. With input data x, y ∈ S d-1 , parameters σ 2 W , σ 2 U , σ 2 b follow the DEQ-NTK initialization, the fixed point of DEQ-NTK is Θ * (x, y) ≜ lim L→∞ Θ (L) (x, y) = Σ * (x, y) 1 -Σ * (x, y) , where Σ * (x, y) ≜ ρ * is the root of: R σ (ρ) -ρ, where R σ (ρ) ≜ σ 2 W & 4 1 -ρ 2 + 5 π -cos -1 ρ 6 ρ π ' + σ 2 U x ⊤ y + σ 2 b , and Σ * (x, y) ≜ lim h→∞ Σ(h) (x, y) = σ 2 W & π -cos -1 (ρ * ) π ' . Proof. Due to the fact that x ∈ S d-1 , σ being a normalized activation, and DEQ-NTK initialization, one can easily calculate by induction that for all h ∈ [L]: Σ (h) (x, x) = σ 2 W E u∼N (0,1) [σ(u) 2 ] + σ 2 V x ⊤ x + σ 2 b = 1 This indicates that in eq. ( 3), the covariance matrix has a special structure Λ (h) (x, y) = 7 1 ρ ρ 1 8 , where ρ = Σ (h-1) (x, y) depends on h, x, y. For simplicity we omit the h, x, y in Λ (h) (x, y). As shown in Daniely et al. (2016) : E (u,v)∼N (0,Λ) [σ(u)σ(v)] = 4 1 -ρ 2 + 5 π -cos -1 (ρ) 6 ρ π (10) E (u,v)∼N (0,Λ) [ σ(u) σ(v)] = π -cos -1 (ρ) π Adding input injection and bias, we derive eq. ( 8) from eq. ( 10), and similarly, eq. ( 9) from eq. ( 11). Notice that iterating eqs. ( 2) to (4) to solve for Σ (h) (x, y) is equivalent to iterating (R σ •• • ••R σ )(ρ) with initial input ρ = x ⊤ y. Take the derivative 9 9 9 9 dR σ (ρ) dρ 9 9 9 9 = 9 9 9 9 9 9 σ 2 W & 1 - cos -1 (ρ) π ' 9 9 9 9 9 9 < 1, if σ 2 W < 1 and -1 ≤ ρ < 1. For x ∕ = y we have -1 ≤ ρ < c < 1 for some c (this is because we only have finite number of inputs x, y) and by DEQ-NTK initialization we have σ 2 W < 1, so the above inequality hold. Hence R σ (ρ) is a contraction on [0, c], and we conclude that the fixed point ρ * is attractive. By lemma 1, if σ 2 W < 1, then the limit of eq. ( 6) exists, so we can rewrite the summation form in eq. ( 6) in a recursive form: Θ (0) (x, y) = Σ (0) (x, y) Θ (L+1) (x, y) = Σ(L+1) (x, y) • Θ (L) (x, y) + Σ (L+1) (x, y), and directly solve the fixed point iteration: lim L→∞ Θ (L+1) (x, y) = lim L→∞ + Σ(L+1) (x, y) • Θ (L) (x, y) + Σ (L+1) (x, y) , =⇒ lim L→∞ Θ (L+1) (x, y) = Σ * (x, y) • lim L→∞ Θ (L) (x, y) + Σ * (x, y) =⇒ lim L→∞ Θ (L) (x, y) = Σ * (x, y) • lim L→∞ Θ (L) (x, y) + Σ * (x, y). Solving for lim L→∞ Θ (L) (x, y) we get Θ * (x, y) = Σ * (x,y) 1-Σ * (x,y) . Remark 2. Note our Σ * (x, y) always depends on the inputs x and y, so the information between two inputs is always preserved, even if the depth goes to infinity. On the contrary as pointed out by Jacot et al. (2019) , without input injection, Σ (h) (x, y) always converges to 1 as h → ∞, even if x ∕ = y.

4. DEQ WITH CONVOLUTION LAYERS

In this section we show how to derive the NTKs for convolution DEQs (CDEQ). Although in this paper only the CDEQ with vanilla convolution structure is considered in experiments, we remark that our derivation is general enough for other CDEQ structure as well, for instance, CDEQ with global pooling layer. The details of this section can be found in appendix B. Unlike the FCNN-IJ, whose intermediate NTK representation is a real number. For convolution neural networks (CNN), the intermediate NTK representation is a four-way tensor. In the following, we will present the notations, CNN with input injection (CNN-IJ) formulation, the CDEQ-NTK initialization, and our main theorem. Notation. We adopt the notations from Arora et al. (2019) . Let x, y ∈ R P ×Q be a pair of inputs, let q ∈ Z + be the filter size (WLOG assume it is odd as well). By convention, we always pad the representation (both the input layer and hidden layer) with 0's. Denote the convolution operation as following: [w * x] ij = q-1 2 ( a=-q-1 2 q-1 2 ( b=-q-1 2 [w] a+ q+1 2 ,b+ q+1 2 [x] a+i,b+j for i ∈ [P ], j ∈ [Q]. Denote D ij,i ′ j ′ = : 5 i + a, j + b, i ′ + a ′ , j ′ + b ′ 6 ∈ [P ] × [Q] × [P ] × [Q] : -(q -1)/2 ≤ a, b, a ′ , b ′ ≤ (q -1)/2 ; . Intuitively, D ij,i ′ j ′ is a q × q × q × q set of indices centered at (ij, i ′ j ′ ). For any tensor T ∈ R P ×Q×P ×Q , let [T ] D ij,i ′ j ′ be the natural sub-tensor and let Tr(T ) = < i,j T ij,ij . Formulation of CNN-IJ. Define the CNN-IJ as follows: • Let the input x (0) = x ∈ R P ×Q×C0 , where C 0 is the number of input channels, and C h is the number of channels in layer h. Assume WLOG that C h = C for all h ∈ [L] • For h = 1, . . . , L, let the inner representation x(h) (β) = C h-1 ( α=1 % σ 2 W C h W (h) (α),(β) * x (h-1) (α) + C0 ( α=1 % σ 2 U C h U (h) (α),(β) * x (0) (α) (13) = x (h) (β) > ij = 1 [S] ij ? σ + x(h) (β) , @ ij , for i ∈ [P ], j ∈ [Q] ( ) where W (h) (α),(β) ∈ R q×q represent the convolution operator from the α th channel in layer h -1 to the β th channel in layer h. Similarly, U (h) (α),(β) ∈ R q×q injects the input in each convolution window. S ∈ R P ×Q is a normalization matrix. Let W, U, S, σ 2 U , σ 2 W be chosen by the CDEQ-NTK initialization described later. Notice here we assume WLOG that the number of channels in the hidden layers is the same. • The final output is defined to be f θ (x) = < C L α=1 A W (L+1) (α) , x (L) (α) B , W (L+1) (α) ∈ R P ×Q is sampled from standard Gaussian distribution. CDEQ-NTK initialization. Let 1 q ∈ R q×q , X ∈ R P ×Q be two all-one matrices. Let X ∈ R (P +2)×(Q+2) be the output of zero-padding X. We index the rows of X by {0, 1, . . . , P + 1} and columns by {0, 1, . . . , Q + 1}. For position i ∈ 14). [P ], j ∈ [Q], let 5 [S] ij 6 2 = [1 q * X] ij in eq. ( Let every entry of every W, U be sampled from N (0, 1) and σ 2 W + σ 2 U = 1. Using the above defined notations, we now state the CDEQ-NTK. Theorem 3. Let x, y ∈ R P ×Q×C0 be s.t ,x ij , 2 = ,y ij , 2 = 1 for i ∈ [P ], j ∈ [Q]. Define the following expressions recursively (some x, y are omitted in the notations), for (i, j, i ′ , j ′ ) ∈ [P ] × [Q] × [P ] × [Q], h ∈ [L] K (0) ij,i ′ j ′ (x, y) = C D ( α∈[C0] x (α) ⊗ y (α) E F ij,i ′ j ′ (15) = Σ (0) (x, y) > ij,i ′ j ′ = 1 [S] ij [S] i ′ j ′ C0 ( α=1 Tr & = K (0) (α) (x, y) > D ij,i ′ j ′ ' (16) Λ (h) ij,i ′ j ′ (x, y) = ) G * = Σ (h-1) (x, x) > ij,ij = Σ (h-1) (x, y) > ij,i ′ j ′ = Σ (h-1) (y, x) > i ′ j ′ ,ij = Σ (h-1) (y, y) > i ′ j ′ ,i ′ j ′ . H / ∈ R 2×2 (17) = K (h) (x, y) > ij,i ′ j ′ = σ 2 W [S] ij • [S] i ′ j ′ E (u,v)∼N (0,Λ (h) ij,i ′ j ′ ) [σ(u)σ(v)] + σ 2 U [S] ij • [S] i ′ j ′ [K (0) ] ij,i ′ j ′ (18) = K(h) (x, y) > ij,i ′ j ′ = σ 2 W [S] ij • [S] i ′ j ′ E (u,v)∼N (0,Λ (h) ij,i ′ j ′ ) [ σ(u) σ(v)] (19) = Σ (h) (x, y) > ij,i ′ j ′ = Tr & = K (h) (x, y) > D ij,i ′ j ′ ' ( ) Define the linear operator L : R P ×Q×P ×Q → R P ×Q×P ×Q via [L(M )] ij,i ′ j ′ = Tr + [M ] D ij,i ′ j ′ , . Then the CDEQ-NTK can be found solving the following linear system: Θ * (x, y) = K * (x, y) ⊙ L 5 Θ * (x, y) 6 + K * (x, y), ( ) where K * (x, y) = lim L→∞ K (L) (x, y), K * (x, y) = lim L→∞ K(L) (x, y). The limit exists if σ 2 W < 1.

5. EXPERIMENTS

In this section, we evaluate the performance of DEQ-NTK and CDEQ-NTK on both MNIST and CIFAR-10 datasets. We also compare the performance of finite depth NTK and finite depth iteration of DEQ-NTK. Implementation. For DEQ-NTK, aligned with the theory, we normalize the dataset such that each data point has unit length. The fixed point Σ * (x, y) is solved by using the modified Powell hybrid method (Powell, 1970) . Notice these root finding problems are one-dimensional, hence can be quickily solved. For CDEQ-NTK, the input data x has dimension P × Q × C 0 , and we normalize x s.t ,x ij , 2 = 1 for any i ∈ [P ], j ∈ [Q]. We set q = 3 and stride 1. The fixed point Σ * (x, y) ∈ R P ×Q×P ×Q is approximated by running 20 iterations of eq. ( 17), eq. ( 18), and eq. ( 20). The actual CDEQ-NTK Θ(x, y) is then calculated by solving the sparse linear system eq. ( 21). After obtaining the NTK matrix, we apply kernel regressions (without regularization unless stated otherwise). For any label y ∈ {1, . . . , n}, denote its one-hot encoding by e y . Let 1 ∈ R n be an all-1 vector, we train on the new encoding -0.1 • 1 + e y . That is, we change the "1" to 0.9, and the "0" to -0.1, as suggested by Novak et al. (2018) . All models are trained on 1000 CIFAR-10 data and tested on 100 test data for 20 random draws. The error bar represents the 95% confidence interval (CI). As expected, as the depth increases, the performance of NTKs drop, eventually their 95% CI becomes a singleton, yet the performance of DEQs stabilize. Also note with larger σ 2 W , the freezing of NTK takes more depths to happen. Result. On MNIST data, we test the performance of DEQ-NTK with σ 2 W = 0.25, σ 2 U = 0.25, σ 2 b = 0.5 and achieve 98.6% test accuracy. The results are listed in table 1. On CIFAR-10, we trained DEQ-NTK with three different sets of random initializations. These initializations are not fine-tuned, yet we can still see they are comparable, or even superior, to the finite-depth NTK with carefully chosen regularization. For CDEQ-NTK, we compute the kernel matrix on 2000 training data and tested on 10000 samples. See the results in table 2. (Chen et al., 2018b) and monotone operator DEQ, see these results from Winston & Kolter (2020) .

MNIST

We should emphasize that the calculation of NTK requires a huge amount of computing resource, even for shallow networks. On the other hand, our method provides an efficient way to compute a special NTK with infinite depth. Typically, training a DEQ-NTK on all CIFAR-10 data takes around 400 CPU hour, and the training hour of CDEQ-NTK halves that of its finite-depth CNTK counterpart, as we only need to calculate Σ (h) , whereas the actual CNTK needs to calculate both Σ (h) and Θ (h) .

CIFAR-10

Method Parameters Acc. 2019), the NTK will always "freeze" in our setting. Therefore the NTK starts to become linearly independent as the depth increases, and its kernel regression does not have a unique solution. To circumvent this issue, we add a regularization term r ∝ %Θ(x,x) n , where n is the size of the training data. Such regularization is known to guarantee uniform stability (Bousquet & Elisseeff, 2002) , and it still interpolates data in the classification sense (training accuracy is 100%).

DEQ-NTK

σ 2 W = 0.25, σ 2 U = 0.25, σ 2 b = 0.5 59.08% DEQ-NTK σ 2 W = 0.6, σ 2 U = 0.4, σ 2 b = 0 59.77% DEQ-NTK σ 2 W = 0.8, σ 2 U = 0.2, σ 2 b = 0 59.43% NTK with ZCA regularization σ 2 W = 2, σ 2 b =

6. CONCLUSION

We derive NTKs for both fully-connected DEQs and convolution DEQs, and show that they can be computed more efficiently than finite depth NTK and CNTK, especially when the depth is deep. Moreover, the performance of DEQ-NTK and CDEQ-NTK is comparable to their finite depth NTK counterparts. Our analysis shows that one can avoid the freeze and chaos phenomenon in infinitely deep NTKs by using input injection. One interesting question remained open is to further understand the role of σ 2 W , σ 2 U , σ 2 b in the fixed point computation, and how they affect generalizations of DEQ-NTKs.

A DETAILS OF SECTION 3

In this section, we give the detailed derivation of DEQ-NTK. There are two terms that are different from NTK: Σ (h) (x, y) and the extra E θ ? A ∂f (θ,x) ∂U , ∂f (θ,y) ∂U B @ in the kernel. Let us restate the FCNN-IJ here: Let d be the input dimension, x, y ∈ R d be a pair of inputs, N h be the width of the h th hidden layers. Let N 0 = d and N L+1 = 1. Define the FCNN-IJ with L layers as follows: f (h) θ (x) = % σ 2 W N h W (h) g (h-1) (x) + % σ 2 U N h U (h) x + % σ 2 b N h b (h) , h ∈ [L] g (L) (x) = σ(f (L) θ (x)) where W (h) ∈ R N h ×N h-1 , U (h) ∈ R N h ×d are the internal weights, and b (h) ∈ R N h are the bias terms. These parameters are chosen using the NTK initialization. Let us pick σ W , σ U , σ b ∈ R arbitrarily in this section. Proof of theorem 1. First we note that E ? = f (h+1) (x) > i • = f (h+1) (y) > i | f (h) @ = σ 2 W N N ( j=1 σ 7 = f (h) (x) > j 8 σ 7 = f (h) (y) > j 8 + σ 2 U N N ( j=1 x ⊤ y + σ 2 b →Σ (h+1) (x, y) a.s where the first line is by expansion the original expression and using the fact that W, U, b are all independent. The last line is from the strong law of large numbers. This shows how the covariance changes as depth increases with input injection. Recall the splitting: Θ (L) (x, y) = E θ ! " ∂f (θ, x) ∂θ , ∂f (θ, y) ∂θ # $ = E θ ! " ∂f (θ, x) ∂W , ∂f (θ, y) ∂W # $ 0 12 3 1 + E θ ! " ∂f (θ, x) ∂U , ∂f (θ, y) ∂U # $ 0 12 3 2 + E θ ! " ∂f (θ, x) ∂b , ∂f (θ, y) ∂b # $ 0 12 3 3 . The following equation has been proven in many places: 1 = L+1 ( h=1 ) * σ 2 W E (u,v)∼N (0,Λ (h) ) [σ(u)σ(v)] • L+1 - h ′ =h Σ(h ′ ) (x, y) . / , 3 = L+1 ( h=1 ) * σ 2 b • L+1 - h ′ =h Σ(h ′ ) (x, y) . / For instance, see Arora et al. (2019) . So we only need to deal with the second term E θ ? A ∂f (θ,x) ∂U , ∂f (θ,y) ∂U B @ . Write f = f θ (x) and f = f θ (y), by chain rule, we have I ∂f ∂U (h) , ∂ f ∂U (h) J = I ∂f ∂f (h) ∂f (h) ) ∂U (h) , ∂ f ∂ f (h) ∂ f (h) ) ∂U (h) J = I ∂f (h) ∂U (h) , ∂ f (h) ∂U (h) J • I ∂f ∂f (h) , ∂ f ∂ f (h) J →σ 2 U x ⊤ y • L+1 - h ′ =h Σ(h ′ ) (x, y) where the last line uses the existing conclusion that A ∂f ∂f (h) , ∂ f ∂ f (h) B → K L+1 h ′ =h Σ(h ′ ) (x, y), this convergence almost surely holds when N → ∞ by law of large numbers. Finally, summing A ∂f ∂U (h) , ∂ f ∂U (h) B over h ∈ [L] we conclude the assertion. We now proceed to explain more about the fixed point convergence in theorem 1. Let us first show the limit converges. Lemma 1. Use the same notations and settings in theorem 1 and theorem 2. Θ (L) (x, y) in eq. ( 6) converges absolutely if σ 2 W < 1. Proof. Since we pick x, y ∈ S d-1 , and by DEQ-NTK initialization, we always have Σ (h) (x, y) < 1 for x ∕ = y. Let ρ = Σ (h) (x, y), by eq. ( 5) and eq. ( 11), if σ 2 W < 1, then there exists c such that Σ(h) (x, y) < c < 1 for all finite number of x ∕ = y on S d-1 , and large enough h. This is because Since c < 1, the geometric sum converges absolutely, hence Θ * (x, y) converges absolutely if σ 2 W < 1, and the limit exists.

B DETAILS OF SECTION 4

We first explain the choice of S in the CDEQ-NTK initialization. In the original CNTK paper (Arora et al., 2019) , the normalization is simply 1/q 2 . However, due to the zero-padding, 1/q 2 does not normalize all = Σ (h) (x, x) > ij,i ′ j ′ as expected: only the variances that are away from the corners are normalized to 1, but the ones near the corner are not. . Now we give the proof to our main theorem. Proof of theorem 3. Similar to the proof of theorem 1, we can split the CDEQ-NTK in two terms: Omit the input symbols x, y, let Θ (L) (x, y) = E θ ! " ∂f (θ, x) ∂θ , ∂f = L K (h) > ij,i ′ j ′ = σ 2 W [S] ij • [S] i ′ j ′ E (u,v)∼N (0,Λ (h) ij,i ′ j ′ ) [σ(u)σ(v)]. As shown in Arora et al. (2019) , we have " ∂f θ (x) ∂W (h) , ∂f θ (, y) ∂W (h) # → Tr ) * K(L) ⊙ L & K(L-1) ⊙ L 7 • • • K(h) ⊙ L + L K h-1 , • • • 8 ' . / Write f = f θ (x) and f = f θ (y). Following the same step, by chain rule, we have I ∂f ∂U (h) , ∂ f ∂U (h) J → Tr ) * K(L) ⊙ L & K(L-1) ⊙ L 7 • • • K(h) ⊙ L + K (0) , • • • 8 ' . / Rewrite the above two equations in recursive form, we can calculate the L-depth iteration of CDEQ-NTK by: • For the first layer Θ (0) (x, y) = Σ (0) (x, y). • For h = 1, . . . , L -1, let = Θ (h) (x, y) > ij,i ′ j ′ = Tr & = K(h) (x, y) ⊙ Θ (h-1) (x, y) + K (h) (x, y) > D ij,i ′ j ′ ' • For h = L, let Θ (L) (x, y) = K(L) (x, y) ⊙ Θ (L-1) (x, y) + K (h) (x, y) • The final kernel value is Tr(Θ (L) (x, y)). Using eq. ( 22) and eq. ( 23), we can find the following recursive relation: Θ (L+1) (x, y) = K(L+1) (x, y) ⊙ L + Θ (L) (x, y) , + K (h+1) (x, y)



Figure 1: Finite depth NTK vs. finite depth iteration of DEQ-NTK. In all experiments, the NTK is initialized with σ 2 W and σ 2 b in the title. For DEQ-NTK we set σ 2 U = σ 2 b -0.1 in the title, and the actual σ 2 b = 0.1.All models are trained on 1000 CIFAR-10 data and tested on 100 test data for 20 random draws. The error bar represents the 95% confidence interval (CI). As expected, as the depth increases, the performance of NTKs drop, eventually their 95% CI becomes a singleton, yet the performance of DEQs stabilize. Also note with larger σ 2 W , the freezing of NTK takes more depths to happen.

[S]  ij is simply the number of non-zero entries in

Performance of DEQ-NTK on MNIST dataset, compared to neural ODE

Performance of DEQ-NTK and CDEQ-NTK on CIFAR-10 dataset, see Lee et al. (2020). for NTK with ZCA regularization. With a smaller dataset with 1000 training data and 100 test data from CIFAR-10, we evaluate the performance of NTK and the finite depth iteration of DEQ-NTK, as depth increases. See fig. 1 As proven in Jacot et al. (

annex

At this point, we need to show that K * (x, y) ≜ lim L→∞ K (L) (x, y) and K * (x, y) ≜ lim L→∞ K(L) (x, y) exist. Let us first agree that for all h ∈ [L], (ij, i ′ j ′ ) ∈the diagonal entries of Λ (h) ij,i ′ j ′ are all ones. Indeed, these diagonal entries are 1's at h = 0 by initialization. Note that iterating eqs. ( 17) to (20) to solve for [Σ (h) (x, y)] ij,i ′ j ′ is equivalent to iterating f : R P ×Q×P ×Q → R P ×Q×P ×Q :whereis applied to P (h) entrywise.Due to CDEQ-NTK initialization, if= 1 for all iterations h. This is true by the definition of S.K * also exist. We should keep the readers aware that f : R P ×Q×P ×Q → R P ×Q×P ×Q , so we should be careful with the metric spaces. We want every entry of Σ (h) (x, y) to converge, since this tensor has finitely many entries, this is equivalent to say its ℓ ∞ norm (imagine flattenning this tensor into a vector) converges. So we can equip the domain an co-domain of f with ℓ ∞ norm (though these are finite-dimensional spaces so we can really equip them with any norm, but picking ℓ ∞ norm makes the proof easy).

Now we have

If we flatten the four-way tensor P (h) into a vector, then L can be represented by aand 0 otherwise. In other words, the ℓ 1 norm of the (ij, i ′ j ′ )-th row represents the number of non-zero entries in D ij,i ′ j ′ , but by the CDEQ-NTK initialization, the row ℓ 1 norm divided by [S] ij • [S] i ′ j ′ is at most 1! Using the fact that ,L, ℓ ∞ →ℓ ∞ is the maximum ℓ 1 norm of the row, and the fact R σ is a contraction (proven in theorem 2), we conclude that f is indeed a contraction.With the same spirit, we can also show that eq. ( 23) is a contraction if σ 2 W < 1, hence eq. ( 21) is indeed the unique fixed point. This finishes the proof.

B.1 COMPUTATION OF CDEQ-NTK

One may wish to directly compute a fixed point (or more precisely, a fixed tensor) of Θ (L) ∈ R P ×Q×P ×Q like eq. ( 8). However, due to the linear operator L (which is just the ensemble of the trace operator in eq. ( 20)), the entries depend on each other. Hence the system involves a (P × Q × P × Q) × (P × Q × P × Q)-dimensional matrix that represents L. Even if we exploit the fact that only entries on the same "diagonal" depend on each other, L is at least P × Q × P × Q, which is 32 4 for CIFAR-10 data. Moreover, this system is nonlinear. Therefore we cannot compute the fixed point Σ * by root-finding efficiently. Instead, we approximate it using finite depth iterations, and we observe that in experiments they typically converge to 10 -6 accuracy in ℓ ∞ within 15 iterations.

