FISHER-LEGENDRE (FISHLEG) OPTIMIZATION OF DEEP NEURAL NETWORKS

Abstract

Incorporating second-order gradient information (curvature) into optimization can dramatically reduce the number of iterations required to train machine learning models. In natural gradient descent, such information comes from the Fisher information matrix which yields a number of desirable properties. As exact natural gradient updates are intractable for large models, successful methods such as KFAC and sequels approximate the Fisher in a structured form that can easily be inverted. However, this requires model/layer-specific tensor algebra and certain approximations that are often difficult to justify. Here, we use ideas from Legendre-Fenchel duality to learn a direct and efficiently evaluated model for the product of the inverse Fisher with any vector, in an online manner, leading to natural gradient steps that get more accurate over time despite noisy gradients. We prove that the resulting "Fisher-Legendre" (FishLeg) optimizer converges to a (global) minimum of non-convex functions satisfying the PL condition, which applies in particular to deep linear networks. On standard auto-encoder benchmarks, we show empirically that FishLeg outperforms standard first-order optimization methods, and performs on par with or better than other second-order methods, especially when using small batches. Thanks to its generality, we expect our approach to facilitate the handling of a variety neural network layers in future work.

1. INTRODUCTION & SUMMARY OF CONTRIBUTIONS

The optimization of machine learning models often benefits from the use of second-order gradient (curvature) information, which can dramatically increase the per-iteration progress on the training loss. However, second-order optimizers for machine learning face a number of key challenges: 1. ML models tend to have many parameters, such that estimating the curvature along all dimensions is intractable. There are normally two ways around this: (i) using iterative methods where each iteration only exploits curvature information along a single dimension ("Hessian-free" methods; Martens et al., 2010) , and (ii) developing custom curvature approximations that can be efficiently inverted to obtain parameter updates (Martens & Grosse, 2015; Grosse & Martens, 2016; Botev et al., 2017; George et al., 2018; Bahamou et al., 2022; Soori et al., 2022) . The latter yields state-of-the-art performance in the optimization of deep networks (Goldfarb et al., 2020) . However, development of structured curvature approximations must be done on a case-by-case basis for each architecture (e.g. fully-connected or convolutional layer) and requires mathematical assumptions that are difficult to justify. 2. ML datasets tend to be large as well, such that the loss, its gradients, and its curvature can be only stochastically estimated from mini-batches. While noise can be mitigated by a combination of large mini-batches and momentum, estimating the various components of the curvature matrix before inverting it (as opposed to estimating the inverse directly) introduces a bias that can potentially be detrimental. Here, we focus primarily on the Fisher information matrix (FIM) as a notoriously effective source of curvature information for training ML models. For probabilistic models, preconditioning the gradient by the inverse FIM yields so-called natural gradient updates (Amari, 1998) . We introduce a novel framework for second-order optimization of high-dimensional probabilistic models in the presence of gradient noise. Instead of approximating the FIM in a way that can easily be inverted (e.g. as in KFAC; Martens & Grosse, 2015 and related approaches), we directly parameterize the inverse FIM. We make the following contributions: • We show that the inverse FIM can be computed through Legendre-Fenchel conjugate of a cross entropy between model distributions. • We provide an algorithm (FishLeg) which meta-learns the inverse FIM in an online fashion, and we prove convergence of the corresponding parameter updates. • We first study its application to deep linear networks, an example of non-convex and pathologically curved loss functions with a Fisher matrix known in closed form (Bernacchia et al., 2018 , Huh, 2020) , and find that convergence occurs much faster than SGD with momentum or Adam. • We then show that, in standard auto-encoders benchmarks, FishLeg operating on a blockdiagonal Kronecker parameterization of the inverse Fisher performs similarly to -and sometimes outperforms -previous approximate natural gradient methods (Goldfarb et al., 2020) , whilst being only twice slower than SGD with momentum in wall-clock time per iteration. Similar to Amortized Proximal Optimization (Bae et al., 2022) , FishLeg can accommodate arbitrary parameterizations of the inverse Fisher, thus facilitating future applications of the natural gradient to a broad range of network architectures where manually approximating the FIM in an easily invertible form is otherwise difficult.

2.1. FISHER INFORMATION AND THE NATURAL GRADIENT

We consider a probabilistic model parameterized by a vector θ, which attributes a negative loglikelihood (θ, D) = -log p(D|θ) to any collection D of data points drawn from a data distribution p (D). This covers a broad range of models, including discriminative models for regression or classification, as well as generative models for density modelling. The goal is to find parameters θ to approximate the true data distribution p (D) by the model distribution p(D|θ ). The Fisher information matrix (FIM) measures how much information can be obtained about parameters θ after observing data D under the model p(D|θ), and captures redundancies between parameters (Rao, 1992) . The FIM is defined as I(θ) = E D∼p(D|θ) ∇ θ (θ, D)∇ θ (θ, D) . By this definition, the FIM is a positive semi-definite matrix. It can be shown that under certain regularity conditions, and if is twice differentiable w.r.t. θ, then the FIM can also be computed as I(θ) = E D∼p(D|θ) ∇ 2 θ (θ, D) It is important to note that the average is computed over the model distribution p(D|θ), not the data distribution p (D). Averaging Eq. 1 over p (D) results in the empirical Fisher matrix (Kunstner et al., 2019) , while averaging Eq. 2 over p (D) results in the Hessian of the loss. The FIM, the empirical Fisher and the Hessian are all different and sometimes confused (Thomas et al., 2020) . We consider cases in which parameters θ are obtained by maximum likelihood: θ = arg min θ L(θ) where the population loss function is defined as L(θ) = E D∼p (θ, D) which is in general non-convex, in particular when the model distribution is parameterised by deep neural networks. The natural gradient update takes the form θ t+1 = θ t -η I(θ) -1 g(θ) with g(θ) = ∇ θ L(θ) (5)

