TRAINABLE WEIGHT AVERAGING: EFFICIENT TRAIN-ING BY OPTIMIZING HISTORICAL SOLUTIONS

Abstract

Stochastic gradient descent (SGD) and its variants are considered as the de-facto methods to train deep neural networks (DNNs). While recent improvements to SGD mainly focus on the descent algorithm itself, few works pay attention to utilizing the historical solutions-as an iterative method, SGD has gone through substantial explorations before convergence. Recently, an interesting attempt is stochastic weight averaging (SWA), which significantly improves the generalization by simply averaging the solutions at the tail stage of training. In this paper, we realize that the averaging coefficients could be determined in a trainable manner and propose Trainable Weight Averaging (TWA), a novel optimization method in the reduced subspace spanned by historical solutions. TWA has much greater flexibility and can be applied to the head stage of training to achieve training efficiency while preserving good generalization capability. Further, we propose a distributed training scheme to resolve the memory burden of large-scale training with efficient parallel computation. In the extensive numerical experiments, (i) TWA achieves consistent improvements over SWA with less sensitivity to learning rate; (ii) applying TWA in the head stage of training largely speeds up the convergence, resulting in over 40% time saving on CIFAR and 30% on ImageNet with improved generalization compared with regular training. The code of implementation is available https://github.com/nblt/TWA.

1. INTRODUCTION

Training deep neural networks (DNNs) usually requires a large amount of time. As the sizes of models and datasets grow larger, more efficient optimization methods together with better performance are increasingly demanded. In the existing works, great efforts have been made to improve the efficiency of stochastic gradient descent (SGD) and its variants, which mainly focus on adaptive learning rates (Duchi et al., 2011; Zeiler, 2012; Kingma & Ba, 2015; Loshchilov & Hutter, 2019; Yao et al., 2021; Heo et al., 2021) or accelerated schemes (Polyak, 1964; Nesterov, 1983; 1988; 2003) . As an iterative descent method, SGD generates a series of solutions during optimization. These historical solutions provide dynamic information about the training and have brought many interesting perspectives, e.g., trajectory (Li et al., 2022) , landscape (Garipov et al., 2018) , to name a few. In fact, they can also be utilized to improve the training performance, resulting in the so-called stochastic weight averaging (SWA) (Izmailov et al., 2018) , which shows significantly better generalization by simply averaging the tail stage explorations of SGD. A similar idea could be found in Szegedy et al. (2016) , which designs an exponential moving average (EMA) and considers that a heuristic strategy could be better than equivalently averaging. The success of SWA and EMA encourages more indepth investigations on the roles of historical solutions obtained during the training (Athiwaratkun et al., 2018; Nikishin et al., 2018; Yang et al., 2019) . In this paper, our main purpose is to utilize historical solutions by optimizing them, rather than using fixed averaging (e.g., SWA) or a heuristic combination (e.g., EMA). With such an optimized averaging scheme, we can achieve higher accuracy using only the solutions in a relatively early state (i.e. the head stage). In other words, we speed up the training and meanwhile improve the performance. The idea of utilizing these early solutions in DNNs' training is mainly inspired by two facts. On the

