ISAAC NEWTON: INPUT-BASED APPROXIMATE CURVATURE FOR NEWTON'S METHOD

Abstract

We present ISAAC (Input-baSed ApproximAte Curvature), a novel method that conditions the gradient using selected second-order information and has an asymptotically vanishing computational overhead, assuming a batch size smaller than the number of neurons. We show that it is possible to compute a good conditioner based on only the input to a respective layer without a substantial computational overhead. The proposed method allows effective training even in small-batch stochastic regimes, which makes it competitive to first-order as well as second-order methods. * for layers 1, 2, 3, 4, 5 * for layers 1 * for layers 5 * for layers 1, 2, 3 * for layers 3, 4, 5 * for layers 1, 3, 5 * for layers 2, 4

1. INTRODUCTION

While second-order optimization methods are traditionally much less explored than first-order methods in large-scale machine learning (ML) applications due to their memory requirements and prohibitive computational cost per iteration, they have recently become more popular in ML mainly due to their fast convergence properties when compared to first-order methods [1] . The expensive computation of an inverse Hessian (also known as pre-conditioning matrix) in the Newton step has also been tackled via estimating the curvature from the change in gradients. Loosely speaking, these algorithms are known as quasi-Newton methods; for a comprehensive treatment, see Nocedal & Wright [2] . Various approximations to the pre-conditioning matrix have been proposed in recent literature [3] - [6] . From a theoretical perspective, second-order optimization methods are not nearly as well understood as first-order methods. It is an active research direction to fill this gap [7] , [8] . Motivated by the task of training neural networks, and the observation that invoking local curvature information associated with neural network objective functions can achieve much faster progress per iteration than standard first-order methods [9] - [11] , several methods have been proposed. One of these methods, that received significant attention, is known as Kronecker-factored Approximate Curvature (K-FAC) [12] , whose main ingredient is a sophisticated approximation to the generalized Gauss-Newton matrix and the Fisher information matrix quantifying the curvature of the underlying neural network objective function, which then can be inverted efficiently. Inspired by the K-FAC approximation and the Tikhonov regularization of the Newton method, we introduce a novel two parameter regularized Kronecker-factorized Newton update step. The proposed scheme disentangles the classical Tikhonov regularization and in a specific limit allows us to condition the gradient using selected second-order information and has an asymptotically vanishing computational overhead. While this case makes the presented method highly attractive from the computational complexity perspective, we demonstrate that its empirical performance on high-dimensional machine learning problems remains comparable to existing SOTA methods. The contributions of this paper can be summarized as follows: (i) we propose a novel two parameter regularized K-FAC approximated Gauss-Newton update step; (ii) we prove that for an arbitrary pair of regularization parameters, the proposed update direction is always a direction of decreasing loss; (iii) in the limit, as one regularization parameter grows, we obtain an efficient and effective conditioning of the gradient with an asymptotically vanishing overhead; (iv) we empirically analyze the method and find that our efficient conditioning method maintains the performance of its more expensive counterpart; (v) we demonstrate the effectiveness of the method in small-batch stochastic regimes and observe performance competitive to first-order as well as quasi-Newton methods.

2. PRELIMINARIES

