TRAINABLE WEIGHT AVERAGING: EFFICIENT TRAIN-ING BY OPTIMIZING HISTORICAL SOLUTIONS

Abstract

Stochastic gradient descent (SGD) and its variants are considered as the de-facto methods to train deep neural networks (DNNs). While recent improvements to SGD mainly focus on the descent algorithm itself, few works pay attention to utilizing the historical solutions-as an iterative method, SGD has gone through substantial explorations before convergence. Recently, an interesting attempt is stochastic weight averaging (SWA), which significantly improves the generalization by simply averaging the solutions at the tail stage of training. In this paper, we realize that the averaging coefficients could be determined in a trainable manner and propose Trainable Weight Averaging (TWA), a novel optimization method in the reduced subspace spanned by historical solutions. TWA has much greater flexibility and can be applied to the head stage of training to achieve training efficiency while preserving good generalization capability. Further, we propose a distributed training scheme to resolve the memory burden of large-scale training with efficient parallel computation. In the extensive numerical experiments, (i) TWA achieves consistent improvements over SWA with less sensitivity to learning rate; (ii) applying TWA in the head stage of training largely speeds up the convergence, resulting in over 40% time saving on CIFAR and 30% on ImageNet with improved generalization compared with regular training. The code of implementation is available https://github.com/nblt/TWA.

1. INTRODUCTION

Training deep neural networks (DNNs) usually requires a large amount of time. As the sizes of models and datasets grow larger, more efficient optimization methods together with better performance are increasingly demanded. In the existing works, great efforts have been made to improve the efficiency of stochastic gradient descent (SGD) and its variants, which mainly focus on adaptive learning rates (Duchi et al., 2011; Zeiler, 2012; Kingma & Ba, 2015; Loshchilov & Hutter, 2019; Yao et al., 2021; Heo et al., 2021) or accelerated schemes (Polyak, 1964; Nesterov, 1983; 1988; 2003) . As an iterative descent method, SGD generates a series of solutions during optimization. These historical solutions provide dynamic information about the training and have brought many interesting perspectives, e.g., trajectory (Li et al., 2022) , landscape (Garipov et al., 2018) , to name a few. In fact, they can also be utilized to improve the training performance, resulting in the so-called stochastic weight averaging (SWA) (Izmailov et al., 2018) , which shows significantly better generalization by simply averaging the tail stage explorations of SGD. A similar idea could be found in Szegedy et al. (2016) , which designs an exponential moving average (EMA) and considers that a heuristic strategy could be better than equivalently averaging. The success of SWA and EMA encourages more indepth investigations on the roles of historical solutions obtained during the training (Athiwaratkun et al., 2018; Nikishin et al., 2018; Yang et al., 2019) . In this paper, our main purpose is to utilize historical solutions by optimizing them, rather than using fixed averaging (e.g., SWA) or a heuristic combination (e.g., EMA). With such an optimized averaging scheme, we can achieve higher accuracy using only the solutions in a relatively early state (i.e. the head stage). In other words, we speed up the training and meanwhile improve the performance. The idea of utilizing these early solutions in DNNs' training is mainly inspired by two facts. On the one hand, high test accuracy commonly starts appearing at an early stage. For example, a PreAct ResNet-164 model (He et al., 2016) achieves over 80% test accuracy within 10 training epochs on CIFAR-10 (Krizhevsky & Hinton, 2009) , while requiring to complete the whole 200 epochs to reach its final 95% accuracy. This observation also coincides with the recent findings that the key connectivity patterns of DNNs emerge early in training (You et al., 2020; Frankle et al., 2020) , indicating a well-explored solution space formed. On the other hand, simply averaging the solutions collected at the SWA stage immediately provides a huge accuracy improvement, e.g. over 16% on CIFAR-100 with Wide ResNet-28-10 (Zagoruyko & Komodakis, 2016) than before averaging (Izmailov et al., 2018) . These facts point out a promising direction that sufficiently utilizing these early explorations may be capable of quickly composing the final solution while obtaining good accuracy. As the model parameters go through a rapid evolution at the early stage of training, a simple averaging strategy with fixed weighting coefficients as in SWA and EMA can result in large estimation errors. We introduce a Trainable Weight Averaging (TWA), which allows explicit adjustments for the averaging coefficients in a trainable manner. Specifically, we construct a subspace that contains all sampled solutions during the training and then conduct efficient optimization therein. As optimization in such a subspace takes into account all possible averaging choices, we are able to adaptively search for a good set of averaging coefficients regardless of the quality of sampled solutions and largely reduce the estimation errors. The proposed optimization scheme is essentially the gradient projection onto a tiny subspace. Hence, the degree of freedom for training is substantially reduced from the original millions to dozens or hundreds (equal to the dimension of the subspace), making TWA enjoy fast convergence and meanwhile immune to overfitting, the latter explains that better training accuracy of TWA over SWA or EMA can lead to better test accuracy. In extensive experiments with various network architectures on different tasks, we reach superior performance with TWA applied to the head stage of training. For instance, we attain 1.5 ∼ 2.2% accuracy improvement on CIFAR-100 and 0.1% on ImageNet with over 40% and 30% training epochs reduced, respectively, compared with the regular training. In summary, we make the following contributions: • We propose Trainable Weight Averaging (TWA) that allows the averaging coefficients determined in a trainable manner instead of a pre-defined strategy. It brings consistent improvements over SWA or EMA with reduced estimation error. 

