NEURAL NETWORKS WITH LATE-PHASE WEIGHTS

Abstract

The largely successful method of training neural networks is to learn their weights using some variant of stochastic gradient descent (SGD). Here, we show that the solutions found by SGD can be further improved by ensembling a subset of the weights in late stages of learning. At the end of learning, we obtain back a single model by taking a spatial average in weight space. To avoid incurring increased computational costs, we investigate a family of low-dimensional late-phase weight models which interact multiplicatively with the remaining parameters. Our results show that augmenting standard models with late-phase weights improves generalization in established benchmarks such as CIFAR-10/100, ImageNet and enwik8. These findings are complemented with a theoretical analysis of a noisy quadratic problem which provides a simplified picture of the late phases of neural network learning.

1. INTRODUCTION

Neural networks trained with SGD generalize remarkably well on a wide range of problems. A classic technique to further improve generalization is to ensemble many such models (Lakshminarayanan et al., 2017) . At test time, the predictions made by each model are combined, usually through a simple average. Although largely successful, this technique is costly both during learning and inference. This has prompted the development of ensembling methods with reduced complexity, for example by collecting models along an optimization path generated by SGD (Huang et al., 2017) , by performing interpolations in weight space (Garipov et al., 2018) , or by tying a subset of the weights over the ensemble (Lee et al., 2015; Wen et al., 2020 ). An alternative line of work explores the use of ensembles to guide the optimization of a single model (Zhang et al., 2015; Pittorino et al., 2020) . We join these efforts and develop a method that fine-tunes the behavior of SGD using late-phase weights: late in training, we replicate a subset of the weights of a neural network and randomly initialize them in a small neighborhood. Together with the stochasticity inherent to SGD, this initialization encourages the late-phase weights to explore the loss landscape. As the late-phase weights explore, the shared weights accumulate gradients. After training we collapse this implicit ensemble into a single model by averaging in weight space. Building upon recent work on ensembles with shared parameters (Wen et al., 2020) we explore a family of late-phase weight models involving multiplicative interactions (Jayakumar et al., 2020) . We focus on low-dimensional late-phase models that can be ensembled with negligible overhead. Our experiments reveal that replicating the ubiquitous batch normalization layers (Ioffe & Szegedy, 2015) is a surprisingly simple and effective strategy for improving generalizationfoot_0 . Furthermore, we find that late-phase weights can be combined with stochastic weight averaging (Izmailov et al., 2018) , a complementary method that has been shown to greatly improve generalization.

2.1. LEARNING WITH LATE-PHASE WEIGHTS

Late-phase weights. To apply our learning algorithm to a given neural network model f w we first specify its weights w in terms of two components, base and late-phase (θ and φ, resp.). The two components interact according to a weight interaction function w = h(θ, φ). Base weights are learned throughout the entire training session, and until time step T 0 both θ and φ are learned and treated on equal grounds. At time step T 0 , a hyperparameter of our algorithm, we introduce K late-phase components Φ = {φ k } K k=1 , that are learned together with θ until the end. This procedure yields a late-phase ensemble of K neural networks with parameter sharing: reusing the base weights θ, each late-phase weight φ k defines a model with parameters w k = h(θ, φ k ). Late-phase weight averaging at test time. Our ensemble defined by the K late-phase weight configurations in Φ is kept only during learning. At test time, we discard the ensemble and obtain a single model by averaging over the K late-phase weight components. That is, given some input pattern x, we generate a prediction y(x) using the averaged model, computed once after learning: y(x) = f w (x), w ≡ h θ, 1 K K k=1 φ k . Hence, the complexity of inference is independent of K, and equivalent to that of the original model. Late-phase weight initialization. We initialize our late-phase weights from a reference base weight. We first learn a base parameter φ 0 from time step t = 0 until T 0 , treating φ 0 as any other base parameter in θ. Then, at time t = T 0 , each configuration φ k is initialized in the vicinity of φ 0 . We explore perturbing φ 0 using a symmetric Gaussian noise model, φ k = φ 0 + σ 0 Z(φ 0 ) k , ( ) where k is a standard normal variate of appropriate dimension and σ 0 is a hyperparameter controlling the noise amplitude. We allow for a φ 0 -dependent normalization factor, which we set so as to ensure layerwise scale-invariance, which helps finding a single σ 0 that governs the initialization of the entire network. More concretely, for a given neural network layer l with weights φ (l) 0 of dimension D (l) , we choose Z(φ (l) 0 ) = √ D (l) / φ (l) 0 . Our perturbative initialization (Eq. 2) is motivated by ongoing studies of the nonconvex, highdimensional loss functions that arise in deep learning. Empirical results and theoretical analyses of simplified models point to the existence of dense clusters of connected solutions with a locallyflat geometry (Hochreiter & Schmidhuber, 1997a) that are accessible by SGD (Huang et al., 2017; Garipov et al., 2018; Baldassi et al., 2020) . Indeed, the eigenspectrum of the loss Hessian evaluated at weight configurations found by SGD reveals a large number of directions of low curvature (Keskar et al., 2017; Chaudhari et al., 2019; Sagun et al., 2018) . For not yet completely understood reasons, this appears to be a recurring phenomenon in overparameterized nonlinear problems (Brown & Sethna, 2003; Waterfall et al., 2006) . Based on these observations, we assume that the initial parameter configuration φ 0 can be perturbed in a late phase of learning without leading to mode hopping across the different models w k . While mode coverage is usually a sought after property when learning neural network ensembles (Fort et al., 2020) , here it would preclude us from taking the averaged model at the end of learning (Eq. 1).

Stochastic learning algorithm.

Having decomposed our weights into base and late-phase components, we now present a stochastic algorithm which learns both θ and Φ. Our algorithm works on the standard stochastic (minibatch) neural network optimization setting (Bottou, 2010) . Given a loss function L(D, w) = 1 |D| x∈D L(x, w) to be minimized with respect to the weights w on a set of data D, at every round we randomly sample a subset M from D and optimize instead the stochastic loss L(M, w). However, in contrast to the standard setting, in late stages of learning (t > T 0 ) we simultaneously optimize K parameterizations W := {w k | w k = h(θ, φ k )} K k=1 , instead of one.



We provide code to reproduce our experiments at https://github.com/seijin-kobayashi/ late-phase-weights

