META-STORM: GENERALIZED FULLY-ADAPTIVE VARIANCE REDUCED SGD FOR UNBOUNDED FUNC-TIONS

Abstract

We study the application of variance reduction (VR) techniques to general nonconvex stochastic optimization problems. In this setting, the recent work STORM (Cutkosky & Orabona, 2019) overcomes the drawback of having to compute gradients of "mega-batches" that earlier VR methods rely on. There, STORM utilizes recursive momentum to achieve the VR effect and is then later made fully adaptive in STORM+ (Levy et al., 2021), where full-adaptivity removes the requirement for obtaining certain problem-specific parameters such as the smoothness of the objective and bounds on the variance and norm of the stochastic gradients in order to set the step size. However, STORM+ crucially relies on the assumption that the function values are bounded, excluding a large class of useful functions. In this work, we propose META-STORM, a generalized framework of STORM+ that removes this bounded function values assumption while still attaining the optimal convergence rate for non-convex optimization. META-STORM not only maintains full-adaptivity, removing the need to obtain problem specific parameters, but also improves the convergence rate's dependency on the problem parameters. Furthermore, META-STORM can utilize a large range of parameter settings that subsumes previous methods allowing for more flexibility in a wider range of settings. Finally, we demonstrate the effectiveness of META-STORM through experiments across common deep learning tasks. Our algorithm improves upon the previous work STORM+ and is competitive with widely used algorithms after the addition of per-coordinate update and exponential moving average heuristics.

1. INTRODUCTION

In this paper, we consider the stochastic optimization problem in the form min x∈R d F (x) := E ξ∼D [f (x, ξ)] , where F : R d → R is possibly non-convex. We assume only access to a first-order stochastic oracle via sample functions f (x, ξ), where ξ comes from a distribution D representing the randomness in the sampling process. Optimization problems of this form are ubiquitous in machine learning and deep learning. Empirical risk minimization (ERM) is one instance, where F (x) is the loss function that can be evaluated by a sample or a minibatch represented by ξ. An important advance in solving Problem (1) is the recent development of variance reduction (VR) techniques that improve the convergence rate to critical points of vanilla SGD from O(1/T  × O κ 1/2 +κ 3/4 G -1/2 +σ+ G log 3/4 T T 1/2 + σ 1/3 T 1/3 3', 4 κ = O (β (F (x1) -F * )) Super-ADAM (Huang et al., 2021) × O κ 1/2 + σ log T 1 T 1/2 + 1 T 1/3 3' κ = O (β (F (x1) -F * )) Does not adapt to σ STORM+ (Levy et al., 2021) O κ 1 T 1/2 + κ 2 σ 1/3 T 1/3 3', 4, 6 κ1 = O β 9/4 + G 5 + β 3/2 G 6 + B 9/8 κ2 = O β 3/2 + B 3/4 META-STORM-SG, p = 1 2 (Ours) 1 O κ1 + κ2 log 1 + σ 2 T 1 T 1/2 + σ 1/3 T 1/3 3, 4 κ1 = O F (x1) -F * + σ 2 + G 2 + κ2 log κ2 κ2 = O((1 + G 2 )β) META-STORM, p = 1 2 (Ours) O κ1 + κ2 log 1 + σ 2 T 1 T 1/2 + σ 1/3 T 1/3 3, 5 κ1 = O F (x1) -F * + σσ + σ 2 + σ 3 + κ2 log κ2 κ2 = O((1 + σ 3 )β) tive methods remove the burden of obtaining certain problem-specific parameters, such as smoothness, in order to set the right step size to guarantee convergence. STORM+ (Levy et al., 2021) is the first algorithm to bridge the gap between fully-adaptive algorithms and VR methods, achieving the variance-reduced convergence rate of O(1/T 1/3 ) while not requiring knowledge of any problemspecific parameter. This is also the first work to demonstrate the interplay between adaptive momentum and step sizes to adapt to the problem's structure, while still achieving the VR rate. However, STORM+ relies on a strong assumption that the function values are bounded, which generally does not hold in practice. Moreover, the convergence rate of STORM+ has high polynomial dependencies on the problem parameters, compared to what can be achieved by appropriately configuring the step sizes and momentum parameters given knowledge of the problem parameters (see Section 3.1). Our contributions: In this work, we propose META-STORM-SG and META-STORM, two flexible algorithmic frameworks that attain the optimal variance-reduced convergence rate for general nonconvex objectives. Both of them generalize STORM+ by allowing a wider range of parameter selection and removing the restrictive bounded function value assumption while maintaining its desirable fully-adaptive property -eliminating the need to obtain any problem-specific parameter. These have been enabled via our novel analysis framework that also establishes a convergence rate with much better dependency on the problem parameters. We present a comparison of META-STORM and its sibling META-STORM-SG against recent VR methods in Table 1 . In the appendix, we propose another algorithm, META-STORM-NA, with even less restrictive assumptions; however, with a tradeoff of losing the adaptivity to the variance parameter. We complement our theoretical results with experiments across three common tasks: image classification, masked language modeling, and sentiment analysis. Our algorithms improve upon the previous work, STORM+. Furthermore, the addition of heuristics such as exponential moving average and per-coordinate updates improves our algorithms' generalization performance. These versions of our algorithms are shown to be competitive with widely used algorithms such as Adam and AdamW.



Comparison of the convergence rate after T iterations under constant success probability. The assumptions and definitions of the parameters referenced can be found in Section 1.2. Assumptions 1 and 2 are used in all algorithms, thus we leave them out from the table.

