FISHER-LEGENDRE (FISHLEG) OPTIMIZATION OF DEEP NEURAL NETWORKS

Abstract

Incorporating second-order gradient information (curvature) into optimization can dramatically reduce the number of iterations required to train machine learning models. In natural gradient descent, such information comes from the Fisher information matrix which yields a number of desirable properties. As exact natural gradient updates are intractable for large models, successful methods such as KFAC and sequels approximate the Fisher in a structured form that can easily be inverted. However, this requires model/layer-specific tensor algebra and certain approximations that are often difficult to justify. Here, we use ideas from Legendre-Fenchel duality to learn a direct and efficiently evaluated model for the product of the inverse Fisher with any vector, in an online manner, leading to natural gradient steps that get more accurate over time despite noisy gradients. We prove that the resulting "Fisher-Legendre" (FishLeg) optimizer converges to a (global) minimum of non-convex functions satisfying the PL condition, which applies in particular to deep linear networks. On standard auto-encoder benchmarks, we show empirically that FishLeg outperforms standard first-order optimization methods, and performs on par with or better than other second-order methods, especially when using small batches. Thanks to its generality, we expect our approach to facilitate the handling of a variety neural network layers in future work.

1. INTRODUCTION & SUMMARY OF CONTRIBUTIONS

The optimization of machine learning models often benefits from the use of second-order gradient (curvature) information, which can dramatically increase the per-iteration progress on the training loss. However, second-order optimizers for machine learning face a number of key challenges: 1. ML models tend to have many parameters, such that estimating the curvature along all dimensions is intractable. There are normally two ways around this: (i) using iterative methods where each iteration only exploits curvature information along a single dimension ("Hessian-free" methods; Martens et al., 2010) , and (ii) developing custom curvature approximations that can be efficiently inverted to obtain parameter updates (Martens & Grosse, 2015; Grosse & Martens, 2016; Botev et al., 2017; George et al., 2018; Bahamou et al., 2022; Soori et al., 2022) . The latter yields state-of-the-art performance in the optimization of deep networks (Goldfarb et al., 2020) . However, development of structured curvature approximations must be done on a case-by-case basis for each architecture (e.g. fully-connected or convolutional layer) and requires mathematical assumptions that are difficult to justify. 2. ML datasets tend to be large as well, such that the loss, its gradients, and its curvature can be only stochastically estimated from mini-batches. While noise can be mitigated by a combination of large mini-batches and momentum, estimating the various components of the curvature matrix before inverting it (as opposed to estimating the inverse directly) introduces a bias that can potentially be detrimental. Here, we focus primarily on the Fisher information matrix (FIM) as a notoriously effective source of curvature information for training ML models. For probabilistic models, preconditioning the gradient by the inverse FIM yields so-called natural gradient updates (Amari, 1998) . We introduce a novel framework for second-order optimization of high-dimensional probabilistic models in the presence of gradient noise. Instead of approximating the FIM in a way that can easily be inverted (e.g. as in KFAC; Martens & Grosse, 2015 and related approaches), we directly parameterize the inverse FIM. We make the following contributions: • We show that the inverse FIM can be computed through the Legendre-Fenchel conjugate of a cross entropy between model distributions. • We provide an algorithm (FishLeg) which meta-learns the inverse FIM in an online fashion, and we prove convergence of the corresponding parameter updates. • We first study its application to deep linear networks, an example of non-convex and pathologically curved loss functions with a Fisher matrix known in closed form (Bernacchia et al., 2018 , Huh, 2020) , and find that convergence occurs much faster than SGD with momentum or Adam. • We then show that, in standard auto-encoders benchmarks, FishLeg operating on a blockdiagonal Kronecker parameterization of the inverse Fisher performs similarly to -and sometimes outperforms -previous approximate natural gradient methods (Goldfarb et al., 2020) , whilst being only twice slower than SGD with momentum in wall-clock time per iteration. Similar to Amortized Proximal Optimization (Bae et al., 2022) , FishLeg can accommodate arbitrary parameterizations of the inverse Fisher, thus facilitating future applications of the natural gradient to a broad range of network architectures where manually approximating the FIM in an easily invertible form is otherwise difficult.

2. BACKGROUND 2.1 FISHER INFORMATION AND THE NATURAL GRADIENT

We consider a probabilistic model parameterized by a vector θ, which attributes a negative loglikelihood (θ, D) = -log p(D|θ) to any collection D of data points drawn from a data distribution p (D). This covers a broad range of models, including discriminative models for regression or classification, as well as generative models for density modelling. The goal is to find parameters θ to approximate the true data distribution p (D) by the model distribution p(D|θ ). The Fisher information matrix (FIM) measures how much information can be obtained about parameters θ after observing data D under the model p(D|θ), and captures redundancies between parameters (Rao, 1992) . The FIM is defined as I(θ) = E D∼p(D|θ) ∇ θ (θ, D)∇ θ (θ, D) . By this definition, the FIM is a positive semi-definite matrix. It can be shown that under certain regularity conditions, and if is twice differentiable w.r.t. θ, then the FIM can also be computed as I(θ) = E D∼p(D|θ) ∇ 2 θ (θ, D) It is important to note that the average is computed over the model distribution p(D|θ), not the data distribution p (D). Averaging Eq. 1 over p (D) results in the empirical Fisher matrix (Kunstner et al., 2019) , while averaging Eq. 2 over p (D) results in the Hessian of the loss. The FIM, the empirical Fisher and the Hessian are all different and sometimes confused (Thomas et al., 2020) . We consider cases in which parameters θ are obtained by maximum likelihood: θ = arg min θ L(θ) where the population loss function is defined as L(θ) = E D∼p (θ, D) which is in general non-convex, in particular when the model distribution is parameterised by deep neural networks. The natural gradient update takes the form θ t+1 = θ t -η I(θ) -1 g(θ) with g(θ) = ∇ θ L(θ) (5) where η is a learning rate and g is the gradient of the loss. In practice, the true distribution of data p (D) is unknown and we have only access to the empirical loss and a stochastic estimate ĝ of the gradient based on a finite sample of data D ∼ p . Similarly, while the FIM may be sometimes computed exactly (Bernacchia et al., 2018 , Huh, 2020) , in most circumstances it is also estimated from a mini-batch D ∼ p. To guarantee invertibility of the FIM, a small amount of "damping" can be added to it when it is nearly singular. Note that Eq. 5 is identical to gradient descent when the FIM equals the identity matrix. The natural gradient has several nice properties that make it an efficient optimizer (Amari, 1998) , but computing and inverting the FIM is usually costly.

