WASSERSTEIN BARYCENTER-BASED MODEL FUSION AND LINEAR MODE CONNECTIVITY OF NEURAL NET-WORKS

Abstract

Based on the concepts of Wasserstein barycenter (WB) and Gromov-Wasserstein barycenter (GWB), we propose a unified mathematical framework for neural network (NN) model fusion and utilize it to reveal new insights about the linear mode connectivity of SGD solutions. In our framework, the fusion occurs in a layer-wise manner and builds on an interpretation of a node in a network as a function of the layer preceding it. The versatility of our mathematical framework allows us to talk about model fusion and linear mode connectivity for a broad class of NNs, including fully connected NN, CNN, ResNet, RNN, and LSTM, in each case exploiting the specific structure of the network architecture. We present extensive numerical experiments to: 1) illustrate the strengths of our approach in relation to other model fusion methodologies and 2) from a certain perspective, provide new empirical evidence for recent conjectures which say that two local minima found by gradient-based methods end up lying on the same basin of the loss landscape after a proper permutation of weights is applied to one of the models.

1. INTRODUCTION

The increasing use of edge devices like mobile phones, tablets, and vehicles, along with the sophistication in sensors present in them (e.g. cameras, GPS, and accelerometers), has led to the generation of an enormous amount of data. However, data privacy concerns, communication costs, bandwidth limits, and time sensitivity prevent the gathering of local data from edge devices into one single centralized location. These obstacles have motivated the design and development of federated learning strategies which are aimed at pooling information from locally trained neural networks (NNs) with the objective of building strong centralized models without relying on the collection of local data McMahan et al. (2017); Kairouz et al. (2019) . Due to these considerations, the problem of NN fusion-i.e. combining multiple models which were trained differently into a single model-is a fundamental task in federated learning. A standard fusion method for aggregating models with the same architecture is FedAvg McMahan et al. (2017) , which involves element-wise averaging of the parameters of local models. This is also known as vanilla averaging Singh & Jaggi (2019). Although easily implementable, vanilla averaging performs poorly when fusing models whose weights do not have a one-to-one correspondence. This happens because even when models are trained on the same dataset it is possible to obtain models that differ only by a permutation of weights Wang et al. (2020); Yurochkin et al. (2019) ; this feature is known as permutation invariance property of neural networks. Moreover, vanilla averaging is not naturally designed to work when using local models with different architectures (e.g., different widths). In order to address these challenges, Singh & Jaggi (2019) proposed to first find the best alignment between the neurons (weights) of different networks by using optimal transport (OT) Villani (2008); Santambrogio (2015); Peyré & Cuturi (2018) and then carrying out a vanilla averaging step. Other approaches, like those proposed in Wang et al. (2020); Yurochkin et al. (2019) , interpret nodes of local models as random permutations of latent "global nodes" modeled according to a Beta-Bernoulli process prior Thibaux & Jordan (2007) . By using "global nodes", nodes from different input NNs can be embedded into a common space where comparisons and aggregation are meaningful. Most works in the literature discussing the fusion problem have mainly focused on the aggregation of fully connected (FC) neural networks and CNNs, but have not, for the most