In this section, we review aspects of second-order optimization, with a focus on generalized Gauss-Newton methods. In combination with Kronecker factorization, this leads us to a new regularized update scheme. We consider the training of an L-layer neural network f (x; θ) defined recursively as z i ← a i-1 W (i) (pre-activations), a i ← ϕ(z i ) (activations), where a 0 = x is the vector of inputs and a L = f (x; θ) is the vector of outputs. Unless noted otherwise, we assume these vectors to be row vectors (i.e., in R 1×n ) as this allows for a direct extension to the (batch) vectorized case (i.e., in R b×n ) introduced later. For any layer i, let W (i) ∈ R di-1×di be a weight matrix and let ϕ be an element-wise nonlinear function. We consider a convex loss function L(y, y ′ ) that measures the discrepancy between y and y ′ . The training optimization problem is then arg min θ E x,y [L(f (x; θ), y)] , where θ = θ (1) , . . . , θ (L) with θ (i) = vec(W (i) ). The classical Newton method for solving (2) is expressed as the update rule θ ′ = θ -η H -1 θ ∇ θ L(f (x; θ), y) , where η > 0 denotes the learning rate and H θ is the Hessian corresponding to the objective function in (2) . The stability and efficiency of an estimation problem solved via the Newton method can be improved by adding a Tikhonov regularization term [13] leading to a regularized Newton method θ ′ = θ -η (H θ + λI) -1 ∇ θ L(f (x; θ), y) , where λ > 0 is the so-called Tikhonov regularization parameter. It is well-known [14] , [15] , that under the assumption of approximating the model f with its first-order Taylor expansion, the Hessian corresponds with the so-called generalized Gauss-Newton (GGN) matrix G θ , and hence (4) can be expressed as θ ′ = θ -η (G θ + λI) -1 ∇ θ L(f (x; θ), y) . A major practical limitation of ( 5) is the computation of the inverse term. A method that alleviates this difficulty is known as Kronecker-Factored Approximate Curvature (K-FAC) [12] which approximates the block-diagonal (i.e., layer-wise) empirical Hessian or GGN matrix. Inspired by K-FAC, there have been other works discussing approximations of G θ and its inverse [15] . In the following, we discuss a popular approach that allows for (moderately) efficient computation. The generalized Gauss-Newton matrix G θ is defined as G θ = E (J θ f (x; θ)) ⊤ ∇ 2 f L(f (x; θ), y) J θ f (x; θ) , where J and ∇ 2 denote the Jacobian and Hessian matrices, respectively. Correspondingly, the diagonal block of G θ corresponding to the weights of the ith layer W (i) is G W (i) =E (J W (i) f (x; θ)) ⊤ ∇ 2 f L(f (x; θ), y) J W (i) f (x; θ) . According to the backpropagation rule J W (i) f (x; θ) = J zi f (x; θ) a i-1 , a ⊤ b = a ⊗ b, and the mixed-product property, we can rewrite G W (i) as G W (i) =E (J zi f (x; θ) a i-1 ) ⊤ (∇ 2 f L(f (x; θ), y)) 1/2 (∇ 2 f L(f (x; θ), y)) 1/2 J zi f (x; θ) a i-1 (7) =E (ḡ ⊤ a i-1 ) ⊤ (ḡ ⊤ a i-1 ) = E (ḡ ⊗ a i-1 ) ⊤ (ḡ ⊗ a i-1 ) = E (ḡ ⊤ ḡ) ⊗ (a ⊤ i-1 a i-1 ) , where ḡ = (J zi f (x; θ)) ⊤ (∇ 2 f L(f (x; θ), y)) 1/2 . ( ) Remark 1 (Monte-Carlo Low-Rank Approximation for ḡ⊤ ḡ). As ḡ is a matrix of shape m × d i where m is the dimension of the output of f , ḡ is generally expensive to compute. Therefore, [12] use a low-rank Monte-Carlo approximation to estimate ∇ 2 f L(f (x; θ), y) and thereby ḡ⊤ ḡ. For this, we need to use the distribution underlying the probabilistic model of our loss L (e.g., Gaussian for MSE loss, or a categorical distribution for cross entropy). Specifically, by sampling from this distribution p f (x) defined by the network output f (x; θ), we can get an estimator of ∇ 2 f L(f (x; θ), y) via the identity ∇ 2 f L(f (x; θ), y) = E ŷ∼p f (x) ∇ f L(f (x; θ), ŷ) ⊤ ∇ f L(f (x; θ), ŷ) . ( ) An extensive reference for this (as well as alternatives) can be found in Appendix A.2 of Dangel et al. [15] . The respective rank-1 approximation (denoted by ≜) of ∇ 2 f L(f (x; θ)) is ∇ 2 f L(f (x; θ), y) ≜ ∇ f L(f (x; θ), ŷ) ⊤ ∇ f L(f (x; θ), ŷ) , where ŷ ∼ p f (x). Respectively, we can estimate ḡ⊤ ḡ using this rank-1 approximation with ḡ ≜ (J zi f (x; θ)) ⊤ ∇ f L(f (x; θ), ŷ) = ∇ zi L(f (x; θ), ŷ) . In analogy to ḡ, we introduce the gradient of training objective with respect to pre-activations z i as g i = (J zi f (x; θ)) ⊤ ∇ f L(f (x; θ), y) = ∇ zi L(f (x; θ), y) . In other words, for a given layer, let g ∈ R 1×di denote the gradient of the loss between an output and the ground truth and let ḡ ∈ R m×di denote the derivative of the network f times the square root of the Hessian of the loss function (which may be approximated according to Remark 1), each of them with respect to the output z i of the given layer i. Note that ḡ is not equal to g and that they require one backpropagation pass each (or potentially many for the case of ḡ). This makes computing ḡ costly. Applying the K-FAC [12] approximation to (8) the expectation of Kronecker products can be approximated as the Kronecker product of expectations as G = E((ḡ ⊤ ḡ) ⊗ (a ⊤ a)) ≈ E(ḡ ⊤ ḡ) ⊗ E(a ⊤ a) , where, for clarity, we drop the index of a i-1 in (8) and denote it with a; similarly we denote G W (i) as G. While the expectation of Kronecker products is generally not equal to the Kronecker product of expectations, this K-FAC approximation (13) has been shown to be fairly accurate in practice and to preserve the "coarse structure" of the GGN matrix [12] . The K-FAC decomposition in ( 13) is convenient as the Kronecker product has the favorable property that for two matrices A, B the identity (A ⊗ B) -1 = A -1 ⊗ B -1 which significantly simplifies the computation of an inverse. In practice, E(ḡ ⊤ ḡ) and E(a ⊤ a) can be computed by averaging over a batch of size b as E(ḡ ⊤ ḡ) ≃ ḡ ḡ ḡ⊤ ḡ ḡ ḡ/b, E(a ⊤ a) ≃ a ⊤ a/b, where we denote batches of g, ḡ and a, as g ∈ R b×di , ḡ ḡ ḡ ∈ R rb×di and a ∈ R b×di-1 , where our layer has d i-1 inputs, d i outputs, b is the batch size, and r is either the number of outputs m or the rank of an approximation according to Remark 1. Correspondingly, the K-FAC approximation of the GGN matrix and its inverse are concisely expressed as G ≈ (ḡ ḡ ḡ⊤ ḡ ḡ ḡ) ⊗ (a ⊤ a)/b 2 G -1 ≈ ḡ ḡ ḡ⊤ ḡ ḡ ḡ -1 ⊗ a ⊤ a -1 • b 2 . ( ) Equipped with the standard terminology and setting, we now introduce the novel, regularized update step. First, inspired by the K-FAC approximation (13) , the Tikhonov regularized Gauss-Newton method (5) can be approximated by θ (i)′ = θ (i) -η(ḡ ḡ ḡ⊤ ḡ ḡ ḡ/b + λI) -1 ⊗ (a ⊤ a/b + λI) -1 ∇ θ (i) L(f (x; θ)), with regularization parameter λ > 0. A key observation, which is motivated by the structure of the above update, is to disentangle the two occurrences of λ into two independent regularization parameters λ g , λ a > 0. By defining the Kronecker-factorized Gauss-Newton update step as ζ ζ ζ = λ g λ a (ḡ ḡ ḡ⊤ ḡ ḡ ḡ/b + λ g I) -1 ⊗ (a ⊤ a/b + λ a I) -1 ∇ θ (i) L(f (x; θ)), we obtain the concise update equation θ (i)′ = θ (i) -η * ζ ζ ζ. This update ( 18) is equivalent to update (16) when in the case of η * = η λgλa and λ = λ g = λ a . This equivalence does not restrict η * , λ g , λ a in any way, and changing λ g or λ a does not mean that we change our learning rate or step size η * . Parameterizing ζ ζ ζ in (17) with the multiplicative terms λ g λ a makes the formulation more convenient for analysis. In this paper, we investigate the theoretical and empirical properties of the iterative update rule (18) and in particular show how the regularization parameters λ g , λ a affect the Kronecker-factorized Gauss-Newton update step ζ ζ ζ. When analyzing the Kronecker-factorized Gauss-Newton update step ζ ζ ζ, a particularly useful tool is the vector product identity, ḡ ḡ ḡ⊤ ḡ ḡ ḡ -1 ⊗ a ⊤ a -1 vec(g ⊤ a) = vec ḡ ḡ ḡ⊤ ḡ ḡ ḡ -1 g ⊤ a a ⊤ a -1 , where the gradient with respect to the weight matrix is g ⊤ a.