2. METHOD

In this section, we first formulate the optimization target of TWA. Then a detailed training algorithm is introduced, which consists of two phases: Schmidt orthogonalization and projected optimization. Note in this paper, the model's weights are aligned as a vector, i.e., w ∈ R D , where D denotes the number of parameters.

2.1. OPTIMIZATION TARGET

In SWA, weight averaging is simply given by w swa = 1 n ∑ n i=1 w i , where n solutions of the network collected at the tail of training are equally weighted. Such an averaging strategy has been proven quite effective with improved generalization ability. However, equally averaging could not always be a perfect solution, which motivates some heuristic modifications on weighting strategy, e.g., EMA. Both SWA and EMA are fixed averaging strategies, which may not adequately adapt to the head stage of training and would result in estimation errors, due to the fact that early historical solutions have not stepped into a stationary distribution. In this paper, we propose to optimize the averaging coefficients of different weights with the hope of reducing the corresponding estimation error. Specifically, the set of possible TWA solutions considered, i.e., w twa , can be represented as follows: A = {α 1 w 1 + α 2 w 2 + • • • + α n w n | α i ∈ R} . (1) The weight vectors between consecutive solutions could have a high cosine similarity. To decouple them and for better optimization, we will further orthogonalize {w i } n i=1 and find a set of orthogonal bases {e i } n i=1 to support the solution space, i.e., A = {β 1 e 1 + β 2 e 2 + • • • + β n e n | β i ∈ R}. Then we search for a good solution w twa in A by optimizing the following problem, min β1,β2,••• ,βn E (x,y)∼D [L (f (w twa ; x) , y)] + λ 2 n ∑ i=1 β 2 i , s.t. w twa = β 1 e 1 + β 2 e 2 + • • • + β n e n , where L(•, •) is the loss function as in regular training and the second term is a regularization coefficient λ > 0. Note that both SWA and EMA are special solutions of (2) without optimization. Optimizing over β i brings benefits in the view of training loss, and a good generalization ability could also be expected: in regular training, the number of optimization variables is D, which is very large, but in (2), there are only n averaging coefficients {β i } n i=1 to be optimized. The significant dimensionality reduction could benefit better generalization.

2.2. TRAINING ALGORITHM

