TRAINABLE WEIGHT AVERAGING: EFFICIENT TRAIN-ING BY OPTIMIZING HISTORICAL SOLUTIONS

Abstract

Stochastic gradient descent (SGD) and its variants are considered as the de-facto methods to train deep neural networks (DNNs). While recent improvements to SGD mainly focus on the descent algorithm itself, few works pay attention to utilizing the historical solutions-as an iterative method, SGD has gone through substantial explorations before convergence. Recently, an interesting attempt is stochastic weight averaging (SWA), which significantly improves the generalization by simply averaging the solutions at the tail stage of training. In this paper, we realize that the averaging coefficients could be determined in a trainable manner and propose Trainable Weight Averaging (TWA), a novel optimization method in the reduced subspace spanned by historical solutions. TWA has much greater flexibility and can be applied to the head stage of training to achieve training efficiency while preserving good generalization capability. Further, we propose a distributed training scheme to resolve the memory burden of large-scale training with efficient parallel computation. In the extensive numerical experiments, (i) TWA achieves consistent improvements over SWA with less sensitivity to learning rate; (ii) applying TWA in the head stage of training largely speeds up the convergence, resulting in over 40% time saving on CIFAR and 30% on ImageNet with improved generalization compared with regular training. The code of implementation is available https://github.com/nblt/TWA.

1. INTRODUCTION

Training deep neural networks (DNNs) usually requires a large amount of time. As the sizes of models and datasets grow larger, more efficient optimization methods together with better performance are increasingly demanded. In the existing works, great efforts have been made to improve the efficiency of stochastic gradient descent (SGD) and its variants, which mainly focus on adaptive learning rates (Duchi et al., 2011; Zeiler, 2012; Kingma & Ba, 2015; Loshchilov & Hutter, 2019; Yao et al., 2021; Heo et al., 2021) or accelerated schemes (Polyak, 1964; Nesterov, 1983; 1988; 2003) . As an iterative descent method, SGD generates a series of solutions during optimization. These historical solutions provide dynamic information about the training and have brought many interesting perspectives, e.g., trajectory (Li et al., 2022 ), landscape (Garipov et al., 2018) , to name a few. In fact, they can also be utilized to improve the training performance, resulting in the so-called stochastic weight averaging (SWA) (Izmailov et al., 2018) , which shows significantly better generalization by simply averaging the tail stage explorations of SGD. A similar idea could be found in Szegedy et al. (2016) , which designs an exponential moving average (EMA) and considers that a heuristic strategy could be better than equivalently averaging. The success of SWA and EMA encourages more indepth investigations on the roles of historical solutions obtained during the training (Athiwaratkun et al., 2018; Nikishin et al., 2018; Yang et al., 2019) . In this paper, our main purpose is to utilize historical solutions by optimizing them, rather than using fixed averaging (e.g., SWA) or a heuristic combination (e.g., EMA). With such an optimized averaging scheme, we can achieve higher accuracy using only the solutions in a relatively early state (i.e. the head stage). In other words, we speed up the training and meanwhile improve the performance. The idea of utilizing these early solutions in DNNs' training is mainly inspired by two facts. On the CIFAR-10 (Krizhevsky & Hinton, 2009) , while requiring to complete the whole 200 epochs to reach its final 95% accuracy. This observation also coincides with the recent findings that the key connectivity patterns of DNNs emerge early in training (You et al., 2020; Frankle et al., 2020) , indicating a well-explored solution space formed. On the other hand, simply averaging the solutions collected at the SWA stage immediately provides a huge accuracy improvement, e.g. over 16% on CIFAR-100 with Wide ResNet-28-10 (Zagoruyko & Komodakis, 2016) than before averaging (Izmailov et al., 2018) . These facts point out a promising direction that sufficiently utilizing these early explorations may be capable of quickly composing the final solution while obtaining good accuracy. 𝐰 3 𝐰 2 𝐰 1 𝐰 swa 𝐰 twa -𝜂𝑷(𝑷 ⊤ ∇𝐿(𝐰)) Figure 1: TWA intuition. Input: Sampled weights {w i } n i=1 , Batch size b, Loss function L : W × X × Y → R + , As the model parameters go through a rapid evolution at the early stage of training, a simple averaging strategy with fixed weighting coefficients as in SWA and EMA can result in large estimation errors. We introduce a Trainable Weight Averaging (TWA), which allows explicit adjustments for the averaging coefficients in a trainable manner. Specifically, we construct a subspace that contains all sampled solutions during the training and then conduct efficient optimization therein. As optimization in such a subspace takes into account all possible averaging choices, we are able to adaptively search for a good set of averaging coefficients regardless of the quality of sampled solutions and largely reduce the estimation errors. The proposed optimization scheme is essentially the gradient projection onto a tiny subspace. Hence, the degree of freedom for training is substantially reduced from the original millions to dozens or hundreds (equal to the dimension of the subspace), making TWA enjoy fast convergence and meanwhile immune to overfitting, the latter explains that better training accuracy of TWA over SWA or EMA can lead to better test accuracy. In extensive experiments with various network architectures on different tasks, we reach superior performance with TWA applied to the head stage of training. For instance, we attain 1.5 ∼ 2.2% accuracy improvement on CIFAR-100 and 0.1% on ImageNet with over 40% and 30% training epochs reduced, respectively, compared with the regular training. In summary, we make the following contributions: • We propose Trainable Weight Averaging (TWA) that allows the averaging coefficients determined in a trainable manner instead of a pre-defined strategy. It brings consistent improvements over SWA or EMA with reduced estimation error. • We successfully apply TWA to the head stage of training, resulting in a great time saving (e.g. over 40% on CIFAR and 30% on ImageNet) compared to regular training along with improved performance and reduced generalization gap. • Our TWA is easy to implement and can be flexibly plugged into different stages of training to bring consistent improvements. It provides a new scheme for achieving efficient DNNs' training by sufficiently utilizing historical explorations.

2. METHOD

In this section, we first formulate the optimization target of TWA. Then a detailed training algorithm is introduced, which consists of two phases: Schmidt orthogonalization and projected optimization.



Learning rate η. , t = 0, P = [e 1 , e 2 , • • • , e n ]; while not converged do Sample batch data: B = {(x k , y k )} TWA algorithm. one hand, high test accuracy commonly starts appearing at an early stage. For example, a PreAct ResNet-164 model (He et al., 2016) achieves over 80% test accuracy within 10 training epochs on

