A FAST, WELL-FOUNDED APPROXIMATION TO THE EMPIRICAL NEURAL TANGENT KERNEL

Abstract

Empirical neural tangent kernels (eNTKs) can provide a good understanding of a given network's representation: they are often far less expensive to compute and applicable more broadly than infinite-width NTKs. For networks with O output units (e.g. an O-class classifier), however, the eNTK on N inputs is of size N O × N O, taking O (N O) 2 memory and up to O (N O) 3 computation. Most existing applications have therefore used one of a handful of approximations yielding N × N kernel matrices, saving orders of magnitude of computation, but with limited to no justification. We prove that one such approximation, which we call "sum of logits," converges to the true eNTK at initialization. Our experiments demonstrate the quality of this approximation for various uses across a range of settings.

1. INTRODUCTION

The pursuit of a theoretical foundation for deep learning has lead researches to uncover interesting connections between neural networks (NNs) and kernel methods. It has long been known that randomly initialized NNs in the infinite width limit are Gaussian processes with what is termed the Neural Network Gaussian Process (NNGP) kernel, and training the last layer with gradient flow under squared loss corresponds to the posterior mean (Neal, 1996; Williams, 1996; Hazan & Jaakkola, 2015; Lee et al., 2017; Matthews et al., 2018; Novak et al., 2018; Yang, 2019) . More recently, Jacot et al. (2018) (building off a line of closely related prior work) showed that the same is true if we train all the parameters of the network, but using a different kernel called the Neural Tangent Kernel (NTK). Yang (2020); Yang & Littwin (2021) later showed that this connection is architecturally universal, extending the domain from fully-connected NNs to most of the currently-used networks in practice, such as ResNets and Transformers. Lee et al. (2019) also showed that the dynamics of training wide but finite-width NNs with gradient descent can be approximated by a linear model obtained from the first-order Taylor expansion of that network around its initialization. Furthermore, they experimentally showed that this approximation approximation excellently holds even for networks that are not so wide. In addition to theoretical insights from the results themselves, NTKs have had significant impact in diverse practical settings. Arora et al. (2019b) show very strong performance of NTK-based models on a variety of low-data classification and regression tasks. The condition number of an NN's NTK has been shown correlation directly with the trainability and generalization capabilities of the NN (Xiao et al., 2018; 2020) We thus believe NTKs will continue to be used in both theoretical and empirical deep learning. Unfortunately, however, computing the NTK for practical networks is extremely challenging, and most of the time not even computationally feasible. The NTK of a NN is defined as the outer product of the Jacobians of the output of the NN with respect to its parameters: eN T K := Θ θ (x 1 , x 2 ) = [J θ (f θ (x 1 ))] [J θ (f θ (x 2 ))] ⊤ , (1) where J θ (f θ (x)) denotes the Jacobian of the function f at a point x with respect to the flattened vector of all its parameters, θ ∈ R P . Assuming f : R D → R O , where D is the input dimension and O the number of outputs, we have J θ (f θ (x)) ∈ R O×P and Θ θ (x 1 , x 2 ) ∈ R O×O . Thus, computing the NTK between a set of N 1 data points and a set of N 2 data points yields N 1 N 2 matrices each of shape O × O, which we usually reshape into an N 1 O × N 2 O matrix. When computing an eNTK on tasks involving large datasets and with multiple output neurons, e.g. in a classification model with O classes, the eNTK quickly becomes impractical regardless of how fast each entry is computed due to its N O × N O size. For instance, the full eNTK of a classification model even on the relatively mild CIFAR-10 dataset (Krizhevsky, 2009) , stored in double precision, takes over 1.8 terabytes in memory. For practical usage, we need to do something better. In this work, we present a simple trick for a strong approximation of the eNTK that removes the O 2 from the size of the kernel matrix, resulting in a factor of O 2 improvement in the memory and up to O 3 in computation. Since for typical classification datasets O is at least 10 (e.g. CIFAR-10) and potentially 1 000 or more (e.g. ImageNet, Deng et al., 2009) , this provides multiple orders of magnitude savings over the original eNTK (1). We prove that under appropriate initialization of the NN this approximation converges to the original eNTK at a rate of O(n -1/2 ) for a network of depth L and width n in each layer, and the predictions of kernel regression with the approximate kernel do the same. Finally, we present diverse experimental investigations supporting our theoretical results across a range of different architectures and settings. We hope this approximation further enables researches to employ NTKs towards theoretical and empirical advances in wide networks. Infinite NTKs In the infinite-width limit of properly initialized NNs, Θ θ converges almost surely at initialization to a particular kernel, and remains constant over the course of training. Algorithms are available to compute this expectation exactly, but they tend to be substantially more expensive than computing (1) directly for all but extremely wide networks. The convergence to this infinite-width regime is also slow in practice, and moreover it eliminates some of the interest of the framework: neural architecture search, predicting generalization of a pre-trained representation, and meta-learning are all considerably less interesting when we only consider infinite-width networks that do essentially no feature learning. Thus, in this paper, we focus only on the "empirical" eNTK (1).

2. RELATED WORK

Among the numerous recent works that have used eNTKs either to gain insights about various phenomenons in deep learning or to propose new algorithms, not many have publicized the computational costs and implementation details of computing eNTKs. Nevertheless, all are in agreement about the expense of such computations (Park et al., 2020; Holzmüller et al., 2022; Fort et al., 2020) . Several recent works have, mostly "quietly," employed various techniques to avoid dealing with the full eNTK matrix; however, to the best of our knowledge, none provide any rigorous justifications.



; thus, Park et al. (2020); Chen et al. (2021) have used this to develop practical algorithms for neural architecture search. Wei et al. (2022); Bachmann et al. (2022) estimate the generalization ability of a specific network, randomly initialized or pre-trained on a different dataset, with efficient cross-validation. Zhou et al. (2021) use NTK regression for efficient meta-learning, and Wang et al. (2021); Holzmüller et al. (2022); Mohamadi et al. (2022) use NTKs for active learning. There has also been significant theoretical insight gained from empirical studies of networks' NTKs. Here are a few examples: Fort et al. (2020) use NTKs to study how the loss geometry the NN evolves under gradient descent. Franceschi et al. (2021) employ NTKs to analyze the behaviour of Generative Adverserial Networks (GANs). Nguyen et al. (2020; 2021) used NTKs for dataset distillation. He et al. (2020); Adlam et al. (2020) used NTKs to predict and analyze the uncertainty of a NN's predictions. Tancik et al. (2020) use NTKs to analyze the behaviour of MLPs in learning high frequency functions, leading to new insights into our understanding of neural radiance fields.

Figure 1: of wall-clock time of evaluating the eNTK and pNTK of a pair of input datapoints over various datasets and ResNet depths.

