ROBUSTNESS TO CORRUPTION IN PRE-TRAINED BAYESIAN NEURAL NETWORKS

Abstract

We develop ShiftMatch 1 , a new training-data-dependent likelihood for robustness to corruption in Bayesian neural networks (BNNs). ShiftMatch is inspired by the training-data-dependent "EmpCov" priors from Izmailov et al. (2021a), and efficiently matches test-time spatial correlations to those at training time. Critically, ShiftMatch is designed to leave the neural network's training time likelihood unchanged, allowing it to use publicly available samples from pre-trained BNNs. Using pre-trained HMC samples, ShiftMatch gives strong performance improvements on CIFAR-10-C, outperforms EmpCov priors (though ShiftMatch uses extra information from a minibatch of corrupted test points), and is perhaps the first Bayesian method capable of convincingly outperforming plain deep ensembles.

1. INTRODUCTION

Neural networks are increasingly being deployed in real-world, safety-critical settings such as selfdriving cars (Bojarski et al., 2016) and medical imaging (Esteva et al., 2017) . Accurate uncertainty estimation in these settings is critical, and a common approach is to use Bayesian neural networks (BNNs) to reason explicitly about uncertainty in the weights (MacKay, 1992; Neal, 2012; Graves, 2011; Blundell et al., 2015; Gal & Ghahramani, 2016; Maddox et al., 2019; Aitchison, 2020a; b; Ober & Aitchison, 2021; Unlu & Aitchison, 2021; Khan & Rue, 2021; Daxberger et al., 2021) . BNNs are indeed highly effective at improving uncertainty estimation in the in-distribution setting, where the train and test distributions are equal (Zhang et al., 2019; Izmailov et al., 2021b) . Critically, we also need to continue to perform effectively (or at least degrade gracefully) when presented with corrupted inputs. Superficially, BNNs seem like a good choice for this setting: we would hope they would give more uncertainty in regions far from the training data, and thus degrade gracefully as inputs become gradually more corrupted, and thus diverge from the training data. However, recent work has highlighted that BNNs including with gold-standard Hamiltonian Monte Carlo (HMC) inference can fail to generalise to corrupted images, potentially performing worse than ensembles (Lakshminarayanan et al., 2017; Ovadia et al., 2019; Izmailov et al., 2021a; b) . Izmailov et al. (2021a) gave a key intuition as to why this failure might occur. In particular, consider directions in input space with zero variance under the training data. As the weights in this direction have little or no effect on the output, any weight regularisation reduces the weights in these directions to zero. The zero weights imply that these directions continue to have no effect on the outputs, even if corruption subsequently increases variance in these input directions. However, BNNs do not work like this. BNNs sample weights in these zero-input-variance directions from the prior. That is fine in the training data domain, as there is no variance in the input in those directions, so the non-zero weights do not affect the outputs. However, if corruption subsequently increases input variance in those directions, then those new high-variance inputs will interact with the non-zero weights to corrupt the output. Izmailov et al. (2021a) suggested an approach for fixing this issue, by modifying the prior over weights at the input layer to reduce the variance in the prior over weights in directions where the inputs have little variance. While their approach did outperform BNNs with standard Gaussian priors, it performed comparably to deep ensembles (Izmailov et al., 2021a, their Figs. 4,11,12) . This failure is surprising: part of the promise of BNNs is that they should perform well for corrupted inputs by giving more uncertainty away from the training data.



Code available at https://github.com/xidulu/ShiftMatch 1