3. THEORETICAL GUARANTEES

In this section, we investigate the theoretical properties of the Kronecker-factorized Gauss-Newton update direction ζ ζ ζ as defined in (17) . We recall that ζ ζ ζ introduces a Tikonov regularization, as it is commonly done in implementations of second order-based methods. Not surprisingly, we show that by decreasing the regularization parameters λ g , λ a the update rule (18) collapses (in the limit) to the classical Gauss-Newton method, and hence in the regime of small λ g , λ a the variable ζ ζ ζ describes the Gauss-Newton direction. Moreover, by increasing the regularization strength, we converge (in the limit) to the conventional gradient descent update step. The key observation is that, as we disentangle the regularization of the two Kronecker factors ḡ ḡ ḡ⊤ ḡ ḡ ḡ and a ⊤ a, and consider the setting where only one regularizer is large (λ g → ∞ to be precise), we obtain an update direction that can be computed highly efficiently. We show that this setting describes an approximated Gauss-Newton update scheme, whose superior numerical performance is then empirically demonstrated in Section 4.  ζ ζ ζ = I m - 1 bλ g ḡ ḡ ḡ⊤ I b + 1 bλ g ḡ ḡ ḡḡ ḡ ḡ⊤ -1 ḡ ḡ ḡ • g ⊤ • I b - 1 bλ a aa ⊤ I b + 1 bλ a aa ⊤ -1 • a . (i) In the limit of λ g , λ a → 0, 1 λgλa ζ ζ ζ is the K-FAC approximation of the Gauss-Newton step, i.e., lim λg,λa→0 1 λgλa ζ ζ ζ ≈ G -1 ∇ θ (i) L(f (x; θ)) , where ≈ denotes the K-FAC approximation (15) . (ii) In the limit of λ g , λ a → ∞, ζ ζ ζ is the gradient, i.e., lim λg,λa→∞ ζ ζ ζ = ∇ θ (i) L(f (x; θ)). The Proof is deferred to the Supplementary Material. We want to show that ζ ζ ζ is well-defined and points in the correct direction, not only for λ g and λ a numerically close to zero because we want to explore the full spectrum of settings for λ g and λ a . Thus, we prove that ζ ζ ζ is a direction of increasing loss, independent of the choices of λ g and λ a .  λ g I m + ḡ ḡ ḡ⊤ ḡ ḡ ḡ/b) -1 ⊗ (λ a I n + a ⊤ a/b) -1 ≈ G -1 ) is PSD, and therefore the direction of the update step remains correct. This leads us to our primary contribution: From our formulation of ζ ζ ζ, we can find that, in the limit for λ g → ∞, Equation ( 21) does not depend on ḡ ḡ ḡ. This is computationally very beneficial as computing ḡ ḡ ḡ is costly as it requires one or even many additional backpropagation passes. In addition, it allows conditioning the gradient update by multiplying a b × b matrix between g ⊤ and a, which is very fast. Theorem 3 (Efficient Update Direction / ISAAC). In the limit of λ g → ∞, the update step Proof. We first show the property (21) . Note that according to (22) , λ g • λ g I m + ḡ ḡ ḡ⊤ ḡ ḡ ḡ/b -foot_0 converges in the limit of λ g → ∞ to I m , and therefore (21) holds. ζ ζ ζ converges to lim λg→∞ ζ ζ ζ = ζ ζ ζ * , where ζ ζ ζ * = g ⊤ • I b - 1 bλ a aa ⊤ I b + 1 bλ a aa ⊤ -1 • a . (i) The statement follows from the fact that the term ḡ ḡ ḡ does not appear in the equivalent characterization ( 21 Next, the matrix I b - 1 bλ a aa ⊤ I b + 1 bλ a aa ⊤ -1 is of shape b × b and can be multiplied with a in O(b 2 n) time. Notably, ( 21) can be computed with a vanishing computational overhead and with only minor modifications to the implementation. Specifically, only the g ⊤ a expression has to be replaced by (21) in the backpropagation step. As this can be done independently for each layer, this lends itself also to applying it only to individual layers. As we see in the experimental section, in many cases in the mini-batch regime (i.e., b < n), the optimal (or a good) choice for λ g actually lies in the limit to ∞. This is a surprising result, leading to the efficient and effective In some cases, it might be more desirable to use the Fisher-based natural gradient instead of the Gauss-Newton method. The difference to this setting is that in (5) the GGN matrix G is replaced by the empirical Fisher information matrix F. ζ ζ ζ * = ζ ζ ζ λg→∞ optimizer. We note that our theory also applies to F, and that ζ ζ ζ * also efficiently approximates the natural gradient update step F -1 ∇. The i-th diagonal block of F (F θ (i) = E (g ⊤ i g i ) ⊗ (a ⊤ i-1 a i-1 ) ), has the same form as a block of the GGN matrix G (G θ (i) = E (ḡ ⊤ i ḡi ) ⊗ (a ⊤ i-1 a i-1 ) ). Thus, we can replace ḡ ḡ ḡ with g in our theoretical results to obtain their counterparts for F.

4. EXPERIMENTS 1

In the previous section, we discussed the theoretical properties of the proposed update directions 

E1: IMPACT OF REGULARIZATION PARAMETERS

For (E1), we study the dependence of the model's performance on the regularization parameters λ g and λ a . Here, we train a 5-layer deep neural network on the MNIST classification task [16] with a batch size of 60 for a total of 40 epochs or 40 000 steps. The plots in Figure 1 demonstrate that the advantage of training by conditioning with curvature information can be achieved by considering both layer inputs a and gradients with respect to random samples ḡ ḡ ḡ, but also using only layer inputs a. In the plot, we show the performance of ζ ζ ζ for different choices of λ g and λ a , each in the range from 10 -6 to 10 6 . The right column shows ζ ζ ζ * , i.e., λ g = ∞, for different λ a . The bottom-right corner is gradient descent, which corresponds to λ g = ∞ and λ a = ∞. Newton's method or the general K-FAC approximation corresponds to the area with small λ g and λ a . The interesting finding here is that the performance does not suffer by increasing λ g toward ∞, i.e., from left to right in the plot. In addition, in Figure 3 , we consider the case of regression with an auto-encoder trained with the MSE loss on MNIST [16] and Fashion-MNIST [17] . Here, we follow the same principle as above and also find that ζ ζ ζ * performs well. Figure 4 : Training a 5-layer ReLU network with 400 neurons per layer on the MNIST classification task (as in Figure 1 ) but with the Adam optimizer [19] . In Figure 2 , we compare the loss for different methods. Here, we distinguish between loss per time (left) and loss per number of steps (right). We can observe that, for λ = 0. [20] on three natural language tasks. In Table 1 , we summarize the results for the BERT fine-tuning task. For the "Corpus of Linguistic Acceptability" (CoLA) [21] data set, we fine-tune both the BERT-Base and the BERT-Mini models and find that we outperform the gradient descent baseline in both cases. For the "Microsoft Research Paraphrase Corpus" (MRPC) [22] data set, we fine-tune the BERT-Base model and find that we outperform the baseline both in terms of accuracy and F1-score. Finally, on the "Semantic Textual Similarity Benchmark" (STS-B) [23] data set, we fine-tune the BERT-Mini model and achieve higher Pearson and Spearman correlations than the baseline. While for training with CoLA and MRPC, we were able to use the Adam optimizer [19] (which is recommended for this task and model) in conjunction with ζ ζ ζ * in place of the gradient, for STS-B Adam did not work well. Therefore, for STS-B, we evaluated it using the SGD with momentum optimizer. For each method, we performed a grid search over the hyperparameters. We note that we use a batch size of 8 in all BERT experiments. ResNet In addition, we conduct an experiment where we train the last layer of a ResNet with ζ ζ ζ * , while the remainder of the model is updated using the gradient ∇. Here, we train a ResNet-18 [24] on CIFAR-10 [25] using SGD with a batch size of 100. In Figure 6 , we plot the test accuracy against number of epochs. The times for each method lie within 1% of each other. We consider three settings: the typical setting with momentum and weight decay, a setting with only momentum, and a setting with vanilla SGD without momentum. The results show that the proposed method outperforms SGD in each of these cases. While the improvements are rather small in the case of the default training, they are especially large in the case of no weight decay and no momentum.

E5: RUNTIME AND MEMORY

Finally, we also evaluate the runtime and memory requirements of each method. The runtime evaluation is displayed in Table 2 . We report both CPU and GPU runtime using PyTorch [26] and (for K-FAC) the backpack library [15] . Note that the CPU runtime is more representative of the pure computational cost, as for the first rows of the GPU runtime the overhead of calling the GPU is dominant. When comparing runtimes between the gradient and ζ ζ ζ * on the GPU, we can observe that we have an overhead of around 2.5 s independent of the model size. The overhead for CPU time is also very small at less than 1% for the largest model, and only 1. We will publish the source code of our implementation. In the appendix, we give a PyTorch [26] implementation of the proposed method (ζ ζ ζ * ).

5. RELATED WORK

Our methods are related to K-FAC by Martens and Grosse [12] . K-FAC uses the approximation (13) to approximate the blocks of the Hessian of the empirical risk of neural networks. In most implementations of K-FAC, the off-diagonal blocks of the Hessian are also set to zero. One of the main claimed benefits of K-FAC is its speed (compared to stochastic gradient descent) for large-batch size training. That said, recent empirical work has shown that this advantage of K-FAC disappears once the additional computational costs of hyperparameter tuning for large batch training is accounted for. There is a line of work that extends the basic idea of K-FAC to convolutional layers [27] . Botev et al. [18] further extend these ideas to present KFLR, a Kronecker factored low-rank approximation, and KFRA, a Kronecker factored recursive approximation of the Gauss-Newton step. Singh and Alistarh [28] propose WoodFisher, a Woodbury matrix inverse-based estimate of the inverse Hessian, and apply it to neural network compression. Yao et al. [29] propose AdaHessian, a second-order optimizer that incorporates the curvature of the loss function via an adaptive estimation of the Hessian. Frantar et al. [6] propose M-FAC, a matrix-free approximation of the natural gradient through a queue of the (e.g., 1 000) recent gradients. These works fundamentally differ from our approach in that their objective is to approximate the Fisher or Gauss-Newton matrix inverse vector products. In contrast, this work proposes to approximate the Gauss-Newton matrix by only one of its Kronecker factors, which we find to achieve good performance at a substantial computational speedup and reduction of memory footprint. For an overview of this area, we refer to Kunstner et al. [30] and Martens [31] . For an overview of the technical aspects of backpropagation of second-order quantities, we refer to Dangel et al. [15] , [32] Taking a step back, K-FAC is one of many Newton-type methods for training neural networks. Other prominent examples of such methods include subsampled Newton methods [33] , [34] (which approximate the Hessian by subsampling the terms in the empirical risk function and evaluating the Hessian of the subsampled terms) and sketched Newton methods [3]- [5] (which approximate the Hessian by sketching, e.g., by projecting the Hessian to a lower-dimensional space by multiplying it with a random matrix). Another quasi-Newton method [35] proposes approximating the Hessian by a block-diagonal matrix using the structure of gradient and Hessian to further approximate these blocks. The main features that distinguish K-FAC from this group of methods are K-FAC's superior empirical performance and K-FAC's lack of theoretical justification. self.bias.unsqueeze(0) if self.bias is not None else None, self.la, self.inv_type )