Instead of directly optimizing β i , we note that there exists a bijection between the coefficient space {β i } n i=1 ∈ R n and the parameter space R D , i.e., each set of the coefficients is uniquely mapped to one point in the parameter space, which forms a complete subspace (with dimensionality n). We could alternatively optimize these coefficients in such a subspace. We first focus on finding a set of orthogonal bases {e i } n i=1 to span the subspace that covers {w i } n i=1 . This is a standard Schmidt orthogonalization, which sequentially takes the following steps: { e k = w k - ( w k ⊤ e 1 ) e 1 - ( w k ⊤ e 2 ) e 2 -• • • - ( w k ⊤ e k-1 ) e k-1 , e k = e k /∥e k ∥ 2 . (3) We then initialize the w (0) twa as one point in the subspace (e.g. 1 n ∑ n i=1 w i ), and optimize the network's parameters therein. Let P = [e 1 , e 2 , • • • , e n ], such optimization can be easily achieved by projecting the gradient onto the subspace via projection matrix P P ⊤ . We summarize the detailed training procedures of the proposed TWA in Algorithm 1 with an intuitive illustration in Figure 1 . The detailed implementation is described in Appendix B.

3. OPTIMIZATION PROVIDES BETTER FLEXIBILITY

The key difference between SWA and TWA is that the averaging coefficients in TWA are determined in a trainable manner, or more precisely, are data-dependent. This potentially enables more precise estimation for the center minima and better tolerance for the outliers that are not aware by SWA. Notice that in the view of coefficient optimization, there is no essential difference between SWA and EMA, which both provide specific and data-independent solutions. Thus, we in the following only compare TWA with SWA. Mandt et al. (2017) demonstrated that under appropriate assumptions, running SGD with a constant learning rate is equivalent to sampling from a stationary Gaussian distribution, and the variance of the distribution is controlled by the learning rate. Accordingly, we assume the solutions at the tail stage of SGD training are sampled from a Gaussian distribution N (µ, Σ) centered at the minimum µ with covariance Σ. Approximately, the sampled solutions {w i } n i=1 are independent random variables from N (µ, Σ), as long as there are sufficient iterations between adjacent samplings. SWA and TWA provide two estimators for the minimum µ, i.e., w swa and w twa . As an averaged solution, w swa has statistically better estimation than any single solution due to the effect of variance reduction, while w twa is approaching the center by minimizing the training loss. As long as the training loss serves as meaningful supervision (which holds under the typical assumption that µ is the center minimum with lowest training loss Mandt et al. (2017) ; Izmailov et al. ( 2018)), w twa could approach µ better than w swa with the posterior optimization for averaging coefficients. In this regard, w twa could have a lower expected variance: E ( ∥w twa -µ∥ 2 2 ) ≤ E ( ∥w swa -µ∥ 2 2 ) . The advantages of optimizing averaging coefficients could be more prominent in the head stage of training, where the weights are going through a rapid evolution. A simple averaging strategy as SWA could introduce a large estimation error (as illustrated in Table 3 and 4 ), while TWA enables correcting it to a smaller estimation error via optimization. In fact, TWA provides much more flexibility to sufficiently utilize historical solutions and produces an optimized solution adaptively.

4. AN EFFICIENT IMPLEMENTATION FOR DISTRIBUTED TRAINING

The above discussion shows promising improvement by optimizing the historical solutions. The only issue one may worry about is the burden in storage (the additional time complexity is small as shown in Table 6 ). During optimization, TWA requires the projection matrix P involving dozens or hundreds of historical weights, which indeed poses a challenge for large models on storage burden. It is preferable to locate P in GPUs to enable efficient matrix operations. However, the size of P increases as the model becomes larger, making it prohibitive to store in a single GPU. To cope with this, we design an efficient scheme with parallel distributed training to enable a) partition of the memory burden of P into multiple GPUs and b) efficient parallel computation of gradient projection. As a result, we successfully optimize more than 900 historical solution coefficients for ResNet-50 on ImageNet task by 4 v100 GPUs. In our experiments, we use at most 300 historical solutions and there is still available space for larger tasks. Suppose that there are k GPUs for parallel training. We first uniformly divide P into k sub-matrices as P = [P 1 , P 2 , • • • , P k ], where each GPU stores a local sub-matrix P i , i = 1, • • • , k. Recall that for an iteration in distributed training, each GPU computes a local gradient g i and synchronizes it with other GPUs to obtain the global gradient through an efficient all-reduce operation (Rabenseifner, 2004) . We mimic such a process for gradient projection: the local projected gradient P i P ⊤ i g is firstly computed in each card and then synchronized with others to obtain the global projected gradient with another all-reduce operation. We illustrate such a process in Figure 2 . For averaging n historical solutions with per size B, the memory burden for each GPU card is reduced to ⌈n/k⌉B, while the computation of gradient projection is also reduced to O(⌈n/k⌉D) (D is the number of the parameters). Hence, we can achieve efficient TWA training by making full use of the remaining memory of each GPU aside the forward/backward training.

5. EXPERIMENTS

In this section, a series of numerical experiments are conducted, demonstrating the effectiveness of our proposed TWA for fast convergence and better performance. First, we show that TWA improves SWA in the existing SWA settings, i.e., used in the tail stage of training. Second, we apply TWA to the head stage of training, which brings significant efficiency improvements together with better performance. Then, we visualize loss / accuracy surfaces to demonstrate the improvements of TWA.

5.1. EXPERIMENTAL SETTINGS

