NNGEOMETRY: EASY AND FAST FISHER INFORMA-TION MATRICES AND NEURAL TANGENT KERNELS IN PYTORCH

Abstract

Yet these theoretical tools are often difficult to implement using current libraries for practical size networks, given that they require per-example gradients, and a large amount of memory since they scale as the number of parameters (for the FIM) or the number of examples × cardinality of the output space (for the NTK). NNGeometry is a PyTorch library that offers a simple interface for computing various linear algebra operations such as matrix-vector products, trace, frobenius norm, and so on, where the matrix is either the FIM or the NTK, leveraging recent advances in approximating these matrices. We hereby introduce the library and motivate our design choices, then we demonstrate it on modern deep neural networks.



Practical and theoretical advances in deep learning have been accelerated by the development of an ecosystem of libraries allowing practitioners to focus on developing new techniques instead of spending weeks or months re-implementing the wheel. In particular, automatic differentiation frameworks such as Theano (Bergstra et al., 2011) , Tensorflow (Abadi et al., 2016) or PyTorch (Paszke et al., 2019) have been the backbone for the leap in performance of last decade's increasingly deeper neural networks as they allow to compute average gradients efficiently, used in the stochastic gradient algorithm or variants thereof. While being versatile in neural networks that can be designed by varying the type and number of their layers, they are however specialized to the very task of computing these average gradients, so more advanced techniques can be burdensome to implement. While the popularity of neural networks has grown thanks to their always improving performance, other techniques have emerged, amongst them we highlight some involving Fisher Information Matrices (FIM) and Neural Tangent Kernels (NTK). Approximate 2nd order (Schraudolph, 2002) or natural gradient techniques (Amari, 1998) aim at accelerating training, elastic weight consolidation (Kirkpatrick et al., 2017) proposes to fight catastrophic forgetting in continual learning and WoodFisher (Singh & Alistarh, 2020) tackles the problem of network pruning so as to minimize its computational footprint while retaining prediction capability. These 3 methods all use the Fisher Information Matrix while formalizing the problem they aim at solving, but resort to using different approximations when going to implementation. Similarly, following the work of Jacot et al. (2018) , a line of work study the NTK in either its limiting infinite-width regime, or during training of actual finite-size networks. All of these papers start by formalizing the problem at hand in a very concise math formula, then face the experimental challenge that computing the FIM or NTK involves performing operations for which off-the-shelf automatic differentiation libraries are not well adapted. An even greater turnoff comes from the fact that these matrices scale with the number of parameters (for the FIM) or the number of examples in the training set (for the empirical NTK). This is prohibitively large for modern neural networks involving millions of parameters or large datasets, a problem circumvented by a series of techniques to approximate the FIM (Ollivier, 2015; Martens & Grosse, 2015 ; George et al., 2018) . NNGeometry aims at making use of these approximations effortless, so as to accelerate development or analysis of new techniques, allowing to spend more time on the theory and less time in fighting development bugs. NNGeometry's interface is designed to be as close as possible to maths formulas. In summary, this paper and library contribute: • We introduce NNGeometry by describing and motivating design choices. -A unified interface for all FIM and NTK operations, regardless of how these are approximated. -Implicit operations for ability to scale to large networks.. • Using NNGeometry, we get new empirical insights on FIMs and NTKs: -We compare different approximations in different scenarios. -We scale some NTK evolution experiments to TinyImagenet.

1. PRELIMINARIES