B IMPLEMENTATION DETAILS

Unless noted differently, for all experiments, we tune the learning rate on a grid of (1, 0.3, 0.1, 0.03, 0.01, 0.003, 0.001). We verified this range to cover the full reasonable range of learning rates. Specifically, for every single experiment, we made sure that there is no learning rate outside this range which performs better. For all language model experiments, we used the respective Huggingface PyTorch implementation. All other hyperparameter details are given in the main paper.

C ADDITIONAL PROOFS

Proof of Theorem 1. We first show, that ζ ζ ζ as defined in ( 17) can be expressed as in (20) . Indeed by using (19) , the Woodbury matrix identity and by regularizing the inverses, we can see that ζ ζ ζ = λ g λ a (ḡ ḡ ḡ⊤ ḡ ḡ ḡ/b + λ g I) -1 ⊗ (a ⊤ a/b + λ a I) -1 g ⊤ a = λ g λ a • λ g I m + ḡ ḡ ḡ⊤ ḡ ḡ ḡ/b -1 g ⊤ a λ a I n + a ⊤ a/b -1 = λ g λ a • 1 λ g I m - 1 bλ g 2 ḡ ḡ ḡ⊤ I b + 1 bλ g ḡ ḡ ḡḡ ḡ ḡ⊤ -1 ḡ ḡ ḡ g ⊤ a 1 λ a I n - 1 bλ a 2 a ⊤ I b + 1 bλ a aa ⊤ -1 a = I m - 1 bλ g ḡ ḡ ḡ⊤ I b + 1 bλ g ḡ ḡ ḡḡ ḡ ḡ⊤ -1 ḡ ḡ ḡ • g ⊤ • a • I n - 1 bλ a a ⊤ I b + 1 bλ a aa ⊤ -1 a = I m - 1 bλ g ḡ ḡ ḡ⊤ I b + 1 bλ g ḡ ḡ ḡḡ ḡ ḡ⊤ -1 ḡ ḡ ḡ • g ⊤ • a - 1 bλ a aa ⊤ I b + 1 bλ a aa ⊤ -1 a = I m - 1 bλ g ḡ ḡ ḡ⊤ I b + 1 bλ g ḡ ḡ ḡḡ ḡ ḡ⊤ -1 ḡ ḡ ḡ • g ⊤ • I b - 1 bλ a aa ⊤ I b + 1 bλ a aa ⊤ -1 • a To show Assertion (i), we note that according to (17) lim λg,λa→0 1 λ g λ a ζ ζ ζ = lim λg,λa→0 (ḡ ḡ ḡ⊤ ḡ ḡ ḡ/b + λ g I) -1 ⊗ (a ⊤ a/b + λ a I) -1 g ⊤ a = (ḡ ḡ ḡ⊤ ḡ ḡ ḡ) -1 ⊗ (a ⊤ a) -1 g ⊤ a ≈ G -1 g ⊤ a, where the first equality uses the definition of ζ ζ ζ in (17) . The second equality is due to the continuity of the matrix inversion and the last approximate equality follows from the K-FAC approximation (15) . To show Assertion (ii), we consider lim λg→∞ and lim λa→∞ independently, that is lim λg→∞ λ g • λ g I m + ḡ ḡ ḡ⊤ ḡ ḡ ḡ/b -1 = lim λg→∞ I m + 1 bλ g ḡ ḡ ḡ⊤ ḡ ḡ ḡ -1 = I m , lim λa→∞ λ a • λ a I n + a ⊤ a/b -1 = lim λa→∞ I n + 1 bλ a a ⊤ a -1 = I n . This then implies lim λg,λa→∞ λ g λ g I m + ḡ ḡ ḡ⊤ ḡ ḡ ḡ/b -1 • g ⊤ (24) • a • λ a λ a I n + a ⊤ a/b -1 = I m • g ⊤ a • I n = g ⊤ a, which concludes the proof. We observe that, for large λg, the behavior is similar to Figure 1 , which is expected as they are the same in the limit of λg → ∞. Further, we observe that (in this case of the Fisher-based ζ ζ ζ) not only in the limit of λg → ∞ but also in the limit of λa → ∞ good performance can be achieved. Moreover, in this specific experiment, λa → ∞ has slightly better optimal performance compared to λg → ∞, but λa → ∞ is more sensitive to changes in λg compared to the sensitivity of the case of λg → ∞ wrt. changes in λa. This phenomenon was also (to a lesser extent) visible in the experiments of Figure 1 . We would like to remark that the case of λg → ∞ (i.e., ζ ζ ζ ⋆ ) is computationally more efficient compared to λa → ∞.