Datasets. We experiment over three benchmark image datasets, i.e., CIFAR-10, CIFAR-100 (Krizhevsky & Hinton, 2009) , and ImageNet (Deng et al., 2009) . Following prior works (Izmailov et al., 2018; Yang et al., 2019) , we apply standard preprocessing for experiments on CIFAR datasets, and adopt the preprocessing and data augmentation procedures in the public Pytorch example on ImageNet (Paszke et al., 2017) . Architectures. We use two representative architectures, VGG-16 (Simonyan & Zisserman, 2014) and PreAct ResNet-164 (He et al., 2016) on CIFAR experiments. For ImageNet, we use ResNet-18 and ResNet-50 (He et al., 2016) . Training. The main body of experiments contains two parts: (1) for the tail stage of training, we use the same hyper-parameters as in SWA (Izmailov et al., 2018) and then a larger tail learning rate is also tried. (2) for the head stage of training, we adopt the standard training protocol with a step-wise learning rate. For CIFAR, we run all experiments with 3 seeds and report the mean test accuracy. We use SGD optimizer with momentum 0.9, weight decay 10 -4 , and batch size 128. We train the models for 200 epochs with an initial learning rate 0.1 and decay it by 10 at the 100th and the 150th epochs. For ImageNet, we follow official PyTorch implementationfoot_0 . For TWA, we sample solutions once after each epoch training for CIFAR and uniformly sample 5 times per epoch for ImageNet. We use a scaled learning rate (Figure 4 ), which takes 10 epochs of training for CIFAR and 2 epochs for ImageNet for fast convergence. The regularization coefficient λ defaults to 10 -5 . More details (including the number of the historical solutions used) could be found in Appendix A.

5.2. IMPROVING SWA SOLUTIONS

In this part, we compare SWA and TWA in the original SWA settings. Specifically, TWA and SWA use the same weights sampled from the tail stage of training. For CIFAR, we try two different tail learning rates: 0.05, the recommended one in (Izmailov et al., 2018) , and 0.10, a larger one for the case with greater variance. The results in Table 1 show that TWA brings consistent improvements over SWA. Especially when the learning rate is not well-tuned, SWA's performance suffers a distinctive drop, but TWA is less sensitive since the estimation error could be well controlled through training. For instance, in the case of CIFAR-100 with VGG-16 and a tail learning rate of 0.10, the estimation error of SWA substantially increases while TWA achieves a significant accuracy improvement, i.e., 4.18%, over SWA. Note that for a fair comparison, TWA starts from the last sampled weights, not the averaged solution of SWA. On ImageNet, we experiment with ResNet-18/50 (He et al., 2016) . Following (Izmailov et al., 2018; Yang et al., 2019) , we start from pre-trained models in torchvision.models and collect model weights by running SGD optimizer up to 10 epochs (with a constant learning rate 0.005). In Table 2 , we report the test accuracy and observe that with more sampling epochs, both TWA and SWA achieve better performance. Notably, TWA performs better than SWA by 0.1 ∼ 0.3%, and such improvements are more obvious in the "5 EPOCHS" case. For example, using 5 epochs of sampled weights, TWA achieves 70.23% accuracy with ResNet-18 and 76.78% accuracy with ResNet-50, outperforming the SWA counterparts with 10 epochs. This indicates that TWA requires fewer historical solution samples to achieve a comparable or even better performance than SWA, due to its ability to reduce the estimation variance with optimized averaging coefficients. 

5.3. EFFICIENT TRAINING AND BETTER GENERALIZATION

