WASSERSTEIN BARYCENTER-BASED MODEL FUSION AND LINEAR MODE CONNECTIVITY OF NEURAL NET-WORKS

Abstract

Based on the concepts of Wasserstein barycenter (WB) and Gromov-Wasserstein barycenter (GWB), we propose a unified mathematical framework for neural network (NN) model fusion and utilize it to reveal new insights about the linear mode connectivity of SGD solutions. In our framework, the fusion occurs in a layer-wise manner and builds on an interpretation of a node in a network as a function of the layer preceding it. The versatility of our mathematical framework allows us to talk about model fusion and linear mode connectivity for a broad class of NNs, including fully connected NN, CNN, ResNet, RNN, and LSTM, in each case exploiting the specific structure of the network architecture. We present extensive numerical experiments to: 1) illustrate the strengths of our approach in relation to other model fusion methodologies and 2) from a certain perspective, provide new empirical evidence for recent conjectures which say that two local minima found by gradient-based methods end up lying on the same basin of the loss landscape after a proper permutation of weights is applied to one of the models.

1. INTRODUCTION

The increasing use of edge devices like mobile phones, tablets, and vehicles, along with the sophistication in sensors present in them (e.g. cameras, GPS, and accelerometers), has led to the generation of an enormous amount of data. However, data privacy concerns, communication costs, bandwidth limits, and time sensitivity prevent the gathering of local data from edge devices into one single centralized location. These obstacles have motivated the design and development of federated learning strategies which are aimed at pooling information from locally trained neural networks (NNs) with the objective of building strong centralized models without relying on the collection of local data McMahan et al. (2017); Kairouz et al. (2019) . Due to these considerations, the problem of NN fusion-i.e. combining multiple models which were trained differently into a single model-is a fundamental task in federated learning. A standard fusion method for aggregating models with the same architecture is FedAvg McMahan et al. (2017) , which involves element-wise averaging of the parameters of local models. This is also known as vanilla averaging Singh & Jaggi (2019). Although easily implementable, vanilla averaging performs poorly when fusing models whose weights do not have a one-to-one correspondence. This happens because even when models are trained on the same dataset it is possible to obtain models that differ only by a permutation of weights Wang et al. (2020); Yurochkin et al. (2019) ; this feature is known as permutation invariance property of neural networks. Moreover, vanilla averaging is not naturally designed to work when using local models with different architectures (e.g., different widths). In order to address these challenges, Singh & Jaggi (2019) proposed to first find the best alignment between the neurons (weights) of different networks by using optimal transport (OT) Villani ( 2008 2021) has conjectured that local minima found by SGD do end up lying on the same basin of the loss landscape after a proper permutation of weights is applied to one of the models. The question of how to find these desired permutations remains in general elusive. The purpose of this paper is twofold. On one hand, we present a large family of barycenter-based fusion algorithms that can be used to aggregate models within the families of fully connected NNs, CNNs, ResNets, RNNs and LSTMs. The most general family of fusion algorithms that we introduce relies on the concept of Gromov-Wasserstein barycenter (GWB), which allows us to use the information in hidden-to-hidden layers in RNNs and LSTMs in contrast to previous approaches in the literature like that proposed in Wang et al. (2020) . In order to motivate the GWB based fusion algorithm for RNNs and LSTMs, we first discuss a Wasserstein barycenter (WB) based fusion algorithm for fully connected, CNN, and ResNet models which follows closely the OT fusion algorithm from Singh & Jaggi (2019) . By creating a link between the NN model fusion problem and the problem of computing Wasserstein (or Gromov-Wasserstein) barycenters, our aim is to exploit the many tools that have been developed in the last decade for the computation of WB (or GWB) -see the Appendix for references-and to leverage the mathematical structure of OT problems. Using our framework, we are able to fuse models with different architectures and build target models with arbitrary specified dimensions (at least in terms of width). On the other hand, through several numerical experiments in a variety of settings (architectures and datasets), we provide new evidence backing certain aspects of the conjecture put forward in Entezari et al. (2021) about the local structure of NNs' loss landscapes. Indeed, we find out that there exist sparse couplings between different models that can map different local minima found by SGD into basins that are only separated by low energy barriers. These sparse couplings, which can be thought of as approximations to actual permutations, are obtained using our fusion algorithms, which, surprisingly, only use training data to set the values of some hyperparameters. We explore this conjecture in imaging and natural language processing (NLP) tasks and provide visualizations of our findings. Consider, for example, Figure 1 (left), which is the visualization of fusing two FC NNs independently trained on the MNIST dataset. We can observe that the basins where model 1 and permuted model 2 (i.e. model 2 after multiplying its weights by the coupling obtained by our fusion algorithm) land are close to each other and are only separated by low energy barriers. Our main contributions can then be summarized as follows: (a) we formulate the network model fusion problem as a series of Wasserstein (Gromov-Wasserstein) barycenter problems, bridging in this way the NN fusion problem with computational OT; (b) we empirically demonstrate that our framework is highly effective at fusing different types of networks, including RNNs and LSTMs. (c) we visualize the result of our fusion algorithm when aggregating two neural networks in a 2D-plane. By doing this we not only provide some illustrations on how our fusion algorithms perform, but also present empirical evidence for the conjecture made in Entezari et al. (2021) , casting light over the loss landscape of a variety of neural networks. At the time of completing this work, we became aware of two very recent preprints which also explore the conjecture made in Entezari et al. (2021) empirically. In particular, Ainsworth et al.



); Santambrogio(2015);Peyré & Cuturi (2018)  and then carrying out a vanilla averaging step. Other approaches, like those proposed inWang et al. (2020);Yurochkin et al. (2019), interpret nodes of local models as random permutations of latent "global nodes" modeled according to a Beta-Bernoulli process priorThibaux & Jordan (2007). By using "global nodes", nodes from different input NNs can be embedded into a common space where comparisons and aggregation are meaningful. Most works in the literature discussing the fusion problem have mainly focused on the aggregation of fully connected (FC) neural networks and CNNs, but have not, for the most part, explored other kinds of architectures like RNNs and LSTMs. One exception to this general state of the art is the workWang et al. (2020), which considers the fusion of RNNs by ignoring hidden-to-hidden weights during the neurons' matching, thus discarding some useful information in the pre-trained RNNs. For more references on the fusion problem see in the Appendix.A different line of research that has attracted considerable attention in the past few years is the quest for a comprehensive understanding of the loss landscape of deep neural networks, a fundamental component in studying the optimization and generalization properties of NNs Li et al. (2018); Mei et al. (2018); Neyshabur et al. (2017); Nguyen et al. (2018); Izmailov et al. (2018). Due to overparameterization, scale, and permutation invariance properties of neural networks, the loss landscapes of DNNs have many local minima Keskar et al. (2016); Zhang et al. (2021). Different works have asked and answered affirmatively the question of whether there exist paths of small-increasing loss connecting different local minima found by SGD Garipov et al. (2018); Draxler et al. (2018). This phenomenon is often referred to as mode connectivity Garipov et al. (2018) and the loss increase along paths between two models is often referred to as (energy) barrier Draxler et al. (2018). It has been observed that low-barrier paths are non-linear, i.e., linear interpolation of two different models will not usually produce a neural network with small loss. These observations suggest that, from the perspective of local structure properties of loss landscapes, different SGD solutions belong to different (well-separated) basins Neyshabur et al. (2020). However, recent work Entezari et al. (