Code will be made available at github.com/Felix-Petersen/isaac CONCLUSIONIn this work, we presented ISAAC Newton, a novel approximate curvature method based on layerinputs. We demonstrated it to be a special case of the regularization-generalized Gauss-Newton method and empirically demonstrate its utility. Specifically, our method features an asymptotically vanishing computational overhead in the mini-batch regime, while achieving competitive empirical performance on various benchmark problems.



Properties of ζ ζ ζ). The K-FAC based update step ζ ζ ζ as defined in (17) can be expressed as

Moreover, ζ ζ ζ admits the following asymptotic properties:

Correctness of ζ ζ ζ is independent of λ g and λ a ). ζ ζ ζ is a direction of increasing loss, independent of the choices of λ g and λ a . Proof. Recall that (λ g I m +ḡ ḡ ḡ⊤ ḡ ḡ ḡ/b) and (λ a I n +a ⊤ a/b) are positive semi-definite (PSD) matrices by definition. Their inverses (λ g I m + ḡ ḡ ḡ⊤ ḡ ḡ ḡ/b) -1 and (λ a I n + a ⊤ a/b) -1 are therefore also PSD. As the Kronecker product of PSD matrices is PSD, the conditioning matrix ((

(i) Here, the update direction ζ ζ ζ * is based only on the inputs and does not require computing ḡ ḡ ḡ (which would require a second backpropagation pass), making it efficient. (ii) The computational cost of computing the update ζ ζ ζ * lies in O(bn 2 + b 2 n + b 3 ), where n is the number of neurons in each layer. This comprises the conventional cost of computing the gradient ∇ = g ⊤ x lying in O(bn 2 ), and the overhead of computing ζ ζ ζ * instead of ∇ lying in O(b 2 n + b 3 ). The overhead is vanishing, assuming n ≫ b. For b > n the complexity lies in O(bn 2 + n 3 ).

) of ζ ζ ζ * . (ii) We first note that the matrix aa ⊤ is of dimension b × b, and can be computed in O(b 2 n) time.

Relation between Update Direction ζ ζ ζ and ζ ζ ζ * ). When comparing the update direction ζ ζ ζ in (20) without regularization (i.e., λ g → 0, λ a → 0) with ζ ζ ζ * (i.e., λ g → ∞) as given in (21), it can be directly seen that ζ ζ ζ * corresponds to a particular pre-conditioning of ζ ζ ζ, since ζ ζ ζ * = Mζ ζ ζ for M = 1 bλg ḡ ḡ ḡ⊤ ḡ ḡ ḡ. As the last theoretical property of our proposed update direction ζ ζ ζ * , we show that in specific networks ζ ζ ζ * coincides with the Gauss-Newton update direction. Theorem 4 (ζ ζ ζ * is Exact for the Last Layer). For the case of linear regression or, more generally, the last layer of networks, with the mean squared error, ζ ζ ζ * is the Gauss-Newton update direction. Proof. The Hessian matrix of the mean squared error loss is the identity matrix. Correspondingly, the expectation value of ḡ ḡ ḡ⊤ ḡ ḡ ḡ is I. Thus, ζ ζ ζ * = ζ ζ ζ. The direction ζ ζ ζ * corresponds to the Gauss-Newton update direction with an approximation of G that can be expressed as G ≈ E I ⊗ (a ⊤ a) . Remark 4 (Extension to the Natural Gradient).