2.2. THE LEGENDRE-FENCHEL CONJUGATE

The Legendre-Fenchel (LF) conjugate is a useful tool in optimization theory to map pairs of problems via duality (Boyd & Vandenberghe, 2004) and was introduced in the context of mechanics and thermodynamics (Zia et al., 2009) . In the following, we deviate from the standard textbook notation to avoid conflicts with the rest of the paper and to facilitate the translation of these classical results to our problem. Consider a twice differentiable function H(δ) of a vector δ, generally non-convex and assume a minimum of H exists. The LF conjugate is equal to H (u) = min δ H(δ) -u T δ. (6) The LF conjugate, also known as the convex conjugate, is defined also for non-differentiable functions and is always convex. We summarize here two properties of the convex conjugate H that we use in our derivations (see chapter 3.3 of Boyd & Vandenberghe, 2004 and Zia et al., 2009 for details): Property 1 The gradient of the conjugate is equal to ∇ u H (u) = δ(u) where δ(u) = argmin δ H(δ) -u T δ. The minimizer δ(u) also satisfies ∇ δ H( δ(u)) = u, implying that ∇ u H is a (local) inverse func- tion of ∇ δ H. The inverse is global when the function H is strictly convex: ∇ u H (u) = (∇ δ H) -1 (u). ( ) Property 2 The Hessian matrix of the conjugate is equal to the inverse (if it exists) of the Hessian matrix of H, computed at δ(u): ∇ 2 u H (u) = ∇ 2 δ H( δ(u)) -1 . ( ) 3 RELATED WORK

Natural gradient

The Hessian of non-convex losses is not positive definite, and second-order methods such as Newton's update typically do not converge to a minimum (Dauphin et al., 2014) . The natural gradient method substitutes the Hessian with the FIM, which is positive (semi)definite (Amari, 1998) . The natural gradient has been applied to deep learning in recent years (Martens & Grosse, 2015 , Bernacchia et al., 2018 , Huh, 2020 , Kerekes et al., 2021) , and has been proved approximately equivalent to local loss optimization (Benzing, 2022 , Meulemans et al., 2020 , Amid et al., 2022) . Previous work on natural gradient approximately estimated and inverted the FIM analytically (Martens & Grosse, 2015; Martens et al., 2018) . Instead, we use an exact formulation of the inverse FIM as the Hessian of the LF conjugate of the cross entropy, and we provide an algorithm to meta-learn an approximation to the inverse FIM during training. Legendre-Fenchel conjugate For convex losses, Eq. 8 implies that the parameter update ∆θ = -∇ g L (g(θ)) + ∇ g L (0) converges to a minimum in one step (Maddison et al., 2021; Chraibi et al., 2019) . However, this update is not practical because finding L is at least as hard as finding a minimum of the loss. Chraibi et al. (2019) propose to learn L by a neural network, and similar "amortized duality" approaches have recently been developed for optimal transport (Nhan Dam et al., 2019; Korotin et al., 2019; Makkuva et al., 2020; Amos, 2022) . Maddison et al. (2021) prove that, if a surrogate for L satisfies relative smoothness in dual space, then linear convergence rates can be proven for non-Lipschitz or non-strongly convex losses. Here we consider non-convex losses and we use the LF conjugate of the cross entropy for meta-learning the natural gradient update. Meta-learning optimizers Many standard optimizers have comparable performance when appropriately tuning a handful of static hyperparameters (Schmidt et al., 2021) . Learned optimizers often outperform well-tuned standard ones by tuning a larger number of hyperparameters online during training (Chen et al., 2021; Hospedales et al., 2021; Bae et al., 2022) . Different methods for learning optimizers have been proposed, including hypergradients (Maclaurin et al., 2015; Andrychowicz et al., 2016; Franceschi et al., 2017; Wichrowska et al., 2017; Wu et al., 2018; Park & Oliva, 2019; Micaelli & Storkey, 2021) , reinforcement learning (Li & Malik, 2016; Bello et al., 2017) , evolution strategies (Metz et al., 2019; Vicol et al., 2021) , and implicit gradients (Lorraine et al., 2020; Rajeswaran et al., 2019; Clarke et al., 2021) . In our work, hyperparameters do not minimize the loss, instead they minimize an auxiliary loss designed to learn the natural gradient update.

4.1. COMPUTATION OF THE NATURAL GRADIENT VIA THE LEGENDRE-FENCHEL CONJUGATE

