TRAINING INDEPENDENT SUBNETWORKS FOR ROBUST PREDICTION

Abstract

Recent approaches to efficiently ensemble neural networks have shown that strong robustness and uncertainty performance can be achieved with a negligible gain in parameters over the original network. However, these methods still require multiple forward passes for prediction, leading to a significant computational cost. In this work, we show a surprising result: the benefits of using multiple predictions can be achieved 'for free' under a single model's forward pass. In particular, we show that, using a multi-input multi-output (MIMO) configuration, one can utilize a single model's capacity to train multiple subnetworks that independently learn the task at hand. By ensembling the predictions made by the subnetworks, we improve model robustness without increasing compute. We observe a significant improvement in negative log-likelihood, accuracy, and calibration error on CIFAR10, CIFAR100, ImageNet, and their out-of-distribution variants compared to previous methods.

1. INTRODUCTION

Uncertainty estimation and out-of-distribution robustness are critical problems in machine learning. In medical applications, a confident misprediction may be a misdiagnosis that is not referred to a physician as during decision-making with a "human-in-the-loop." This can have disastrous consequences, and the problem is particularly challenging as patient data deviates significantly from the training set such as in demographics, disease types, epidemics, and hospital locations (Dusenberry et al., 2020b; Filos et al., 2019) . Using a distribution over neural networks is a popular solution stemming from classic Bayesian and ensemble learning literature (Hansen & Salamon, 1990; Neal, 1996) , and recent advances such as BatchEnsemble and extensions thereof achieve strong uncertainty and robustness performance (Wen et al., 2020; Dusenberry et al., 2020a; Wenzel et al., 2020) . These methods demonstrate that significant gains can be had with negligible additional parameters compared to the original model. However, these methods still require multiple (typically, 4-10) forward passes for prediction, leading to a significant runtime cost. In this work, we show a surprising result: the benefits of using multiple predictions can be achieved "for free" under a single model's forward pass. The insight we build on comes from sparsity. Neural networks are heavily overparameterized models. The lottery ticket hypothesis (Frankle & Carbin, 2018) and other works on model pruning (Molchanov et al., 2016; Zhu & Gupta, 2017) show that one can prune away 70-80% of the connections in a neural network without adversely affecting performance. The remaining sparse subnetwork, called the winning ticket, retains its predictive accuracy. This suggests that a neural network has sufficient capacity to fit 3-4 independent subnetworks simultaneously. We show that, using a multi-input multioutput (MIMO) configuration, we can concurrently train multiple independent subnetworks within one network. These subnetworks co-habit the network without explicit separation. The advantage of doing this is that at test time, we can evaluate all of the subnetworks at the same time, leveraging the benefits of ensembles in a single forward pass. Our proposed MIMO configuration only requires two changes to a neural network architecture. First, replace the input layer: instead of taking a single datapoint as input, take M datapoints as inputs, where M is the desired number of ensemble members. Second, replace the output layer: instead of a single head, use M heads that make M predictions based on the last hidden layer. During training, the inputs are sampled independently from the training set and each of the M heads is trained to predict its matching input (Figure 1a ). Since, the features derived from the other inputs are not useful for predicting the matching input, the heads learn to ignore the other inputs and make their predictions independently. At test time, the same input is repeated M times. That is, the heads make M independent predictions on the same input, forming an ensemble for a single robust prediction that can be computed in a single forward pass (Figure 1b ). The core component of an ensemble's robustness such as in Deep Ensembles is the diversity of its ensemble members (Fort et al., 2019) . While it is possible that a single network makes a confident misprediction, it is less likely that multiple independently trained networks make the same mistake. Our model operates on the same principle. By realizing multiple independent winning lottery tickets, we are reducing the impact of one of them making a confident misprediction. For this method to be effective, it is essential that the subnetworks make independent predictions. We empirically show that the subnetworks use disjoint parts of the network and that the functions they represent have the same diversity as the diversity between independently trained neural networks.

Summary of contributions.

1. We propose a multi-input multi-output (MIMO) configuration to network architectures, enabling multiple independent predictions in a single forward pass "for free." Ensembling these predictions significantly improves uncertainty estimation and robustness with minor changes to the number of parameters and compute cost. 2. We analyze the diversity of the individual members and show that they are as diverse as independently trained neural networks. 3. We demonstrate that when adjusting for wall-clock time, MIMO networks achieve new state-ofthe-art on CIFAR10, CIFAR100, ImageNet, and their out-of-distribution variants.

2. MULTI-INPUT MULTI-OUTPUT NETWORKS

The MIMO model is applicable in a supervised classification or regression setting. Denote the set of training examples X = {(x (n) , y (n) )} N n=1 where x (n) is the n th datapoint with the corresponding label y (n) and N is the size of the training set. In the usual setting, for an input x, the output



Figure 1: In the multi-input multi-output (MIMO) configuration, the network takes M = 3 inputs and gives M outputs. The hidden layers remain unchanged. The black connections are shared by all subnetworks, while the colored connections are for individual subnetworks. (a) During training, the inputs are independently sampled from the training set and the outputs are trained to classify their corresponding inputs. (b) During testing, the same input is repeated M times and the outputs are averaged in an ensemble to obtain the final prediction.

