TRAVERSING BETWEEN MODES IN FUNCTION SPACE FOR FAST ENSEMBLING

Abstract

Deep ensemble is a simple yet powerful way to improve the performance of deep neural networks. Under this motivation, recent works on mode connectivity have shown that parameters of ensembles are connected by low-loss subspaces, and one can efficiently collect ensemble parameters in those subspaces. While this provides a way to efficiently train ensembles, for inference, one should still execute multiple forward passes using all the ensemble parameters, which often becomes a serious bottleneck for real-world deployment. In this work, we propose a novel framework to reduce such costs. Given a low-loss subspace connecting two modes of a neural network, we build an additional neural network predicting outputs of the original neural network evaluated at a certain point in the low-loss subspace. The additional neural network, what we call a "bridge", is a lightweight network taking minimal features from the original network, and predicting outputs for the low-loss subspace without forward passes through the original network. We empirically demonstrate that we can indeed train such bridge networks and significantly reduce inference costs with the help of the bridge networks.

1. INTRODUCTION

Deep Ensemble (DE) (Lakshminarayanan et al., 2017) is a simple algorithm to improve both predictive accuracy and uncertainty calibration of deep neural networks, where a neural network is trained multiple times using the same data but with different random seeds. Due to this randomness, the parameters obtained from the multiple training runs reach different local optima, called modes, on the loss surface (Fort et al., 2019) . These parameters represent a set of diverse functions serving as an effective approximation for Bayesian Model Averaging (BMA) (Wilson and Izmailov, 2020 ). An apparent drawback of DE is that it requires multiple training runs. This cost can be huge especially for large-scale settings for which parallel training is not feasible. Garipov et al. (2018); Draxler et al. (2018) showed that modes in the loss surface of a deep neural network are connected by relatively simple low-dimensional subspaces where every parameter on those subspaces retains low training error, and the parameters along those subspaces are good candidates for ensembling. Based on this observation, Garipov et al. (2018); Huang et al. (2017) proposed algorithms to quickly construct deep ensembles without having to run multiple independent training runs. While the fast ensembling methods based on mode connectivity reduce training costs, they do not address another important drawback of DE; the inference cost. One should still execute multiple forward passes using all the parameters collected for ensemble, and this cost often becomes critical for a real-world scenario, where the training is done in a resource-abundant setting with plenty of computation time, but for the deployment, the inference should be done in a resource-limited environment. For such settings, reducing the inference cost is much more important than reducing the training cost. In this paper, we propose a novel approach to scale up DE by reducing inference cost. We start from an assumption; if two modes in an ensemble are connected by a simple subspace, we can predict the outputs corresponding to the parameters on the subspace using only the outputs computed from the modes. In other words, we can predict the outputs evaluated at the subspace without having to forward the actual parameters on the subspace through the network. If this is indeed possible, for instance, given two modes, we can approximate an ensemble of three models consisting of parameters collected from three different locations (one from a subspace connecting two modes, and two from each mode) with only two forward passes and a small auxiliary forward pass. We show that we can actually implement this idea using an additional lightweight network whose inference cost is relatively low compared to that of the original neural network. This additional network, what we call a "bridge network", takes some features from the original neural network, (e.g., features from the penultimate layer), and directly predict the outputs computed from the connecting subspace. In other words, the bridge network lets us travel between modes in the function space. We present two types of bridge networks depending on the number of modes involved in prediction, network architectures for bridge networks, and training procedures. Through empirical validation on various image classification benchmarks, we show that 1) bridge networks can predict outputs of connecting subspaces quite accurately with minimal computation cost, and 2) DEs augmented with bridge networks can significantly reduce inference costs without big sacrifice in performance.

2.1. PROBLEM SETUP

In this paper, we discuss the K-way classification problem taking D-dimensional inputs. A classifier is constructed with a deep neural network f θ : R D → R K which is decomposed into a feature extractor f (ft) ϕ : R D → R Dft and a classifier f (cls) ψ : R Dft → R K , i.e., f θ (x) = f (cls) ψ • f (ft) ϕ (x). Here, ϕ ∈ Φ and ψ ∈ Ψ denote the parameters for the feature extractor and classifier, respectively, θ = (ϕ, ψ) ∈ Θ, and D ft is the dimension of the feature. An output from the classifier corresponds to a class probability vector.

2.2. FINDING LOW-LOSS SUBSPACES

While there are few low-loss subspaces that are known to connect modes of deep neural networks, in this paper, we focus on Bezier curves as suggested in (Garipov et al., 2018) . Let θ i and θ j be two parameters (usually corresponding to modes) of a neural network. The quadratic Bezier curve between them is defined as (1 -r) 2 θ i + 2r(1 -r)θ (be) i,j + r 2 θ j | r ∈ [0, 1] , where θ (be) i,j is a pin-point parameter characterizing the curve. Based on this curve paramerization, a low-loss subspace connecting (θ i , θ j ) is found by minimizing the following loss w.r.t. θ  i,j (r) = (1 -r) 2 θ i + 2r(1 -r)θ (be) i,j + r 2 θ j ,



Figure 1: Comparing ensembles with a Bezier curve (left) and a type II bridge network (right).

i,j (r) denotes the point at the position r of the curve, θ