Figure 1: Logarithmic training loss (top) and test accuracy (bottom) on the MNIST classification task. The axes are the regularization parameters λg and λa in logarithmic scale with base 10. Training with a 5-layer ReLU activated network with 100 (left, a, e), 400 (center, b, c, f, g), and 1 600 (right, d, h) neurons per layer. The optimizer is SGD except for (c, g) where the optimizer is SGD with momentum. The top-left sector is ζ ζ ζ, the top-right column is ζ ζ ζ * , and the bottom-right corner is ∇ (gradient descent). For each experiment and each of the three sectors, we use one learning rate, i.e., ζ ζ ζ, ζ ζ ζ * , ∇ have their own learning rate to make a fair comparison between the methods; within each sector the learning rate is constant. We can observe that in the limit of λg → ∞ (i.e., in the limit to the right) the performance remains good, showing the utility of ζ ζ ζ * .

Figure 2: Training loss of the MNIST auto-encoder trained with gradient descent, K-FAC, ζ ζ ζ, and ζ ζ ζ * . Comparing the performance per real-time (left) and per number of update steps (right). Runtimes are for a CPU core.

Figure 3: Training an auto-encoder on MNIST (left) and Fashion-MNIST (right). The model is the same as used by Botev et al. [18], i.e., it is a ReLU-activated 6-layer fully connected model with dimensions 784-1000-500-30-500-1000-784. Displayed is the logarithmic training loss.