1.1 NETWORK LINEARIZATION Neural networks are parametric functions f (x, w) : X × R d → R c where x ∈ X are covariates from an input space, and w ∈ R d are the network's parameters, arranged in layers composed of weight matrices and biases. The function returns a value in R c , such as the c scores in softmax classification, or c real values in c-dimensional regression. Neural networks are trained by iteratively adjusting their parameters w (t+1) ← w (t) + δw (t) using steps δw (t) typically computed using the stochastic gradient algorithm or variants thereof, in order to minimize the empirical risk of a loss function. In machine learning, understanding and being able to control the properties of the solution obtained by an algorithm is of crucial interest, as it can provide generalization guarantees, or help design more efficient or accurate algorithms. Contrary to (kernelized) linear models, where closed-form expressions of the empirical risk minimizer exist, deep networks are non-linear functions, whose generalization properties and learning dynamics is not yet fully understood. Amongst the recent advances toward improving theory, is the study of the linearization (in w) of the deep network function f (x, w): f (x, w + δw) = f (x, w) + J (x, w) δw + o ( δw ) 1 (1) where J (x, w) = ∂f (x,w) ∂w is the Jacobian with respect to parameters w, computed in (w, x), mapping changes in parameter space δw to corresponding changes in output space using the identity δf (x, w, δw) = J (x, w) δw. For tiny steps δw, we neglect the term o ( δw ) thus f is close to its linearization. It happens for instance at small step sizes, or in the large-width limit with the specific parameter initialization scheme proposed by Jacot et al. (2018) .

1.2. PARAMETER SPACE METRICS AND FISHER INFORMATION MATRIX

While neural networks are trained by tuning their parameters w, the end goal of machine learning is not to find the best parameter values, but rather to find good functions, in a sense that is dependent of the task at hand. For instance different parameter values can represent the same function (Dinh et al., 2017) . On the contrary 2 parameter space steps δw 1 and δw 2 with same euclidean norm can provide very different changes in a function (δf (x, w, δw 1 ) = δf (x, w, δw 2 )). In order to quantify changes of a function, one generally defines a distancefoot_1 on the function space. Examples of such distances are the L k -norms, Wasserstein distances, or the KL divergence used in information geometry. To each of these function space distances correspond a parameter space metric. We continue our exposition by focusing on the KL divergence, which is closely related to the Fisher Information Matrix, but our library can be used for other function space distances. Suppose f is interpreted as log-probability of a density p: log p (x, w) = f (x, w), the KL divergence gives a sense of how much the probability distribution changes when adding a small increment δw to the parameters of f (x, w). We can approximate it as: KL (p (x, w) p (x, w + δw)) = x∈X log p (x, w) p (x, w + δw) dp (x, w) (2) = 1 2 x∈X 1 p (x, w) J (x, w) δw 2 dp (x, w) + o δw 2 (3) where we used this form (derived in appendix) in order to emphasize how steps in parameter space δw affect distances measured on the function space: equation 3 is the result of i) taking a step δw in parameter space; ii) multiplying with J (x, w) to push the change to the function space; iii) weight this function space change using p (x, w) -1 ; iv) square and sum. In particular, because of the properties of the KL divergence, there is no second derivative of f involved, even if equation 3 is equivalent to taking the 2nd order Taylor series expansion of the KL divergence. We can rewrite in a more concise way: KL (f (x, w) f (x, w + δw)) = δw F w δw + o δw 2 (4) which uses the d × d FIM F w = x∈X 1 p(x,w) 2 J (x, w) J (x, w) dp (x, w). In particular, we can now define the norm δw Fw = δw F w δw used in the natural gradient algorithm (Amari (1998) , also see Martens (2020) for a more thorough discussion of the FIM), in elastic weight consolidation (Kirkpatrick et al., 2017) , or in pruning (Singh & Alistarh, 2020) . Other quantities also share the same structure of a covariance of parameter space vectors, such as the covariance of loss gradients in TONGA (Roux et al., 2008) , the second moment of loss gradientsfoot_2 (Kunstner et al., 2019; Thomas et al., 2020) , or posterior covariances in bayesian deep learning (e.g. in Maddox et al. ( 2019)).

1.3. NEURAL TANGENT KERNEL

Another very active line of research around the linearization of equation 1 is to take inspiration from the rich literature on kernel methods by defining the neural tangent kernel (NTK): k w (x, y) = J (x, w) J (y, w) In the limit of networks infinite width, Jacot et al. (2018) have shown that the tangent kernel remains constant through training using gradient descent, which allows to directly apply kernel learning theory to deep learning. While this regime is of theoretical interest, it arguably does not explain what happens at finite width, where the NTK evolves during training. While kernels are functions of the whole input space X × X , we often only have access to a limited number of samples in a datasets. We thus resort to using the kernel evaluated at points x i of a 