We use the properties of the LF conjugate, in particular Eq. 9, to directly learn the inverse of the (damped) FIM, and use it in the natural gradient update of Eq. 5. This is different from previous work on natural gradient, which aimed at approximating the FIM analytically and then inverting it. We prove the following: Theorem 1. Assume that the negative log-likelihood (θ, D) = -log p(D|θ) is twice differentiable. Let γ > 0 be a small damping parameter. Define the regularized cross entropy between p(D|θ) and p(D|θ + δ) H γ (θ, δ) = E D∼p(D|θ) (θ + δ, D) + γ 2 δ 2 and the function δγ (θ, u) = argmin δ H γ (θ, δ) -u T δ. Then, the inverse damped Fisher information matrix exists and is equal to (I(θ) + γI) -1 = ∇ u δγ (θ, 0). ( ) Proof. Here we provide a sketch of the proof, see Appendix A.2 for the full proof. Lemma 1 proves that the regularized cross-entropy in Eq. 10 has a unique minimum at δ = 0, therefore δγ (θ, 0) = 0. The proof continues by expressing the damped FIM as the Hessian of the regularized cross-entropy, which is done by reparameterizing Eq. 2. Next, we use Eq. 9 to express the inverse damped FIM as the Hessian of the LF conjugate of the regularized cross-entropy (seen as a function of δ). From there, the theorem follows from Eq. 7. To clarify the connection with the LF conjugate, we note that Eq. 11 represents the gradient of the LF conjugate of the regularized cross-entropy, while Eq. 12 is the Hessian of the LF conjugatesee Appendix A.2 for details, and Appendix A.3 for an informal mathematical argument that helps understanding Theorem 1. Our goal is to compute the damped natural gradient update by solving the optimization problem in Eq. 11, and substituting the inverse damped FIM computed by equation 12 into the definition of the natural gradient step in Eq. 5, yielding θ t+1 = θ t -η ∇ u δγ (θ t , 0)g(θ t ). Importantly, note that Eq. 11 does not need to be solved for all possible pairs of θ and u. Indeed, given the current value of parameters θ t , it is sufficient to compute δγ (θ t , u) (i) at small values of u, because ∇ u δγ is evaluated at u = 0 in Eq. 12, and (ii) along the direction u ∝ g(θ t ), because ∇ u δγ is multiplied by g(θ t ) in Eq. 13 such that knowing the slope of δγ (θ t , u) along any other direction would be irrelevant. Nevertheless, computing Eq. 11 for such restricted (θ, u) pairs may still not be easier than computing the FIM by standard methods, e.g. by computing Eq. 1 and inverting the resulting matrix. In section 4.2, we therefore propose a practical instantiation of the ideas developed in this section, which yields an online algorithm with provable convergence guarantees.

4.2. ONLINE META-LEARNING OF NATURAL GRADIENT UPDATES

In principle, Theorem 1 provides an exact implementation of natural gradient by solving the optimization problem in Eq. 11 and then applying the update of Eq. 13. However, this approach is impractical due to the complexity of solving Eq. 11. To make it practical, we propose learning an approximation δ γ (θ, u, λ) of the true δγ (θ, u), through a set of auxiliary parameters λ. The auxiliary parameters λ are updated online, during optimization, to push δ γ towards a solution of Eq. 11 but without requiring convergence. This is akin to a meta-learning approach whereby only one or a few steps are taken towards a desired solution for meta-parameters λ, which are optimized concurrently with parameters θ in an inner loop (Hospedales et al., 2021) . Here we choose gradient-based meta-learning updates (Finn et al., 2017) , but other choices are possible, e.g. implicit differentiation (Lorraine et al., 2020 , Rajeswaran et al., 2019) or evolutionary strategies (Metz et al., 2019 ,Gao & Sener, 2022) . For the auxiliary parameterization, we choose a linear function of u, δ γ (θ, u, λ) = Q(λ)u, where Q(λ) is a positive definite matrix which will thus effectively estimate the inverse damped FIM (see equation 12). Appropriate choices for Q should take into account the architecture of the model, as discussed further below. Although the r.h.s. of Eq. 14 does not depend on θ explicitly, an implicit dependence will arise from us learning the parameters λ in a way that depends on the momentary model parameters θ. Indeed, for learning λ, we perform gradient descent using Adam on the following auxiliary loss A, a choice justified by Eq. 11: A γ (θ, u, λ) = H γ (θ, Q(λ)u)) -u T Q(λ)u. In summary, we alternate between updating λ to minimize the auxiliary loss in Eq. 15, and updating θ according to Eq. 13 to minimize the main loss, λ t+1 = λ t -α AdamUpdate (∇ λ A γ (θ t , g(θ t ), λ t )) , (16) θ t+1 = θ t -η Q(λ t+1 ) g(θ t ), ( ) where α is a learning rate for the auxiliary parameters and is a small scalar (both are hyperparameters). Note that the auxiliary loss is computed at g(θ) because, as argued in section 4.1, it is sufficient to compute δ at small values of u in the direction of u ∝ g(θ t ). In Appendix A.9, we provide an analysis of these coupled dynamics in a simple linear-Gaussian model. We parameterize Q(λ) as a positive definite Kronecker-factored block-diagonal matrix, with block sizes reflecting the layer structure of deep fully-connected neural networks (section A.5). Although this is similar to parameterizations used in previous studies of natural gradient (Martens & Grosse, 2015) , we stress that our approach is more flexible since it learns the inverse FIM rather than approximating and inverting the FIM analytically. In fact, we show in the appendix A.6 that alternative forms of the matrix Q, different from those used before, may provide a better approximation of the inverse FIM (see Fig. 4 ).

4.3. CONVERGENCE OF ONLINE NATURAL GRADIENT

In this section we prove that the update of Eq. 17 converges to a minimum of the loss for the class of PL functions (Polyak-Lojasiewicz), which are generally non-convex. Charles & Papailiopoulos (2018) proved that deep linear networks, which are non-convex (Saxe et al., 2013) , are PL almost everywhere. We study deep linear networks empirically in section 5.1. We stress that the goal of this section is to prove convergence, rather than to obtain tight bounds on the convergence rate. We provide convergence guarantees for both the true gradient g and stochastic gradients ĝ. Theorem 2 is a special case of Theorem 1 in Radhakrishnan et al. (2020) , while Theorem 3 is new to our knowledge. In both cases, the crucial assumption is that Q is a positive definite matrix, which holds for our chosen parameterization (see section A.5). For details on definitions and proofs, see appendix A.4. Theorem 2. Assume the loss function L is ξ-smooth and µ-PL. Let θ * ∈ arg min θ L(θ). Assume that the eigenvalues of Q(λ t ) are lower-and upper-bounded uniformly in time by λ min > 0 and λ max , respectively. We have the following rate of convergence for optimization using the update rule of Eq. 17 with η = 1 ξλmax : L(θ t ) -L(θ * ) ≤ 1 - µλ min 2ξλ max t (L(θ 0 ) -L(θ * )). ( ) Theorem 3. Let the same assumptions as in Theorem 2 hold. In addition, assume that E[ ĝ(θ) 2 ] ≤ G 2 . We have the following rate for optimization using update rule 17 with η t = 2 µλmin(t+1) : L(θ t ) -L(θ * ) ≤ A t + 1 where A = max 2ξλ 2 max G 2 µ 2 λ 2 min , L(θ 0 ) -L(θ * ) .

