ROBUSTNESS TO CORRUPTION IN PRE-TRAINED BAYESIAN NEURAL NETWORKS

Abstract

We develop ShiftMatch 1 , a new training-data-dependent likelihood for robustness to corruption in Bayesian neural networks (BNNs). ShiftMatch is inspired by the training-data-dependent "EmpCov" priors from Izmailov et al. (2021a), and efficiently matches test-time spatial correlations to those at training time. Critically, ShiftMatch is designed to leave the neural network's training time likelihood unchanged, allowing it to use publicly available samples from pre-trained BNNs. Using pre-trained HMC samples, ShiftMatch gives strong performance improvements on CIFAR-10-C, outperforms EmpCov priors (though ShiftMatch uses extra information from a minibatch of corrupted test points), and is perhaps the first Bayesian method capable of convincingly outperforming plain deep ensembles.

1. INTRODUCTION

Neural networks are increasingly being deployed in real-world, safety-critical settings such as selfdriving cars (Bojarski et al., 2016) and medical imaging (Esteva et al., 2017) . Accurate uncertainty estimation in these settings is critical, and a common approach is to use Bayesian neural networks (BNNs) to reason explicitly about uncertainty in the weights (MacKay, 1992; Neal, 2012; Graves, 2011; Blundell et al., 2015; Gal & Ghahramani, 2016; Maddox et al., 2019; Aitchison, 2020a; b; Ober & Aitchison, 2021; Unlu & Aitchison, 2021; Khan & Rue, 2021; Daxberger et al., 2021) . BNNs are indeed highly effective at improving uncertainty estimation in the in-distribution setting, where the train and test distributions are equal (Zhang et al., 2019; Izmailov et al., 2021b) . Critically, we also need to continue to perform effectively (or at least degrade gracefully) when presented with corrupted inputs. Superficially, BNNs seem like a good choice for this setting: we would hope they would give more uncertainty in regions far from the training data, and thus degrade gracefully as inputs become gradually more corrupted, and thus diverge from the training data. However, recent work has highlighted that BNNs including with gold-standard Hamiltonian Monte Carlo (HMC) inference can fail to generalise to corrupted images, potentially performing worse than ensembles (Lakshminarayanan et al., 2017; Ovadia et al., 2019; Izmailov et al., 2021a; b). Izmailov et al. (2021a) gave a key intuition as to why this failure might occur. In particular, consider directions in input space with zero variance under the training data. As the weights in this direction have little or no effect on the output, any weight regularisation reduces the weights in these directions to zero. The zero weights imply that these directions continue to have no effect on the outputs, even if corruption subsequently increases variance in these input directions. However, BNNs do not work like this. BNNs sample weights in these zero-input-variance directions from the prior. That is fine in the training data domain, as there is no variance in the input in those directions, so the non-zero weights do not affect the outputs. However, if corruption subsequently increases input variance in those directions, then those new high-variance inputs will interact with the non-zero weights to corrupt the output.

Izmailov et al. (2021a

) suggested an approach for fixing this issue, by modifying the prior over weights at the input layer to reduce the variance in the prior over weights in directions where the inputs have little variance. While their approach did outperform BNNs with standard Gaussian priors, it performed comparably to deep ensembles (Izmailov et al., 2021a, their Figs. 4,11,12) . This failure is surprising: part of the promise of BNNs is that they should perform well for corrupted inputs by giving more uncertainty away from the training data. However, the performance of any Bayesian method on in-distribution and corrupted inputs in depends heavily on the choice of model (prior and likelihood combined). We begin by noting that the training-data-dependent EmpCov priors from Izmailov et al. (2021a) can be equivalently viewed as training-data-dependent likelihoods. We then develop a new training-data-dependent likelihood, ShiftMatch, which has two advantages over EmpCov priors (Izmailov et al., 2021a) : First, EmpCov priors apply only to the input layer, so might not be effective for more complex corruptions which are best understood and corrected at later layers. In contrast, our likelihoods modify the activity at every layer in the network, so have the potential to fix complex, nonlinear corruptions. Second, EmpCov modifies the prior over weights, preventing the use of publically available samples from BNNs with standard Gaussian priors (Izmailov et al., 2021b) We found that ShiftMatch improved considerably over all previous baselines for BNN robustness to corruption, including BNNs with standard Gaussian and EmpCov priors, and non-Bayesian methods like plain deep ensembles (Fig. 1 ,3). Further, ShiftMatch can be combined with non-Bayesian methods: we found significant improvement for all methods tested: stochastic gradient descent (SGD), ensembles, and Bayes (HMC) (Fig. 4 ), and ShiftMatch can be scaled to ImageNet (Fig. 5 ), where it offers improved performance over test-time batchnorm (Nado et al., 2020) .

2. BACKGROUND

In BNNs, we use Bayes theorem to compute a posterior distribution over the weights w given the training input, X train , and labels, y train , P (w|X train , y train ) = P (y train |X train , w) P (w) P (y train |X train ) . (1)



Code available at https://github.com/xidulu/ShiftMatch



Recht et al., 2019)  uses the 2000 new images from CIFAR-10.1 to compute the training statistics, demonstrating that we get only a small performance penalty if the full training data is not available, but we can get access to data from roughly the same distribution. Intensity 0 stands for clean CIFAR-10 test set without corruption. ShiftMatch significantly improves the robustness against corruption compared with plain HMC and even outperforms Deep Ensemble, but uses additional information from a minibatch of corrupted test datapoints, which is not used by the other methods.

, which is especially important as some gold-standard Bayesian sampling methods are extremely expensive (e.g. by Hamiltonian Monte Carlo, HMC inIzmailov et al. (2021b)  took one hour to get a sample on a ResNet trained on CIFAR-10 using "a cluster of 512TPUs" Izmailov et al., 2021b). In contrast, ShiftMatch keeps the prior and training-time likelihood unchanged, allowing us to directly re-use gold-standard HMC samples fromIzmailov et al. (2021b). Indeed, ShiftMatch is highly efficient as it does not require further retraining or fine-tuning at test time, allowing us to e.g. use a very large batch size.

