TWO-TAILED AVERAGING: ANYTIME ADAPTIVE ONCE-IN-A-WHILE OPTIMAL WEIGHT AVERAGING FOR BETTER GENERALIZATION

Abstract

Tail averaging improves on Polyak averaging's non-asymptotic behaviour by excluding a number of leading iterates of stochastic optimization from its calculations. In practice, with a finite number of optimization steps and a learning rate that cannot be annealed to zero, tail averaging can get much closer to a local minimum point of the training loss than either the individual iterates or the Polyak average. However, the number of leading iterates to ignore is an important hyperparameter, and starting averaging too early or too late leads to inefficient use of resources or suboptimal solutions. Our work focusses on improving generalization, which makes setting this hyperparameter even more difficult, especially in the presence of other hyperparameters and overfitting. Furthermore, before averaging starts, the loss is only weakly informative of the final performance, which makes early stopping unreliable. To alleviate these problems, we propose an anytime variant of tail averaging intended for improving generalization not pure optimization, that has no hyperparameters and approximates the optimal tail at all optimization steps. Our algorithm is based on two running averages with adaptive lengths bounded in terms of the optimal tail length, one of which achieves approximate optimality with some regularity. Requiring only the additional storage for two sets of weights and periodic evaluation of the loss, the proposed two-tailed averaging algorithm is a practical and widely applicable method for improving generalization.

1. INTRODUCTION

For the series of iterates produced by Stochastic Gradient Descent (SGD) (Robbins and Monro, 1985) to converge to a local minimum point of the training loss, the learning rate must be annealed to zero. Polyak averaging (Polyak and Juditsky, 1992; Ruppert, 1988) improves on SGD and achieves a statistically optimal convergence rate by averaging all iterates to produce the final solution. Tail or suffix averaging (Jain et al., 2018; Rakhlin et al., 2011) takes this further and improves the nonasymptotic behaviour by dropping a number of leading iterates from the average, speeding up the decay of the effect of the initial state while allowing the learning rate to stay constant. Both of these properties are advantageous in practice, where a finite number of optimization steps are taken, and because large learning rates may bias optimization towards flatter and wider minima, which improves generalization (Hochreiter and Schmidhuber, 1997; Keskar et al., 2016) . Focussing on large learning rates, flat minima, and generalization, Izmailov et al. (2018) propose Stochastic Weight Averaging (SWA), which takes the same form as tail averaging but is motivated from an ensembling point of view. Tail averaging starts after a given number of optimization steps. Setting this hyperparameter to minimize the training loss already poses some difficulties, which only become more pronounced and numerous in the context of generalization, our primary focus in this work. • Triggering averaging too early is inefficient as the average must grow long for the early weights to matter less. • Triggering averaging too late is inefficient as it does not use valuable information. • Due to interdependencies, tuning of other hyperparameters may become harder. There is always one short (S) and one long (L) average, with the long one having more iterates averaged and a better loss. When the loss with the short one would become better, the short average is renamed to long (marked by arrowheads), and the short one is restarted from an empty state (marked by discontinuities between double vertical lines). The long average thus always starts from the final state of the preceeding short average and incorporates more iterates until the new short average becomes at least as good in terms of the loss. In any interval labeled with L (from an arrowhead to a vertical bar), there is at least one point where the length of the long average is near optimal. • The learning curves have a sudden drop at the onset of averaging and are not necessarily informative of their eventual performance before that. Early stopping of seemingly unpromising training runs can cull solutions that would benefit more from averaging. Motivated by these problems, we propose the two-tailed averaging algorithm with the following features. • Anytime: An estimate of the optimal tail is maintained at all optimization steps. • Adaptive: It has no hyperparameters. The number of weights averaged (the length of the tail) is determined adaptively based on the evolution of generalization performance. • Optimal once in a while: The tail length achieves near optimality regularly. The algorithm is very easy to implement. Its principal cost is the storage required for a second running average, and it also performs more evaluations of generalization performance (e.g. of the validation loss). The main idea, sketched in Figure 1 , is to maintain two running averages of optimization iterates: a short and a long one, with the long average being our estimate of the optimal weights.

2.1. AVERAGING IN PURE OPTIMIZATION

Polyak averaging as originally proposed (Ruppert, 1988; Polyak and Juditsky, 1992) computes the equally weighted average θt = 1 t + 1 t i=0 θ i of all iterates θ i from the optimizer up to the current time step t. The convergence rate of θt was analyzed in the convex case with an appropriately decaying learning rate. Beyond this strictest interpretation, Polyak (or Polyak-Ruppert) averaging may refer to using θt without the convexity assumption, without a decaying learning rate, or with another optimizer such as Adam (Kingma and Ba, 2014) . In practice, where finite budget considerations override the asymptotic optimality guarantees offered by theory, Polyak averaging may refer to an exponential moving average (EMA) of the form θ0 = θ 0 , ) θt = (1 -β t )θ t + β t θt-1 (t ⩾ 1), where β t < 1 may be a constant near 1 or it may be scheduled as by Martens (2020) . The idea here is to improve the rate of decay of the effect of the initial error by downweighting early iterates. Tail averaging (TA) (Jain et al., 2018) , also known as suffix averaging (Rakhlin et al., 2011) , considers a finite optimization budget of n steps with a constant learning rate. At the cost of introducing a hyperparameter s to control the start of averaging, it improves the rate of decay of the effect of the initial error while obtaining near-minimax rates on the variance. Tail averaging is defined as θt = θ t (t < s),



Figure1: Example evolution of the two running averages of weights over optimization steps. Continuous lines indicate optimization iterates being added to averages over the course of optimization. There is always one short (S) and one long (L) average, with the long one having more iterates averaged and a better loss. When the loss with the short one would become better, the short average is renamed to long (marked by arrowheads), and the short one is restarted from an empty state (marked by discontinuities between double vertical lines). The long average thus always starts from the final state of the preceeding short average and incorporates more iterates until the new short average becomes at least as good in terms of the loss. In any interval labeled with L (from an arrowhead to a vertical bar), there is at least one point where the length of the long average is near optimal.