4.4. DETAILS OF THE ALGORITHM AND FURTHER REMARKS

Here we detail our practical implementation of the online FishLeg algorithm presented in section 4.2, also summarized in Algorithm 1. Our code is available here on GitHub. Auxiliary optimization Since H γ is evaluated on a model-sampled mini-batch, it is a noisy quantity whose variance depends on the hyperparameter . While this variance could be reduced e.g. by using an antithetic estimator (Gao & Sener, 2022) , we opted instead for analytically taking the small limit of the auxiliary loss. By Taylor expanding equation 15 (use equations 23, 24), and dropping terms that are constant in λ, we arrive at Ãγ (θ, u, λ) ≡ u T 1 2 Q(λ)∇ 2 δ H γ (θ, 0)Q(λ) -Q(λ) u. The auxiliary update of Eq. 16 thus becomes λ t+1 = λ t -α AdamUpdate(∇ λ Ãγ (θ t , g/ g , λ t )) where the normalization of g is important for continuing to learn about curvature even when gradients are small. In order to avoid differentiating through the Hessian-vector product in Eq. 20, we note that the gradient of Ã w.r.t. λ can be written as ∇ λ Ã(θ, u, λ) = ∇ 2 δ H(θ, 0)Q(λ)u -u ∇ λ [Q(λ)u] . Algorithm 1 in the appendix shows how to implement this efficiently using automatic differentiation tools. As a side note, Eq. 21 makes it clear that Q will get to approximate the inverse of I(θ) + γI (which is also the Hessian of the regularized cross-entropy, see equation 24) at least in the relevant subspace occupied by u = g/ g over iterations. Moreover, note that the noise in H γ simply adds stochasticity to the auxiliary loss but does not bias its gradient w.r.t. λ. Finally, Eq. 20 reveals a connection with Hessian-free optimization which we discuss in Appendix A.8. Damping Practitioners of second-order optimization in deep learning have consistently found it critical to apply damping when the FIM becomes near-singular. Although adaptive damping schemes exist and could be incorporated in FishLeg, here we followed Goldfarb et al. ( 2020) and used static damping; specifically, we used a fixed γ in Eq. 10, treated as a tuned hyperparameter. Momentum We have found useful to use momentum as many other optimizers do (e.g. KFAC, KBFGS). In most of our experiments, we implemented momentum on g (as per line 16 of Algorithm 1). For our wall-clock time results (Fig. 3 ), however, we applied momentum to our natural gradient estimate Q(λ)g instead, in order to preserve the low-rank structure of g at each layer of the neural network. This enabled us to substantially speed up the computation of Q(λ)g on small mini-batches, but did not affect training and test errors noticeably. Initialization of λ We find it useful to initialize λ such that Q(λ 0 ) = ηSGDm η × I, where η SGDm is a learning rate that is known to give good results with SGD-with-momentum. With such an initialization, FishLeg therefore initially behaves like SGDm. Future work could investigate smarter initialization / warm-starting schemes, e.g. starting Q(λ) in an Adam-like diagonal approximation of the empirical Fisher. 

5.1. DEEP LINEAR NETWORKS

We first studied FishLeg in the context of deep linear networks, whose loss function is non-convex, pathologically curved, and for which the Fisher matrix is known analytically (Bernacchia et al., 2018 , Huh, 2020) . The results presented in Fig. 1 were obtained with a network of 20 (linear) layers of size n = 20 each. We generated data by instantiating a teacher network with random Gaussian weights N (0, 1/n) in each layer and a predictive density with mean equal to the activation of the last layer and isotropic Gaussian noise with variance σ 2 0 = 0.001 2 . Input samples were drawn from a standard normal distribution, and the teacher's predictive distribution was sampled accordingly to obtain output labels. We investigated the behaviour of FishLeg both in the large data regime (4000 training samples) and the small data regime (40 training samples). In both regimes, FishLeg very rapidly drove the training loss to the minimum set by the noise (σ 0 ), a minimum which was attained by neither manually-tuned SGD-m nor Adam (Fig. 1A ). This is consistent with previous such comparisons in similar settings (Bernacchia et al., 2018) . In the small data regime, FishLeg displays some overfitting but continues to compare favourably against firstorder methods on test error. In wall-clock time, FishLeg was about 5 times slower than SGDm per training iteration in this case (Fig. 1C ), but we show that a significant speed up can be obtained by updating λ every 10 iterations at little performance cost (see Fig. 3 ). We used this simple experiment to assess the effectiveness of auxiliary loss optimization during the course of training. The lower the auxiliary loss, the closer it is from the Legendre conjugate of the cross-entropy at g, and therefore the closer Q(λ)g is from the natural gradient. We assessed how much progress was made on the auxiliary loss, relative to how it would evolve if λ was held constant, as θ is being optimized (Fig. 1B ). This revealed two distinct phases of learning. At first, θ (and therefore the FIM) changes rapidly, but λ rapidly adapts to keep the auxiliary loss in check (compare black and dark purple curves) -we speculate that this mostly reflects adaptation to the overall scale of the (inverse) FIM. Later on, as θ changes less rapidly, λ begins to learn the more subtle geometry of the FIM, giving rise to rather modest improvements on the auxiliary loss (compare black and lighter purple curves).

5.2. SECOND-ORDER OPTIMIZATION BENCHMARK

We applied FishLeg to the auto-encoders benchmarks previously used to compare second-order optimization methods -the details of these experiments (model architectures, datasets, etc) can be found in (Goldfarb et al., 2020) , and hyperparameters specific to FishLeg are Table 1 

(Appendix).