2.2.1. ABSTRACT OBJECTS

In section 1, we have worked with abstract mathematical objects δw, δf (x, w, δw), J (x, w), F w and K w . We now identify these mathematical objects to Python classes in NNGeometry. We start with the parameter space, that we previously identified as R d . Closer to how they are actually implemented in deep learning frameworks, vectors in the parameter space w can equivalently be considered as a set of weight matrices and bias vectors w = {W 1 , b 1 , . . . , W l , b l }. Parameter space vectors are represented by the class PVector in NNGeometry, which is essentially a dictionary of PyTorch Parameters, with basic algebra logic: PVectors can be readily added, substracted, and scaled by a scalar with standard python operators. As an illustration wsum = w1 + w2 internally loops through all parameter tensors of w1 and w2 and returns a new PVector w_sum. Similarly, and more interestingly, parameter space metrics such as the FIM are represented by classes prefixed with PMat. For instance, the natural gradient δ nat = -ηF -1 ∇ w L applies the linear operator w → F -1 w to the parameter space vector ∇ w L, and can be implemented cleanly and concisely using delta_nat = -eta * F.solve(nabla_L), even if it internally involves different operations for different layer types, and different approximation techniques. Function space vectors FVector define objects associated to vectors of the output space, evaluated on a dataset of n examples X. As an example, getting back to the linearization δf (x, w, δw) = J (x, w) δw, we define δf (X) = (δf (x 1 , w, δw) , . . . , δf (x n , w, δw)) as the R c×n function space vector of output changes for all examples of X. Gram matrices of the NTK are linear operators on this space, represented by objects prefixed with FMat. Borrowing from the vocabulary of differential geometry, we also define PushForward objects that are linear operator from parameter space to function space, and PullBack objects that are linear operator from function space to parameter space. While the following consideration can be ignored upon first glance, the structure of the parameter space is internally encoded using a LayerCollection object. This gives the flexibility of defining our parameter space as parameters of a subset of layers, in order to treat different layers in different ways. An example use case is to use KFAC for linear layers parameters, and block-diagonal for GroupNorm layers, as KFAC is not defined for the latter.

2.2.2. CONCRETE REPRESENTATIONS

These abstract objects are implemented in memory using concrete representations. NNGeometry comes with a number of representations. Amongst them, most notably, are parameter space approximations proposed in recent literature (Ollivier, 2015; Martens & Grosse, 2015; Grosse & Martens, 2016; George et al., 2018) , and an implicit representation for each abstract linear operator, that allows to compute linear algebra operations without ever computing or storing the matrix in memory. PMatDense (resp PMatDense) and PMatDiag represent the full dense matrix and the diagonal matrix and need no further introduction. PMatLowRank only computes and stores J (X, w) the c × n × d stacked Jacobian for all examples of the given dataset. Next come representations that do not consider neural networks as black-box functions, but instead are adapted to the layered structure of the networks: PMatBlockDiag uses dense blocks of the FIM for parameters of the same layer, and puts zeros elsewhere, ignoring cross-layer covariance. PMatQuasiDiag (Ollivier, 2015) uses the full diagonal and adds to each bias element the interaction with the corresponding row of the weight matrix. PMatKFAC uses KFAC (Martens & Grosse, 2015) and its extension to convolution layers KFC (Grosse & Martens, 2016) to approximate each layer blocks with the kronecker product of 2 much smaller matrices, thus saving memory and compute compared to PMatBlockDiag. PMatEKFAC uses the EKFAC (George et al., 2018) extension of KFAC. The last representation that comes with this first release of NNGeometry, PMatImplicit, allows to compute certain linear algebra operations using the full dense matrix, but without the need to ever store it in memory, which permits scaling to large networks (see experiments in section 3). As an illustration, the vector-matrix-vector product v F v can be computed using equation 3. Each representation comes with its advantages and drawbacks, allowing to trade-off between memory and approximation accuracy. For a new project, we recommend starting with a small network using the PMatDense representation, then gradually switching to representations with a lower memory footprint while experimenting with actual modern networks. While linear algebra operations associated to each representation internally involve very different mechanisms, NNGeometry's core contribution is to give easy access to these operations by using the same simple methods (figure 1 ).

