FISHER-LEGENDRE (FISHLEG) OPTIMIZATION OF DEEP NEURAL NETWORKS

Abstract

Incorporating second-order gradient information (curvature) into optimization can dramatically reduce the number of iterations required to train machine learning models. In natural gradient descent, such information comes from the Fisher information matrix which yields a number of desirable properties. As exact natural gradient updates are intractable for large models, successful methods such as KFAC and sequels approximate the Fisher in a structured form that can easily be inverted. However, this requires model/layer-specific tensor algebra and certain approximations that are often difficult to justify. Here, we use ideas from Legendre-Fenchel duality to learn a direct and efficiently evaluated model for the product of the inverse Fisher with any vector, in an online manner, leading to natural gradient steps that get more accurate over time despite noisy gradients. We prove that the resulting "Fisher-Legendre" (FishLeg) optimizer converges to a (global) minimum of non-convex functions satisfying the PL condition, which applies in particular to deep linear networks. On standard auto-encoder benchmarks, we show empirically that FishLeg outperforms standard first-order optimization methods, and performs on par with or better than other second-order methods, especially when using small batches. Thanks to its generality, we expect our approach to facilitate the handling of a variety neural network layers in future work.

1. INTRODUCTION & SUMMARY OF CONTRIBUTIONS

The optimization of machine learning models often benefits from the use of second-order gradient (curvature) information, which can dramatically increase the per-iteration progress on the training loss. However, second-order optimizers for machine learning face a number of key challenges: 1. ML models tend to have many parameters, such that estimating the curvature along all dimensions is intractable. There are normally two ways around this: (i) using iterative methods where each iteration only exploits curvature information along a single dimension ("Hessian-free" methods; Martens et al., 2010) , and (ii) developing custom curvature approximations that can be efficiently inverted to obtain parameter updates (Martens & Grosse, 2015; Grosse & Martens, 2016; Botev et al., 2017; George et al., 2018; Bahamou et al., 2022; Soori et al., 2022) . The latter yields state-of-the-art performance in the optimization of deep networks (Goldfarb et al., 2020) . However, development of structured curvature approximations must be done on a case-by-case basis for each architecture (e.g. fully-connected or convolutional layer) and requires mathematical assumptions that are difficult to justify. 2. ML datasets tend to be large as well, such that the loss, its gradients, and its curvature can be only stochastically estimated from mini-batches. While noise can be mitigated by a combination of large mini-batches and momentum, estimating the various components of the curvature matrix before inverting it (as opposed to estimating the inverse directly) introduces a bias that can potentially be detrimental. 1