To compare FishLeg to other optimizers, we used the code provided by Goldfarb et al. which implemented Adam, RMSprop, KFAC and KBFGS. These autoencoder benchmarks are difficult problems on which the family of second-order methods has been shown to improve substantially over first-order optimizers (e.g. SGDm) and optimizers that exploit second-order information in diagonal form (Adam, RMSprop). Yet, no clear differences in performance has emerged within second-order methods. As far as training loss and test errors are concerned, our results confirmed this trend: FishLeg performed similarly to KFAC and KBFGS on the FACES and MNIST datasets, although it did converge within fewer iterations on the CURVES dataset (Fig. 2 , top). Figure 3 : Wall-clock time comparisons on MNIST and FACES, with optimizers all implemented in the same way on CPU (Intel Xeon Platinum 8380H @ 2.90GHz) with OpenBLAS compiled for that architecture and multi-threaded with OpenMP (8 threads). Within each dataset, all curves show the same number of epochs to facilitate comparisons of CPU time-per-iteration. In FishLeg, the auxiliary parameters were updated only every 10 iterations, with no noticeable drop in performance but large wallclock speedup. Similarly, KFAC's preconditioning matrices were inverted at each layer every 20 iterations only. Interestingly, however, we found FishLeg to be consistently more robust than other methods, especially when using relatively small batch sizes. In particular, it converged more reliably over repeated runs with different random initializations (Fig. 2 , middle). This is perhaps due to other methods achieving their best performance in a near-critical hyperparameter regime in which optimization tends to fail half of the time (these failed runs are often discarded when averaging). In contrast, FishLeg achieved similar or better performance in a more stable hyperparameter regime, and did not fail a single of the 10 runs we performed. Similarly, FishLeg was also more consistent than other methods, i.e. displayed a lower variance in training loss across runs (Fig. 2 , bottom). Given the heterogeneity of software systems in which the various methods were implemented (e.g. JAX vs PyTorch), we ran a clean wallclock-time comparison between SGDm, KFAC and FishLeg using a unified CPU-only implementation applied to the FACES and MNIST benchmarks. This ensured e.g. that the loss and its gradients were computed in exactly the same way across methods. Overall, one iteration of vanilla FishLeg was ∼ 5 times slower than one iteration of SGDm. However, we were able to bring this down to only twice slower by updating λ every 10 iterations, which did not significantly affect performance. Combined with FishLeg's faster progress per-iteration, this meant that FishLeg retained a significant advantage in wall-clock time over SGD (Fig. 3 ), similar to KFAC. In practice we think that it might make sense to update λ more frequently at the beginning of training, and let these updates become sparser as optimization progresses.

6. DISCUSSION

We provided a general framework for approximating the natural gradient through online metalearning of the LF conjugate of a specific cross-entropy. Our framework is general: different choices can be made for how to meta-learn the LF conjugate (Eq. 11), parameterize its gradient (Eq. 14) and evaluate/differentiate the auxiliary loss (Eq. 15). Beyond our specific implementation, future work will study alternative choices that may be more efficient. For example, implicit differentiation (Lorraine et al., 2020 , Rajeswaran et al., 2019) , or evolution strategies (Metz et al., 2019 ,Gao & Sener, 2022 ) may be used to meta-learn the LF conjugate. The auxiliary loss could also be evaluated without taking the small limit but using antithetic estimators to reduce variance (Gao & Sener, 2022) . Alternative parameterizations could be used for δ or Q. Indeed, preliminary results show that more expressive choices for Q yield better approximations of the inverse FIM (Appendix A.6). Previous work on natural gradient aimed at computing and inverting analytically an approximation to the FIM, which was done on a case-by-case basis for dense (Martens & Grosse, 2015) and convolutional (Grosse & Martens, 2016 ) layers, and standard recurrent networks (Martens et al., 2018) . By parameterizing the inverse FIM directly, our approach allows the user to express their assumptions about the structure of parameter precisions, which is typically easier to reason about than the structure of parameter covariances. We therefore expect that FishLeg will facilitate future applications of natural gradient optimization to a broader range of network architectures.

A APPENDIX

A.1 ALGORITHM Algorithm 1 FishLeg algorithm (online setting) 1: function UPDATE AUX(θ, g, λ, Adam state) 2: g ← g/ g normalize gradient 3: initialize the adjoint of λ to prepare for automatic differentiation reverse pass (c.f. line 9).

4:

δ ← Q(λ)g can exploit fast matrix-vector products without forming Q 5: Hessian-vector product on the regularized cross-entropy H (with δ taken off the automatic differentiation tape!) evaluated on a mini-batch different from the one used to obtain g.

6:

h ← Hess vec prod(fun δ → H(θ, δ), at δ = 0, along v =stop gradient( δ)) 7: d ← h - g 8: aux loss ← 1 2 (d -g) stop gradient( δ) log the value of the auxiliary loss 9: ∆ λ ← adjoint of λ after completion of reverse pass on surrogate auxiliary loss d δ 10: λ, Adam state ← Adam update(λ, ∆ λ , α, Adam state) α is the learning rate 11: return aux loss, λ, Adam state 12: 13: t ← 0, initialize θ 0 and λ 0 14: while not converged do 15: L, g ← value and gradient of negative log likelihood evaluated at θ t on a minibatch 16: ḡ ← βḡ + (1 -β)g momentum 17: Note: the following step needs not be performed at every iteration -wall-clock time speedups can be obtained by running this every 10 iterations, though this might depend on the problem.

18:

aux loss, λ t+1 , Adam state ← UPDATE AUX(θ t , ḡ, λ t , Adam state) 19: θ t+1 ← θ t -ηQ(λ t+1 )ḡ 20: t ← t + 1 A.2 PROOF OF THEOREM 1 We start by stating the following Lemma, which is a modified version of Gibbs' inequality (see e.g. Chapter 2.6 of MacKay et al., 2003) Lemma 1. For a fixed θ, the regularized cross entropy between p(D|θ) and p(D|θ + δ) H γ (θ, δ) = E D∼p(D|θ) (θ + δ, D) + γ 2 δ 2 has a unique global minimum at δ = 0. Proof of Theorem 1. We start by computing the gradient and Hessian of the regularized crossentropy H γ (θ, δ) with respect to δ, computed at δ = 0. ∇ δ H γ (θ, 0) = E D∼p(D|θ) ∇ θ (θ, D) + 0 = 0 (23) ∇ 2 δ H γ (θ, 0) = E D∼p(D|θ) ∇ 2 θ (θ, D) + γI = I(θ) + γI. ( ) The gradient is equal to the expectation of the score function, which is known to be equal to zero (as also implied by Lemma 1). The Hessian is equal to the damped FIM, as expressed by equation 2. Note that the damped FIM is positive definite and so its inverse exists. Using property 9 of the LF conjugate, we express the inverse damped Fisher matrix as the Hessian of the LF conjugate of H γ . We denote by H γ (θ, u) the LF conjugate of H γ (θ, δ), as a function of its second argument δ. Using the definition of LF conjugate 6, that is equal to H γ (θ, u) = min δ H γ (θ, δ) -u T δ We denote by δγ (θ, u) the minimizer of this expression, i.e. δγ (θ, u) = argmin δ H γ (θ, δ) -u T δ Using the property of LF conjugate, equation 9, we have that ∇ 2 u H γ (θ, u) = ∇ 2 δ H γ (θ, δγ (θ, u)) -1 Comparing with the expression 24 of the damped FIM, the right hand side of this equation is equal to the inverse damped FIM when δγ (θ, u) is equal to zero. By Lemma 1, we have that δγ (θ, 0) = 0 and therefore (I(θ) + γI) -1 = ∇ 2 u H γ (θ, 0). Finally, using the properties of LF conjugate (Eq. 7), we have that δγ (θ, u) = ∇ u H γ (θ, u) and the inverse FIM is equal to (I(θ) + γI) -1 = ∇ u δγ (θ, 0). A.3 INFORMAL ARGUMENT FOR THEOREM 1 Theorem 1 suggests that computation of the inverse damped FIM requires computing δγ (θ, u) near u = 0. By Lemma 1, we have that δγ (θ, 0) = 0, therefore we may hypothesize that δ(θ, u) is near zero when u is also near zero. Under this assumption, and using equations 23, 24, we may approximate the regularized cross entropy H γ (θ, δ) by a second order Taylor expansion in δ: H γ (θ, δ) -u T δ H γ (θ, 0) + 1 2 δ T (I(θ) + γI)δ -u T δ Minimizing this expression with respect to δ results in δγ (θ, u) = (I(θ) + γI) -1 u (31) which implies the statement of the theorem, (I(θ) + γI) -1 = ∇ u δγ (θ, 0).

A.4 PROOFS OF CONVERGENCE

Definition 1 (PL Condition). We say that a function satisfies the PL condition with parameter µ ∈ R >0 , if the following holds ∇L(θ) 2 ≥ µ(L(θ) -L(θ )), ∀θ We use the notation µ-PL to denote the class of functions satisfying the PL condition with parameter µ. This condition implies that every stationary point is a global minimum. It however does not imply either uniqueness of the global minimum nor convexity. Definition 2 (Smoothness). We say that a function is smooth with parameter ξ, if the following holds L(θ ) ≤ L(θ) + ∇L(θ) T (θ -θ) + ξ 2 θ -θ 2 , ∀θ, θ We use the notation ξ-smooth to denote the class of smooth functions with parameter ξ. Proof of theorem 2. We show that the error at each iteration t + 1 is linearly related to the error at iteration t. L(θ t+1 ) -L(θ t ) ≤ -ηg(θ t ) Qg(θ t ) + ξη 2 2 g(θ t ) Q 2 g(θ t ) (34) ≤ -ηg(θ t ) Qg(θ t ) + ξλ max η 2 2 g(θ t ) Qg(θ t ) (35) = -η(1 - ξλ max η 2 )g(θ t ) Qg(θ t ) (36) ≤ - ηλ min 2 g(θ t ) 2 (37) ≤ - ηµλ min 2 (L(θ t ) -L(θ * )) (38) = - µλ min 2ξλ max (L(θ t ) -L(θ * )) The first line is a result of smoothness assumption and the update rule 17. The second line follows from the definition of λ max . The third line is only a rearrangement of terms. The fourth line follows from the definition of λ min and the choice of η = 1 ξλmax . The fifth line holds by PL condition. The last line follows from replacing the value of η = 1 ξλmax . We thus have L(θ t+1 ) -L(θ * ) ≤ (1 - µλ min 2ξλ max ) (L(θ t ) -L(θ * )) . Applying this recursive relation over t, starting from θ 0 , we arrive at L(θ t ) -L(θ * ) ≤ (1 - µλ min 2ξλ max ) t (L(θ 0 ) -L(θ * )) . Proof of theorem 3. We show that the error at each iteration t + 1 is related to the error at iteration t as follows. E [L(θ t+1 ) -L(θ t )|θ t ] ≤ E -η t g(θ t ) Qĝ(θ t ) + ξη 2 t 2 ĝ(θ t ) Q 2 ĝ(θ t )|θ t (40) ≤ -η t g(θ t ) Qg(θ t ) + E ξλ 2 max η 2 t 2 ĝ(θ t ) 2 |θ t (41) ≤ -η t λ min g(θ t ) 2 + ξη 2 t λ 2 max G 2 2 (42) ≤ -η t µλ min (L(θ t ) -L(θ * )) + ξη 2 t λ 2 max G 2 2 (43) The first line is a result of smoothness assumption and the update rule 17. The second line follows from the definition of λ max . The third line follows from the upper bound assumption on the norm of gradient. The fourth line holds by PL condition. We thus have L(θ t+1 ) -L(θ * ) ≤ (1 -η t µλ min )(L(θ t ) -L(θ * )) + ξη 2 t λ 2 max G 2 2 Recall the choice of η t = 2 µλmin(t+1) . By definition of A, we have L(θ t ) -L(θ * ) ≤ A t+1 when t = 0. Using induction, we prove the same for L(θ t+1 ) -L(θ * ). L(θ t+1 ) -L(θ * ) ≤ (1 - 2 t + 1 ) A (t + 1) + 2ξλ 2 max G 2 µ 2 λ 2 min (t + 1) 2 (44) ≤ A t + 1 - A (t + 1) 2 (45) = A( 1 t + 1 - 1 (t + 1) 2 ) (46) ≤ A t + 2 . ( ) That completes the proof. A.5 PARAMETERIZATION OF THE MATRIX Q Consider a multi-layer perceptron (MLP) neural network with L layers. The activation a i of neuron i at layer is equal to a i = σ   N -1 +1 j=1 W ij a -1 j   for 1 ≤ i ≤ N (48) where layer has N neurons and = 1, . . . , L. The symbol W denotes the weight matrix from layer -1 to layer , and has size N × (N -1 + 1). This weight matrix includes a bias, by setting the N + 1 activation at each layer equal to one: a N +1 = 1 The function σ can be any nonlinearity such as ReLU or a Softmax in case of the last layer. The output of the neural network is defined as the activation in the last layer, a L . For convenience of notation, the input to the neural network, denoted by x and of dimension N 0 , is defined as activation at layer 0: a 0 i = x i for 1 ≤ i ≤ N 0 (51) a 0 N0+1 = 1 We structure the matrix Q as block-diagonal, each block corresponds to one layer and has the same number of rows and columns, equal to N (N -1 +1). For layer l, We parameterize the corresponding block, denoted by Q l as Q = (R R T ⊗ L L T ) where the matrix R has size (N -1 + 1) × (N -1 + 1) while the matrix L has size N × N . Both matrices L , R are lower triangular. This parameterization ensures that the matrix Q is positive definite. The auxiliary parameters λ are represented by the matrices L , R for all layers = 1, . . . , L, for a total of 1 2 L =1 [N (N + 1) + (N -1 + 1)(N -1 + 2)] auxiliary parameters. This number is much smaller than the total number of entries of the matrix Q, which is equal to L =1 N (N -1 + 1) 2 A.