In the head stage of training, SWA usually fails due to the large estimation variance from fastevolving solutions and large learning rate. Since TWA could reduce the variance and be less sensitive to the learning rate, it can also be expected to work well in this stage. If so, it is promising to simultaneously attain generalization improvements and training efficiency. We first investigate the experiments on CIFAR datasets. The original training schedule contains 200 epochs and we take the first 100 epoch explorations for TWA. The results are given in Table 3 . It can be observed that TWA achieves better performance compared to the regular SGD training with a significantly reduced generalization gap. For instance, we attain 1.52% accuracy improvement on CIFAR-100 with VGG-16 while the generalization gap is reduced by 9.56%. This suggests that a better solution could already be composed using these historical solutions without further training by more delicate learning rates, which instead may bring overfitting problems and harm the generalization. By comparisons, we also apply SWA to average these samples, which shows degraded performance due to the existence of estimation error. Apart from the good performance in accuracy, TWA also manifests its great potential in improving the training efficiency: we use only 10 epochs to complete the convergence, while the regular SGD needs 100 epochs. As TWA and SGD have nearly the same computation overhead per epoch, the time-saving is around 45% in TWA. Generally, utilizing more epochs of explorations can provide a better estimation for the center minimum and hence lead to better performance. Then, we also study the impact of different averaging epochs, and the results are illustrated in Figure 3 , where the final accuracy of SGD and the accuracy reached by SGD before averaging are also given for reference. It could be observed that the model's performance is consistently improved with more epochs of explorations. Notably, although each historical solution in a relatively short period of explorations is not good, satisfied solutions have already emerged in the subspace spanned by these solutions. Then through proper optimization in subspace, TWA could find them out, e.g., on CIFAR-100 with PreAct ResNet-164 model, averaging over 50 epochs via TWA has already matched the final performance of regular SGD training. 2019)) with an aggressive learning rate decay (i.e., the learning rate is decayed at the 30th, 48th, and 58th epochs), while our TWA reaches 75.70% with the same budget but simply using the conventional decay. TWA is very flexible and can be readily applied to different training stages, and we also conduct an experiment by averaging the solutions of the final training period (i.e. 61-90 epochs) and simply performing TWA for one epoch training, as presented in Table 5 . Such cheap training still shows to bring significant improvements (e.g. +0.51% on ResNet-50 for Ima-geNet). Thus, TWA can serve as an effective approach for composing a better final solution. Scaled learning rate The optimization of TWA is conducted in a very low-dimensional space, which also suppresses the sensitivity of the learning rate. In fact, we can allow a very large learning rate to accelerate the training. Thereby, we design a scaled learning rate, which linearly scales up the learning rate and reduces the training epochs accordingly, as shown in Figure 4 . Within an appropriate range, scaling the learning rate largely speeds up the convergence without affecting the final performance. For example, with the learning rate of 4, TWA approaches the final accuracy with only 1 epoch and converges within 5 epochs. Comparison with EMA EMA serves as an alternative to SWA, which averages the model weights along the training trajectory with exponential decay. It requires a hyper-parameter γ to control the averaging horizons. As a manually defined averaging strategy, EMA could be sensitive to learning rates, datasets, architectures, etc. Here we compare the performance of EMA with SWA and TWA in the head stage of training, where we try γ = 0.99/0.999. The results are illustrated in Figure 5 . We observe that the performance of EMA varies significantly with different choices of γ. It could perform notably better than SWA in the early stage of training as more weight is paid to the latest solutions. Note that both EMA and SWA are fixed averaging strategies for adapting different training stages (essentially could be viewed as particular solutions of TWA). By optimizing the averaging coefficients, TWA could consistently achieve better performance. Optimized Averaging Coefficients In Figure 6 , we visualize the averaging coefficients α i learned by TWA. Detailed derivation could be found in Appendix D. We observe all historical solutions could contribute to the final solutions. Solutions from the latter training stage are attached with more importance as expected. Different from the fix averaging strategies like SWA or EMA, such averaging coefficients enable to take full advantage of the historical solutions through delicate optimization and better adapt to the training dynamics.

5.4. LANDSCAPE VISUALIZATION

Following (Garipov et al., 2018; Izmailov et al., 2018) , we visualize the training loss and test error surfaces of SWA and TWA in Figure 7 on CIFAR-100 with PreAct ResNet-164. We set the SGD solution after 125 training epochs as the origin and plot the TWA and SWA solutions on the plane. For the case with a default learning rate of 0.05, TWA achieves slightly better test accuracy with lower training loss. This shows that in the subspace, minimizing the training loss is meaningful and results in lower test errors. Especially for the case with a larger learning rate of 0.10, the superiority of TWA over SWA is more significant (over 0.7% improvement on test accuracy), since the variance grows larger and the variance reduction effect of TWA becomes more obvious.

6. RELATED WORK