Figure 5: Training on the MNIST classification task using ζ ζ ζ * only in selected layers. Runtimes are for CPU.

we can see that training performs well for n ∈ {100, 400, 1 600} neurons per layer at a batch size of only 60. Also, in all other experiments, we use small batch sizes of between 8 and 100.E3: ζ ζ ζ * IN INDIVIDUAL LAYERSIn Figure5, we train the 5-layer fully connected model with 400 neurons per layer. Here, we consider the setting that we use ζ ζ ζ * in some of the layers while using the default gradient ∇ in other layers. Specifically, we consider the settings, where all, the first, the final, the first three, the final three, the odd numbered, and the even numbered layers are updated by ζ ζ ζ * . We observe that all settings with ζ ζ ζ * perform better than

Figure 6: ResNet-18 trained on CIFAR-10 with image augmentation and a cosine learning rate schedule. To ablate the optimizer, two additional settings are added, specifically, without weight decay and without momentum. Results are averaged over 5 runs and the standard deviation is indicated with the colored areas.

3 s for the smallest model. In contrast, the runtime of ζ ζ ζ * is around 4 times the runtime of the gradient, and K-FAC has an even substantially larger runtime. Regarding memory, ζ ζ ζ * (contrasting the other approaches) also requires only a small additional footprint.