6. PRELIMINARY INVESTIGATIONS OF MORE FLEXIBLE APPROXIMATIONS OF THE INVERSE FIM

Directly parameterizing and learning the inverse FIM lends more flexibility than parameterizing and learning the FIM in an easy-to-invert form. We conducted preliminary experiments with alternative parameterizations Q(λ) of the inverse FIM (i.e. beyond the one described in Appendix A.5), all constrained to afford fast Q(λ)g products. In particular, we experimented with: • a simple diagonal matrix Q (constrained to be positive definite), mostly included as a useful baseline • the block Kronecker used in the paper and described in Appendix A.5, known to perform much better than a diagonal matrix; • a modification of the block Kronecker form that introduces full inner and outer diagonal rescaling: at each layer , we took Q = A (R ⊗ L )B 2 (R T ⊗ L T )A where A and B are two diagonal matrices of the appropriate size. The presence of B makes this similar to EKFAC (George et al., 2018) , but A makes it even more general; • a sum of block Kronecker approximations: at each layer , we took Q = (R (1) R (1) T ) ⊗ (L (1) L (1) T ) + (R (2) R (2) T ) ⊗ (L (2) L (2) T ) All these approximations lend themselves to efficient Qg products (e.g. by exploiting standard properties of the Kronecker product). To assess how well these approximations could approximate the inverse FIM at a point in optimization where high-quality curvature information matters, we pre-trained an MLP with layer sizes [784, 300, 100, 30, 100, 300, 784] for 200 iterations on MNIST autoencoding using SGD-withmomentum. The resulting model parameters θ were subsequently frozen, and the auxiliary loss was minimized to convergence (asymptotic value shown here) under each of the four different parameterizations of Q(λ) described above. We found that diagonal rescaling improved only slightly over the vanilla block Kronecker approximation. However, there was a larger improvement going from the block Kronecker approximation to the sum-of-block-Kronecker approximation -indeed, the improvement on the auxiliary loss was a sizeable fraction of the improvement going from a diagonal approximation to single block-Kronecker terms. The latter is known to be highly consequential for training, and so we speculate that sum-of-Kronecker approximations might improve natural gradient descent in future work. - 

A.7 FISHLEG COMPLEXITY

In this section, we provide an analysis of the complexity of FishLeg for a deep fully connected network with L layers of size N each, processing data in mini-batches of size K. We assume that Q has the block-diagonal Kronecker form of Eq. 53. The complexity of the forward pass is O(LKN 2 ), and with automatic differentiation the backward pass has the same complexity. At each layer, the gradient of the loss w.r.t. the weight matrix has rank R = min(N, K), such that computing the Qg product required for the FishLeg update step (Eq. 17) using Kronecker identities has complexity O(LKN 2 ) or O(LN 3 ), whichever is smallest. Finally, the inner update step (Eq. 16) is dominated by (i) the Qg product (line 4 of Algorithm 1) and (ii) the Hessian-vector product (line 6) which has the same O(LKN 2 ) complexity as the forward pass. Altogether, FishLeg has complexity O(LKN 2 ). In contrast, KFAC has complexity O(LN 2 max(K, N )), as the N 3 cost of inverting the Kronecker factors is unavoidable even when the batch size is small. where we have defined θ = U θ and z = U ẑ, and Λ is a diagonal matrix containing the eigenvalues {λ i } of the input covariance. The eigenvalues {β i } of this state matrix are then easily shown to satisfy αλ i (1 + β i ) = -β 2 i (61) assuming the correct ordering of the β i 's w.r.t. the λ i 's. For each λ i , as α grows large (i.e. good separation of timescales between θ and Q), there are only two ways for β i to satisfy Eq. 61. Either β i remains finite (O(α 0 )), implying that (1 + β i ) must be O(1/α), which in turn implies β i → -1. Or β i grows large and real, in which case (1 + β i ) must be negative because -β 2 i is, and more specifically β i must grow as -O(α). These qualitative arguments can be confirmed by writing down the closed-form solution of the quadratic Eq. 61. In summary, with sufficient separation of timescales between the primary and auxiliary parameter updates, FishLeg converges uniformly in all directions in parameter space, with a uniform dominant timescale that does not depend on the condition number of Σ. This is contrast with standard gradient descent which is slowed down by small negative eigenvalues in the spectrum of Σ. Note that in the above derivations, the transition from Eq. 57 to Eq. 58 made the implicit assumption that Q was unconstrained -in particular, it could lose symmetry. In this case, the derivation shows that unless the auxiliary flow is substantially faster than the primary flow, FishLeg can suffer from oscillations / poor damping especially in the bottom subspace of Σ (c.f. solutions of Eq. 61 for small α). In practice however, one would normally write Q = LL T with a positivity constraint on the diagonal of L, and formulated the auxiliary flow in terms of dL/dt instead. This would guarantee Q 0 throughout optimization. Although analytical derivations become more tedious in this case, we speculate based on numerical simulations that this parameterization mitigates the problem of oscillatory optimization dynamics.