2.2.3. GENERATORS

In order to compute FIMs and NTKs, we need to compute Jacobians J (x, w) for examples x coming from a dataset. NNGeometry's generator is the component that actually populates the representations by computing the required elements of the matrices, depending on the representation. While a naive idea would be to loop through examples x i , compute f (x i , w) and compute gradients with respect to parameters using PyTorch's automatic differentiation, it is rather inefficient as it does not make usage of parallelism in GPUs. NNGeometry's generator instead allows to use minibatches of examples by intercepting PyTorch's gradients and using techniques such as those in (Goodfellow, 2015) and (Rochette et al., 2019) : Let us consider f (x, w) : X × R d → R c . In order to simplify exposition, we focus on fully connected layers and suppose that f can be written f (x, w) = σ l • g l (•, w) • σ l-1 • g l-1 (•, w) • . . . • σ 1 • g 1 (x, w) where σ k are activation functions and g k are parametric affine transformations that compute pre-activations s k of a layer using a weight matrix W l and a bias vector b k with the following expression: s k = g k (a k-1 , w) = W k a k-1 + b k . For each example x i in a minibatch, we denote these intermediate quantities by superscripting s the corresponding matrix of activations of the same layer. These are already computed when performing the backpropagation algorithm, then used to obtain the average gradient w.r.t the weight matrix by means of the matrix/matrix product Goodfellow (2015) is that we can in addition obtain individual gradients ∂f (xi,w) ∂ ∂W l { i f (x i , w)} = Ds k a k . The observation of ∂s (1) k a (i) k , an operation that can be efficiently done simultaneously for all examples of the minibatch using the bmm PyTorch function. While we used this already known trick as an example of how to make profit of minibatching, NNGeometry's generator incorporate similar tricks in several other places, including in implicit operations. Instead of reimplementing backpropagation as is for example done by Dangel et al. (2019) , we chose to use PyTorch's internal automatic differentiation mechanism, as it already handles most corner cases encountered by deep learning practitioners: we do not have to reimplement backward computations for every new layer, but instead we just have to compute individual gradients by intercepting gradients with respect to pre-activations Ds k . Other generators are to be added to NNGeometry in the future, either by using different ways of computing the Jacobians, or by populating representations using other matrices such as the Hessian matrix, or the KFRA approximation of the FIM (Botev et al., 2017) .

3. EXPERIMENTAL SHOWCASE

Equipped with NNGeometry, we experiment with a large network: We train a 24M parameters Resnet50 network on TinyImagenet. We emphasize that given the size of the network, we would not have been able to compute operations involving the true F without NNGeometry's PMatImplicit representation, since F would require 2.3 petabytes of memory (24M × 24M × 4 bytes for float32).

3.1. QUALITY OF FIM APPROXIMATIONS

We start by comparing the accuracy of several PMat representations at computing various linear algebra operations. We use a Monte-Carlo estimate of the FIM, where we use 5 samples from p (y|x) for each example x. Here, since this TinyImagenet is a classification task, p (y|x) is a multinoulli distribution with the event probabilities given by the softmax layer. We compare the approximate value obtained for each representation, to a "true" value, obtained using the full matrix with the PMatImplicit representation. For trace and v F v, we compare these quantities using the relative difference approx-true true . For F v, we report the cos-angle 1 F v 2 Fapproxv 2 F v, F approx v , and for the solve operation, we report the cos-angle between v and (F approx + λI) -1 (F + λI) v. and cos angle between v and v = (F approx + λI) -1 (F + λI) v for a 24M parameters Resnet50 at different points during training on TinyImagenet, using different approximations F approx of F , for v uniformly sampled on the unit sphere (higher is better). Figure 4 : Cos angle between F v and F approx v for a 24M parameters Resnet50 at different points during training on TinyImagenet, using different approximations F approx of F , for v uniformly sampled on the unit sphere (higher is better). Since the latter is highly dependent on the Tikhonov regularization parameter λ, we plot the effect on the cos-angle of varying the value of λ. The results can be observed in figures 3, 4, 5, 6. From this experiment, there is no best representation for all linear algebra operations. Instead, this analysis suggest to use PMatKFAC when possible for operations involving the inverse FIM, and PMatEKFAC for operations involving the (forward) FIM. Other representations are less accurate, but should not be discarded as they can offer other advantages, such as lower memory footprint, and faster operations.