Improving the model's generalization capability is of great importance and has received wide attention. The recent efforts mainly focus on two aspects: (1) proper regularization terms to search for more flat minimum (Keskar et al., 2016; Li et al., 2018) , such as weight decay (Krogh & Hertz, 1991) , dropout (Srivastava et al., 2014) , label smoothing (Szegedy et al., 2016) , Shake-Shake (Gastaldi, 2017), MixUp (Zhang et al., 2018) , SAM (Foret et al., 2020) and AMP (Zheng et al., 2021) ; (2) effective data augmentation to diversify the dataset, such as Cutout (DeVries & Taylor, 2017) , AutoAugment (Cubuk et al., 2019) and RandAugment (Cubuk et al., 2020) . Different from these techniques, we improve the generalization ability by constraining the training in a low-dimensional subspace spanned by historical explorations, which regularizes the model complexity. We note that TWA is orthogonal to these methods, and it is promising to combine them for boosted improvements. A lot of efforts have been made to speed up the DNNs' training. Apart from the well-known methods on adaptive learning rates, e.g. Adam (Kingma & Ba, 2015) and accelerated schemes, e.g. Nesterov momentum Nesterov (1983) , a new method is proposed in Zhang et al. (2019) proposed, where a look-ahead search direction generated by another "fast" optimizer is utilized, achieving faster convergence and better learning stability. Goyal et al. (2017) adapted a large mini-batch to speed up the training and introduced a scaling rule for adjusting the learning rates. In this paper, we realize training efficiency by sufficiently utilizing the historical solutions of DNNs' training and conducting training in a subspace with substantially reduced dimensions. For utilizing historical explorations, SWA (Izmailov et al., 2018) adopts a simple averaging strategy at the tail of training. Cha et al. (2021) extended it to the domain generalization task with a dense and overfit-aware stochastic weight sampling strategy. We firstly propose to utilize the explorations at the head stage of training to achieve training efficiency. Exponentially decaying running average (Hunter, 1986; Szegedy et al., 2016) is a common technique adopted by practitioners. It requires a manually set averaging horizon and generally performs comparably as SGD (Izmailov et al., 2018) . Another closely related work is model soups (Wortsman et al., 2022) , which improves the model performance by averaging the weights from different fine-tuning configurations in a greedy order. We differ in that the historical solutions are from one single configuration. We mainly focus on improving training efficiency and optimizing the averaging coefficients in a trainable manner.

7. CONCLUSION

In this work, we propose TWA, a novel training algorithm that optimizes the averaging coefficients of historical solutions in DNNs' training to achieve efficiency and better performance. It differs from the manually set averaging strategies as SWA or EMA and manifests better adaptation to different stages of training. We further design a parallel framework for large-scale training with efficiency in memory and computation. Extensive experiments demonstrate the superior performance of TWA on benchmark computer vision tasks with various architectures.

A TRAINING DETAILS