Figure 1: Application to deep linear networks. A deep linear network (20 layers of size 20 each) with Gaussian output likelihood is trained on a number of data samples (left: 4000; right: 40) generated by a noisy teacher in the same model class. (A) Loss evaluated on the training (solid) and testing set (dashed), as a function of training epoch. The gray line shows the noise floor, i.e. the test loss evaluated on the teacher network. Lines show mean across 10 independent experiments (teacher, initial conditions, etc), shadings show ±1 stdev. (B) Evolution of the auxiliary loss during training (black), with Fisher-matrix-vector products evaluated exactly according to the analytical expression given in(Bernacchia et al., 2018). To assess how well the auxiliary optimization steps minimize the auxiliary loss -a moving target when θ changes -, the auxiliary parameters λ are frozen every 500 iterations and subsequently used to re-evaluate the auxiliary loss (colored lines). (C) Same as (A), here as a function of wall-clock time. Parameters: minibatch size = 40, η = 0.04, α = 0.001, β = 0.9, η SGDm = 0.002, η Adam = 0.0002.

Figure 2: Comparison between FishLeg and other optimizers on standard auto-encoders benchmarks (batch size 100). Top: Training loss and test error as a function of training epochs on the 3 datasets, averaged over 10 different seeds. Middle: reliability of FishLeg across runs. For each optimizer and dataset, the fraction of runs (estimated from 10 runs) that converged properly (i.e. where the loss did not run out to 'nan') is shown; error bars show ±1 s.e.m. Bottom: variability (standard deviation) of the training loss across runs, averaged over epochs (discarding any epochs where the loss might have exploded). See supplementary Fig. 5 for results on batch size 1000.

Figure 4: Comparison of various forms of Q(λ). See text in Appendix A.6 for details.

A.8 FISHLEG AS AMORTIZATION OF CG STEPS IN HESSIAN-FREE OPTIMIZATION

FishLeg can be thought as a way of gradually amortizing the conjugate gradients (CG)-based inner loop in Hessian-free optimization. Looking back at Eq. 20, one might wonder: why not directly optimize over the product v = Q(λ)g, instead of optimizing over λ? Treating v as free parameters, this would entail minimizing 1 2 v (I(θ) + γI)v -v g, which is in fact exactly what Hessianfree optimization does using CG (Byrd et al., 2011; Martens et al., 2010;  though I is sometimes replaced by the Gauss-Newton approximation to the Hessian). Although each step of CG has a cost comparable to the computation of g, in practice several iterations are required which tends to remove any advantage on wallclock time. By learning Q(λ) instead of Q(λ)g, FishLeg learns to amortize those CG steps over a progressively growing subspace of noisy gradients.

A.9 ANALYTICAL RESULTS IN A SIMPLE LINEAR-GAUSSIAN SYSTEM

Here, we provide an analytical derivation of the behaviour of FishLeg in a simple one-layer linear network with a Gaussian likelihood, and draw some insights. Consider a regression model of the form p(y|x) = N (y; θ T x, 1), and a teacher in the same model class with θ = 0. Let the input distribution be x ∼ N (0, Σ). The log likelihood -in the limit of large data batches sampled from the teacher network -is given by (θ) = 1 2 θ T Σθ. For analytical convenience, we will work in the continuous time limit. In this limit, the standard gradient flow is given byThis flow is easily seen to converge to the correct solution (the teacher parameter θ = 0), but to do so slowly along the eigenvectors of Σ associated with small eigenvalues. Indeed, writing θ = U a where the columns of U are the orthonormal eigenvectors of Σ, Eq. 54 implies a i (t) = a i (0)e -λit where λ i > 0 is the i th eigenvalue of Σ. Thus, when Σ is poorly conditioned, convergence is very slow in its bottom subspace. The question arises: how well does FishLeg mitigate this problem?For this simple model, the Fisher matrix is I = Σ, and the cross-entropy of Eq. 10 (unregularized, as it is not necessary here) is given by H(•, δ) = 1 2 δ T Σδ. To simplify our analytical derivations, instead of using Adam for optimizing the auxiliary loss as we do in our experiments, here we consider simple gradient descent in continuous time:which is evaluated at u = g/ g where g = Σθ is the momentary gradient of the primary loss (the model's negative log likelihood). Expanding the auxiliary loss of Eq. 20, we obtainLetting z = QΣθ, and after some algebra, the auxiliary flow of Q in Eq. 57 implies the following flow for z:With the same notation, the primary FishLeg flow isThus, in this simple linear-Gaussian setup, FishLeg boils down to a pair of coupled linear ODEs.One can already see that, assuming a separation of timescales with α 1 such that θ changes much more slowly than Q (i.e. than z), then the FishLeg flow becomes dθ/dt = -θ which implies exponential decay of the loss irrespective of Σ -and indeed, for this model it is exactly natural gradient descent. To drive the point home, we rewrite these coupled ODEs in the orthonormal eigenbasis of Σ = U ΛU T , which yields 