3.2. NEURAL TANGENT KERNEL EIGENVECTORS

In the line of Baratin et al. (2020) ; Paccolat et al. (2020) , we observe the evolution of the NTK during training. We use the Resnet50 on the 200 classes of TinyImagenet, but in order to be able to plot a 2d matrix for analysis, we extract the function On this larger network, we reproduce the conclusion of Baratin et al. (2020) ; Paccolat et al. (2020) that the NTK evolution is not purely random during training, but instead adapts to the task in a very specific way. f c1,c2 (x, w) = (f (x, w)) c2 -(f (x, w)) c1 ,

4. CONCLUSION

We introduced NNGeometry, a PyTorch library that allows to compute various linear algebra operations involving Fisher Information Matrices and Neural Tangent Kernels, using an efficient implementation that is versatile enough given current usages of these matrices, while being easy enough to save time for the user. 



The Landau notation o (pronounced "little-o") means a function whose exact value is irrelevant, with the property that limx→0 o(x)x = 0, or in other words that is negligible compared to x for small x. We here use the notion of distance informally. The second moment of loss gradients is sometimes called empirical Fisher.



Figure1: Computing a vector-Fisher-vector product v F v, for a 10-fold classification model defined by model, can be implemented with the same piece of code for 2 representations of the FIM using NNGeometry, even if they involve very different computations under the hood.

backpropagation algorithm applied to computing gradients of a sum S = i f (x i , w) works by sequentially computing intermediate gradients ∂f (xi,w) ∂s (i) k from top layers to bottom layers. Denote by Ds k = ∂f (xi,w) ∂s (1) k , . . . , ∂f (xi,w) ∂s (m) k the matrix obtained by stacking these gradients for a minibatch of size m, and a k = a (1) k , . . . , a (m) k

Figure 3: Residual v-v 2

namely a binary classifier of class c 2 vs class c 1 . We plot at different points during training i) the Gram matrix of examples from the 2 classes c 1 and c 2 (figure 7, top row) and ii) a kernel pca of points from classes c 1 and c 2 projected on the 2 first principal components (figure 7, bottom row). The Gram matrix is computed for valid set examples of classes c 1 and c 2 .

Figure 6: Relative difference of trace computed using F approx and F (lower is better). As we observe, all 3 representations PMatDiag, PMatQuasiDiag and PMatEKFAC estimate the trace very accurately, since the only remaining fluctuation comes from Monte-Carlo sampling of the FIM. On the other hand, the estimation provided by PMatKFAC is less accurate.

Memory usage and computational cost A FIM matrix is d × d where d is the total number of parameters. With a memory cost in O d 2 , this is prohibitively costly even for moderate size networks. Typical linear algebra operations have a computational cost in either O d 2 (e.g. matrixvector product) or even O d 3 (e.g. matrix inverse). NNGeometry instead comes with recent lower memory intensive approximations.

Relative difference between v F v and v F approx v for a 24M parameters Resnet50 at different points during training on TinyImagenet, using different approximations F approx of F , for v uniformly sampled on the unit sphere (higher is better).

acknowledgement

We hope that NNGeometry will help make progress across deep learning subfields as FIMs and NTKs are used in a range of applications.

availability

https://github. com/OtUmm7ojOrv/nngeometry.