( a, g = 0.1) K-FAC ( a, g = 0.01) ( a, g = 0.1) ( a, g = 0.01) SGD w/ Momentum SGD (bs=600) K-FAC ( a, g = 0.01) (bs=600) Adam

Figure 7: Training loss of the MNIST auto-encoder trained with gradient descent, K-FAC, ζ ζ ζ, ζ ζ ζ * , as well as SGD w/ momentum, SGD with a 10× larger batch size (600), K-FAC with a 10× larger batch size (600), and Adam. Comparing the performance per real-time (left) and per number of epochs (right). We display both the training loss (top) as well as the test loss (bottom) Runtimes are for a CPU core.

Figure 8: Test accuracy for training on the MNIST classification task using ζ ζ ζ * only in selected layers. Runtimes are for CPU.

Figure 9: Reproduction of the experiments in Figure1but with the Fisher-based natural gradient formulation from Remark 4. For a description of the experimental settings, see the caption of Figure1. We observe that, for large λg, the behavior is similar to Figure1, which is expected as they are the same in the limit of λg → ∞. Further, we observe that (in this case of the Fisher-based ζ ζ ζ) not only in the limit of λg → ∞ but also in the limit of λa → ∞ good performance can be achieved. Moreover, in this specific experiment, λa → ∞ has slightly better optimal performance compared to λg → ∞, but λa → ∞ is more sensitive to changes in λg compared to the sensitivity of the case of λg → ∞ wrt. changes in λa. This phenomenon was also (to a lesser extent) visible in the experiments of Figure1. We would like to remark that the case of λg → ∞ (i.e., ζ ζ ζ ⋆ ) is computationally more efficient compared to λa → ∞.

BERT results for fine-tuning pre-trained BERT-Base (B-B) and BERT-Mini (B-M) models on the COLA, MRPC, and STSB text classification tasks. Larger values are better for all metrics. MCC is the Matthews correlation. Results averaged over 10 runs. BERT To demonstrate the utility of ζ ζ ζ * also in large-scale models, we evaluate it for fine-tuning BERT

Runtimes and memory requirements for different models. Runtime is the training time per epoch on MNIST at a batch size of 60, i.e., for 1 000 training steps. The K-FAC implementation is from the backpack library[15]. The GPU is an Nvidia A6000. The implementation of ζ ζ ζ * can be done by replacing the backpropagation step of a respective layer by(21). As all "ingredients" are already available in popular deep learning frameworks, it requires only little modification (contrasting K-FAC and ζ ζ ζ, which require at least one additional backpropagation.)

ACKNOWLEDGMENTS

This work was supported by the IBM-MIT Watson AI Lab, the DFG in the Cluster of Excellence EXC 2117 "Centre for the Advanced Study of Collective Behaviour" (Project-ID 390829875), the Land Salzburg within the WISS 2025 project IDA-Lab (20102-F1901166-KZP and 20204-WISS/225/197-2019), and the National Science Foundation (NSF) (grants no. 1916271, 2027737, and 2113373).

A PYTORCH IMPLEMENTATION

We display a PyTorch [26] implementation of ISAAC for a fully-connected layer below. Here, we mark the important part (i.e., the part beyond the boilerplate) with a red rectangle. 

