GIT RE-BASIN: MERGING MODELS MODULO PERMU-TATION SYMMETRIES

Abstract

The success of deep learning is due in large part to our ability to solve certain massive non-convex optimization problems with relative ease. Though non-convex optimization is NP-hard, simple algorithms -often variants of stochastic gradient descent -exhibit surprising effectiveness in fitting large neural networks in practice. We argue that neural network loss landscapes often contain (nearly) a single basin after accounting for all possible permutation symmetries of hidden units a la Entezari et al. (2021). We introduce three algorithms to permute the units of one model to bring them into alignment with a reference model in order to merge the two models in weight space. This transformation produces a functionally equivalent set of weights that lie in an approximately convex basin near the reference model. Experimentally, we demonstrate the single basin phenomenon across a variety of model architectures and datasets, including the first (to our knowledge) demonstration of zero-barrier linear mode connectivity between independently trained ResNet models on CIFAR-10. Additionally, we investigate intriguing phenomena relating model width and training time to mode connectivity. Finally, we discuss shortcomings of the linear mode connectivity hypothesis, including a counterexample to the single basin theory.

1. INTRODUCTION

We investigate the unreasonable effectiveness of stochastic gradient descent (SGD) algorithms on the high-dimensional non-convex optimization problems of deep learning. In particular, 1. Why does SGD thrive in optimizing high-dimensional non-convex deep learning loss landscapes despite being noticeably less robust in other non-convex optimization settings, like policy learning (Ainsworth et al., 2021) , trajectory optimization (Kelly, 2017), and recommender systems (Kang et al., 2016)? 2. What are all the local minima? When linearly interpolating between initialization and final trained weights, why does the loss smoothly and monotonically decrease (Goodfellow & Vinyals, 2015; Frankle, 2020; Lucas et al., 2021; Vlaar & Frankle, 2021 )? 3. How can two independently trained models with different random initializations and data batch orders inevitably achieve nearly identical performance? Furthermore, why do their training loss curves often look identical? We posit that these phenomena point to the existence of some yet uncharacterized invariance(s) in the training dynamics causing independent training runs to exhibit similar characteristics. Hecht-Nielsen (1990) noted the permutation symmetries of hidden units in neural networks; briefly, one can swap any two units of a hidden layer in a network and -assuming weights are adjusted accordingly -network functionality will not change. We refer to such solutions as being linearly mode connected (LMC) (Frankle et al., 2020) , an extension of mode connectivity (Garipov et al., 2018; Draxler et al., 2018) . If true, Conjecture 1 will both materially expand our understanding of how SGD works in the context of deep learning and offer a credible explanation for the preceding phenomena, in particular. Contributions. In this paper, we attempt to uncover what invariances may be responsible for the phenomena cited above and the unreasonable effectiveness of SGD in deep learning. We make the following contributions: 1. Matching methods. We propose three algorithms, grounded in concepts and techniques from combinatorial optimization, to align the weights of two independently trained models. Where appropriate, we prove hardness results for these problems and propose approximation algorithms. Our fastest method identifies permutations in mere seconds on current hardware. 2. Relationship to optimization algorithms. We demonstrate by means of counterexample that linear mode connectivity is an emergent property of training procedures, not of model architectures. We connect this result to prior work on the implicit biases of SGD. 3. Experiments, including zero-barrier LMC for ResNets. Empirically, we explore the existence of linear mode connectivity modulo permutation symmetries in experiments across MLPs, CNNs, and ResNets trained on MNIST, CIFAR-10, and CIFAR-100. We contribute the first-ever demonstration of zero-barrier LMC between two independently trained ResNets. We explore the relationship between LMC and model width as well as training time. Finally, we show evidence of our methods' ability to combine models trained on independent datasets into a merged model that outperforms both input models in terms of test loss (but not accuracy) and is no more expensive in compute or memory than either input model.

2. BACKGROUND

Although our methods can be applied to arbitrary model architectures, we proceed with the multilayer perceptron (MLP) for its ease of presentation (Bishop, 2007) . Consider an L-layer MLP, f (x; Θ) = z L+1 , z ℓ+1 = σ(W ℓ z ℓ + b ℓ ), z 1 = x, where σ denotes an element-wise nonlinear activation function. Furthermore, consider a loss, L(Θ), that measures the suitability of a particular set of weights Θ towards some goal, e.g., fitting to a training dataset. Central to our investigation is the phenomenon of permutation symmetries of weight space. Given Θ, we can apply some permutation to the output features of any intermediate layer, ℓ, of the model, denoted by a permutation matrix P ∈ S d ,foot_0 z ℓ+1 = P ⊤ P z ℓ+1 = P ⊤ P σ(W ℓ z ℓ + b ℓ ) = P ⊤ σ(P W ℓ z ℓ + P b ℓ ) for σ, an element-wise operator. It follows that as long as we reorder the input weights of layer ℓ + 1 according to P ⊤ , we will have a functionally equivalent model. To be precise, if we define Θ ′ to be identical to Θ with the exception of W ′ ℓ = P W ℓ , b ′ ℓ = P b ℓ , W ′ ℓ+1 = W ℓ+1 P ⊤ ,



We denote the set of all d × d permutation matrices -isomorphic to the symmetric group -as S d , to the possible chagrin of pure mathematicians.



Permutation symmetries of deep learning models vs. an upper estimate on the number of atoms in the known, observable universe. Deep learning loss landscapes contain incomprehensible amounts of geometric repetition.