For SWA experiments, we replicate the SWA baseline by using the publicly released implementation of Izmailov et al. (2018) . We use VGG-16 architecture with batch normalization for a unified learning rate setting as PreAct ResNet-164. For ImageNet experiments, we follow official PyTorch implementation. We use a scaled learning rate for TWA training with 20x and 30x factors on CIFAR and ImageNet, respectively (e.g. the original learning rate of 0.1 is scaled up to 2 on CIFAR). CIFAR experiments are performed on one Nvidia Geforce GTX 2080 TI GPU, while ImageNet experiments are on four NVIDIA Tesla A100. The number of historical solutions optimized by TWA is presented in Table A1 . We now disclose the specific hyper-parameters in the following. Published as a conference paper at ICLR 2023 with a intial learning rate 3. We observe that it could similarly bring generalization improvement and training efficiency. Let η be the learning rate. We could optimize (A.1) with the following gradient descent: .4) Since w twa = P β, we have the corresponding update in the parameter space: B IMPLEMENTATION Let β = [β 1 , β 2 , • • • , β n ] ⊤ ∈ R n , P = [e 1 , e 2 , • • • , e n ] ∈ β (t+1) = β (t) -η ( P ⊤ ∂L(w (t) twa ) ∂w twa + λβ ) . ( w (t+1) twa = w (t) twa -η ( P P ⊤ ∂L(w (t) twa ) ∂w twa + λP β ) (A.5) = (1 -ηλ) w (t) twa -ηP P ⊤ ∂L(w (t) twa ) ∂w twa . (A.6) As β does not explicitly appear in (A.6), we could treat the coefficient β as an implicit variable. In practice, we optimize (A.1) by directly updating the model weights w twa with weight decay λ, which is an optimization in the reduced subspace with projection matrix P P ⊤ .

C SWA WITH DIFFERENT STARTING EPOCHS

We test the performance of SWA with different starting epochs to average on ImageNet with the ResNet-50 model. We observe that the performance of SWA gradually becomes better with the relatively latter stage of solutions averaged, showing that SWA could not adapt well to the head stage of training where the solutions are fast-evolving. Hence, a good solution for SWA may require manually selecting which period to average. We also notice that TWA (with 0-60 epoch solutions) consistently outperforms the SWA wherever the averaging begins, confirming that TWA could automatically find a good set of averaging coefficients and provide better performance. In the solution set A = {α 1 w 1 + α 2 w 2 + • • • + α n w n | α i ∈ R}, we do not explicitly require that the sum of averaging coefficients ∑ n i=1 α i to be 1 since the network's performance would be sensitive to a direct scaling over all parameters (i.e. kw), that is, a good solution in A will inherently have a coefficient sum very close to 1. We verify the sum of the averaging coefficients of the attained solution w twa on CIFAR-100 with PreResNet-164 model and observe that ∑ n i=1 α i = 1.02 ±<0.01 .

E NUMERICAL DISCUSSION FOR DDP TRAINING

We numerically measure the averaged epoch training time and memory burden for SGD and TWA in the DDP training setting. Specifically, we experiment with the ResNet-50 model on ImageNet and use 1, 2, and 4 GPUs with a batch size of 256 per GPU and a total of 300 historical solutions. The experiments are conducted on NVIDIA Tesla A100 40G GPUs. From the results reported in Table E4 , we observe that TWA brings minor additional costs, e.g. +2.8% on time cost and +2.9% on memory burden with 4 GPUs, compared with regular SGD training. The additional memory burden becomes even minor with more GPUs. This shows that TWA could provide efficient and scalable averaging for large-scale problems. 

F ABLATION STUDY

We conduct an ablation study in Table F5 to analyze the impact of the regularization coefficient λ. We observe that such regularization brings improvements but is not significant. This is because the main regularization effects come from the significant decrease of training variables, i.e., regular training has D variables but TWA contains only n. Since such regularization is easy to implement and virtually brings little training cost, we include it in our method.



Available at https://github.com/pytorch/examples/tree/main/imagenet. Available at https://huggingface.co/



Figure 1: TWA intuition.

Figure 2: An efficient parallel scheme for distributed training.

Figure 3: Performance comparisons on before and after TWA w.r.t. different epochs of weights used. "SGD final" indicates the accuracy reached by regular SGD training and "TWA" is the corresponding accuracy reached by Algorithm 1 with these epochs of weights. The final accuracy of SGD training is plotted for reference. TWA dramatically lifts the SGD accuracy and outperforms the final accuracy of SGD within 100 epochs. The experiments are repeated over 3 trials.

Figure 4: Left: Scaled learning rate schedules with different scaling factors; Middle and Right: Test accuracy curves of TWA w.r.t. to different schedules on CIFAR-10/100. Training in subspace shows high robustness to scaled learning rate, which enlarges the learning rate and reduces the corresponding training epochs. In this way, TWA achieves very fast convergence.

Figure 5: Comparisons with EMA.

Figure 6: Averaging coefficients of TWA.

Figure 7: Train loss and test error surface of TWA and SWA with different SWA_LR.

DISCUSSION ON THE SUM-ONE CONSTRAINT Let α = [α 1 , α 2 , • • • , α n ] ⊤ ∈ R n , W = [w 1 , w 2 , • • • , w n ] ∈ R D×n ,we have w twa = W α. We multiply W ⊤ on the both sides, i.e., W ⊤ w twa = W ⊤ W α, and could obtain α = ( W ⊤ W ) -1 W ⊤ w twa . Further, we could establish the relation between α and β: α = ( W ⊤ W ) -1 W ⊤ P β, since they are the coordinates of w twa under two different set of bases.

Test accuracy (%) on CIFAR-10/100 for tail training with different learning rates

Top-1 accuracy (%) on ImageNet for tail training with different averaging epochs



For ImageNet, the efforts required for each epoch training are much greater, and hence efficient methods to reduce the training epochs are highly desirable. The comparison results of SGD/SWA/TWA are shown in Table4. Besides the reduced generalization gap, TWA takes only 2 epochs to average the historical solutions of the first 60 epochs, reaching comparable or even better performance than regular SGD training with 90 epochs. For comparison, Lookahead(Zhang et al., 2019) is another advanced optimizer recently proposed for improving convergence and reported 75.49% accuracy at the 60th epoch (Table2inZhang et al. (



Wall-clock time per epoch

Test accuracy (%) and generalization gap (%) on CIFAR-10/100 with Adam Optimizer

D×n , the optimization target for TWA can be formulated as,

Epoch 10-60 Epoch 20-60 Epoch 30-60 Epoch 40-60 Epoch 50-60 Epoch 0-60 SWA with different starting epochs.

Time and memory comparisons of SGD and TWA with DDP training.

ACKNOWLEDGEMENTS

We are very grateful for anonymous reviewers for the valuable feedback on the paper. We thank Minqi Chen at Huawei Technologies for the great support. The research leading to these results has received funding from National Natural Science Foundation of China ( 61977046), Shanghai Science and Technology Program (22511105600), and Shanghai Municipal Science and Technology Major Project (2021SHZDZX0102).

annex

We use the same schedule and hyper-parameters as in Izmailov et al. (2018) . For VGG-16, we use weight decay of 5 × 10 -4 and train the model for 300 epochs with weight averaging at 161 to 300 epochs. For PreAct ResNet-164, we use weight decay of 3 × 10 -4 and train the model for 225 epochs with weight averaging at 126 to 225 epochs.For TWA training, we use the same weights as SWA and initialize w twa as the last checkpoint (i.e. 300 / 225 epochs). We train the models for 10 epochs with an initial learning rate of 2 and decay it by 10 at the 5th and 8th epochs. The regularization coefficient λ is set to 5 × 10 -5 .A.1.2 IMAGENET Following Izmailov et al. (2018) ; Yang et al. (2019) , we start from pre-trained models (they are from torchvision.models) and collect weights by running SGD optimizer up to 10 epochs (with a constant learning rate 0.005, weight decay 1 × 10 -4 ). We uniformly sample the solutions 5 times per epoch.For TWA training, we use the same weights as SWA and initialize w twa as the pre-trained model.For ImageNet, there are many iterations in one epoch, and hence we conduct TWA training for one epoch, in which we linearly decay the learning rate from 0.03 to 0. The regularization coefficient λ is set to 1 × 10 -5 .

A.2.1 CIFAR

For regular training, we train the models for 200 epochs with an initial learning rate of 0.1 and decay it by 10 at the 100th and the 150th epoch. We use SGD optimizer with momentum 0.9, weight decay 1 × 10 -4 , and batch size 128 by convention.For TWA training, we initialize w twa as 1 n ∑ n i=1 w i , i.e., the center of sampled solutions. We train the models for 10 epochs with an initial learning rate of 2 and decay it by 10 at the 5th and 8th epochs. The regularization coefficient λ is set to 1 × 10 -5 .

A.2.2 IMAGENET

We follow the training protocol described in He et al. (2016) . Specifically, we train the models for 90 epochs with an initial learning rate of 0.1 and decay it by a factor of 10 every 30 epochs. We use SGD optimizer with momentum 0.9, weight decay 1 × 10 -4 , and batch size 256.For TWA training, we uniformly sample solutions 5 times per epoch and initialize w twa as 1 n ∑ n i=1 w i . We train the models for 2 epochs with a learning rate of 0.3 and 0.03, respectively. For the extra one epoch training, we use the same training protocol as in subsection A.1.2, i.e., linearly decaying the learning rate from 0.03 to 0. The regularization coefficient λ is set to 1 × 10 -5 .

A.2.3 RESULTS ON ADAM OPTIMIZER

Adam (Kingma & Ba, 2015) is another mainstream optimizer with adaptive gradient descent, which enjoys fast convergence and insensitivity to the initial learning rate. Here, we apply TWA to the solutions generated by Adam optimizer and the results are in Table A2 . The training settings are the same as in Table 3 with default β 1 = 0.9 and β 2 = 0.999 for Adam. TWA is trained for 5 epochs Published as a conference paper at ICLR 2023 

G ADDITIONAL RESULTS ON NLP TASKS

For NLP datasets, we try a finetune task with pre-trained models and compare the performance of SWA and TWA. Specifically, we experiment with The Corpus of Linguistic Acceptability (CoLA), a text classification task in the General Language Understanding Evaluation (GLUE, Wang et al. (2018) ) benchmark. In experiment, we use a pre-trained BERT (Devlin et al., 2018) model (bert-base-uncased) from Hugging Face community 2 . We fine-tune BERT on CoLA for 3 epochs with AdamW optimizer Loshchilov & Hutter (2017) , learning rate 2e-5, and weight decay 0.0. The model weights at the end of these epochs are collected for SWA and TWA. In TWA, we train the fine-tuned model for 1 epoch with a learning rate of 0.5 and regularization coefficients λ = 0.001. From the results below, we observe that TWA could achieve better performance than the competing methods. This further demonstrates the broad application of TWA. Table G6 : Fine-tune results on CoLA.

