META-STORM: GENERALIZED FULLY-ADAPTIVE VARIANCE REDUCED SGD FOR UNBOUNDED FUNC-TIONS

Abstract

We study the application of variance reduction (VR) techniques to general nonconvex stochastic optimization problems. In this setting, the recent work STORM (Cutkosky & Orabona, 2019) overcomes the drawback of having to compute gradients of "mega-batches" that earlier VR methods rely on. There, STORM utilizes recursive momentum to achieve the VR effect and is then later made fully adaptive in STORM+ (Levy et al., 2021), where full-adaptivity removes the requirement for obtaining certain problem-specific parameters such as the smoothness of the objective and bounds on the variance and norm of the stochastic gradients in order to set the step size. However, STORM+ crucially relies on the assumption that the function values are bounded, excluding a large class of useful functions. In this work, we propose META-STORM, a generalized framework of STORM+ that removes this bounded function values assumption while still attaining the optimal convergence rate for non-convex optimization. META-STORM not only maintains full-adaptivity, removing the need to obtain problem specific parameters, but also improves the convergence rate's dependency on the problem parameters. Furthermore, META-STORM can utilize a large range of parameter settings that subsumes previous methods allowing for more flexibility in a wider range of settings. Finally, we demonstrate the effectiveness of META-STORM through experiments across common deep learning tasks. Our algorithm improves upon the previous work STORM+ and is competitive with widely used algorithms after the addition of per-coordinate update and exponential moving average heuristics.

1. INTRODUCTION

In this paper, we consider the stochastic optimization problem in the form min x∈R d F (x) := E ξ∼D [f (x, ξ)] , where F : R d → R is possibly non-convex. We assume only access to a first-order stochastic oracle via sample functions f (x, ξ), where ξ comes from a distribution D representing the randomness in the sampling process. Optimization problems of this form are ubiquitous in machine learning and deep learning. Empirical risk minimization (ERM) is one instance, where F (x) is the loss function that can be evaluated by a sample or a minibatch represented by ξ. 2021) avoid this drawback by using a weighted average of past gradients, often known as momentum. When the weights are selected appropriately, momentum reduces the error in the gradient estimates which improves the convergence rate.

An important advance in solving

A different line of work on adaptive methods (Duchi et al., 2011; Kingma & Ba, 2014) , some of which incorporate momentum techniques, have shown tremendous success in practice. These adap-



Problem (1) is the recent development of variance reduction (VR) techniques that improve the convergence rate to critical points of vanilla SGD from O(1/T 1/4 ) to O(1/T 1/3 ) (Fang et al., 2018; Li et al., 2021) for the class of mean-squared smooth functions (Arjevani et al., 2019). In contrast to earlier VR algorithms which often require the computation of the gradients over large batches, recent methods such as Cutkosky & Orabona (2019); Levy et al. (2021); Huang et al. (

