REGRESSION PRIOR NETWORKS

Abstract

Prior Networks are a class of models which yield interpretable measures of uncertainty and have been shown to outperform state-of-the-art ensemble approaches on a range of tasks. They can also be used to distill an ensemble of models via Ensemble Distribution Distillation (EnD 2 ), such that its accuracy, calibration, and uncertainty estimates are retained within a single model. However, Prior Networks have so far been developed only for classification tasks. This work extends Prior Networks and EnD 2 to regression tasks by considering the Normal-Wishart distribution. The properties of Regression Prior Networks are demonstrated on synthetic data, selected UCI datasets, and two monocular depth estimation tasks. They yield performance competitive with ensemble approaches.

1. INTRODUCTION

Neural Networks have become the standard approach to addressing a wide range of machine learning tasks (Girshick, 2015; Simonyan & Zisserman, 2015; Villegas et al., 2017; Mikolov et al., 2013b; a; 2010; Hinton et al., 2012; Hannun et al., 2014; Caruana et al., 2015; Alipanahi et al., 2015) . However, in order to improve the safety of AI systems (Amodei et al., 2016) and avoid costly mistakes in high-risk applications, such as self-driving cars, it is desirable for models to yield estimates of uncertainty in their predictions. Ensemble methods are known to yield both improved predictive performance and robust uncertainty estimates (Gal & Ghahramani, 2016; Lakshminarayanan et al., 2017; Maddox et al., 2019) . Importantly, ensemble approaches allow interpretable measures of uncertainty to be derived via a mathematically consistent probabilistic framework. Specifically, the overall total uncertainty can be decomposed into data uncertainty, or uncertainty due to inherent noise in the data, and knowledge uncertainty, which is due to the model having limited uncertainty of the test data (Malinin, 2019) . Uncertainty estimates derived from ensembles have been applied to the detection of misclassifications, out-of-domain inputs and adversarial attack detection (Carlini & Wagner, 2017; Smith & Gal, 2018) , and active learning (Kirsch et al., 2019) . Unfortunately, ensemble methods may be computationally expensive to train and are always expensive during inference. A class of models called Prior Networks (Malinin & Gales, 2018; 2019; Malinin, 2019; Sensoy et al., 2018) was proposed as an approach to modelling uncertainty in classification tasks by emulating an ensemble using a single model. Prior Networks parameterize a higher order conditional distribution over output distributions, such as the Dirichlet distribution. This enables Prior Networks to efficiently yield the same interpretable measures of total, data and knowledge uncertainty as an ensemble. Unlike ensembles, the behaviour of Prior Networks' higher-order distribution is specified via a loss function, such as reverse KL-divergence (Malinin & Gales, 2019), and training data. However, such Prior Networks yield predictive performance consistent with that of a single model trained via Maximum Likelihood, which is typically worse than that of an ensemble. This can be overcome via Ensemble Distribution Distillation (EnD 2 ) (Malinin et al., 2020) , which is an approach that allows distilling an ensemble into Prior Network such that measures of ensemble diversity are preserved. This enables to retain both the predictive performance and uncertainty estimates of an ensemble at low computational and memory cost. Finally, it is important to point out that a related class of evidential methods has concurrently appeared (Sensoy et al., 2018; Amini et al., 2020) . Structurally they yield models similar to Prior Networks, but are trained in a different fashion. While Prior Networks have many attractive properties, they have only been applied to classification tasks. In this work we develop Prior Networks for regression tasks by considering the Normal-Wishart distribution -a higher-order distribution over the parameters of multivariate normal distributions. Specifically, we extend theoretical work from (Malinin, 2019) , where such models are considered, but never evaluated. We derive all measures of uncertainty, the reverse KL-divergence training objective, and the Ensemble Distribution Distillation objective in closed form. Regression Prior Networks are then evaluated on synthetic data, selected UCI datasets and the NYUv2 and KITTI monocular depth estimation tasks, where they are shown to yield comparable or better performance to state-of-the-art single-model and ensemble approaches. Crucially, they enable, via EnD 2 , to retain the predictive performance and uncertainty estimates of an ensemble within a single model.

2. REGRESSION PRIOR NETWORKS

In this section we develop Prior Network models for regression tasks. While typical regression models yield point-estimate predictions, we consider probabilistic regression models which parameterizes a distribution p(y|x, θ) over the target y ∈ R K . Typically, this is a normal distribution: p(y|x, θ) = N (y|µ, Λ), {µ, Λ} = f (x; θ) (1) where µ is the mean, and Λ the precision matrix, a positive-definite symmetric matrix. While normal distributions are usually defined in terms of the covariance matrix Σ = Λ -1 , parameterization using the precision is more numerically stable during optimization (Bishop, 2006; Goodfellow et al., 2016) . While a range of distributions over continuous random variables can be considered, we will consider the normal as it makes the least assumptions about the nature of y and is mathematically simple. As in the case for classification, we can consider an ensemble of networks which parameterize multivariate normal distributions (MVN) {p(y|x, θ (m) )} M m=1 . This ensemble can be interpreted as a set of draws from a higher-order implicit distribution over normal distributions. A Prior Network for regression would, therefore, emulate this ensemble by explicitly parameterizing a higher-order distribution over the parameters µ and Λ of a normal distribution. One sensible choice is the formidable Normal-Wishart distribution (Murphy, 2012; Bishop, 2006) , which is a conjugate prior to the MVN. This parallels how the Dirichlet distribution, the conjugate prior to the categorical, was used in classification Prior Networks. The Normal-Wishart distribution is defined as follows: N W(µ, Λ|m, L, κ, ν) = N (µ|m, κΛ)W(Λ|L, ν) where m and L are the prior mean and inverse of the positive-definite prior scatter matrix, while κ and ν are the strengths of belief in each prior, respectively. The parameters κ and ν are conceptually similar to precision of the Dirichlet distribution α 0 . The Normal-Wishart is a compound distribution which decomposes into a product of a conditional normal distribution over the mean and a Wishart distribution over the precision. Thus, a Regression Prior Network (RPN) parameterizes the Normal-Wishart distribution over the mean and precision of normal output distributions as follows: p(µ, Λ|x, θ) = N W(µ, Λ|m, L, κ, ν), {m, L, κ, ν} = Ω = f (x; θ) where Ω = {m, L, κ, ν} is the set of parameters of the Normal-Wishart predicted by neural network. The posterior predictive of this model is the multivariate Student's T distribution (Murphy, 2012), which is the heavy-tailed generalization of the multivariate normal distribution: p(y|x, θ) = E p(µ,Λ|x,θ) [p(y|µ, Λ)] = T (y|m, κ + 1 κ(ν -K + 1) L -1 , ν -K + 1) In the limit, as ν → ∞, the T distribution converges to a normal distribution. The predictive posterior of the Prior Network given in equation 4 only has a defined mean and variance when ν > K + 1. Figure 1 depicts the desired behaviour of an ensemble of normal distributions sampled from a Normal-Wishart distribution. Specifically, the ensemble should be consistent for in-domain inputs in regions of low/high data uncertainty, as in figures 1a-b, and highly diverse both in the location of the mean and in the structure of the covariance for out-of-distribution inputs, as in figure 1c . Samples of continuous output distributions from a regression Prior Network should yield the same behaviour. Measures of Uncertainty Given an RPN which displays these behaviours, we can compute closedform expression for all uncertainty measures previously discussed for ensembles and Dirichlet Prior Networks (Malinin, 2019) . We can obtain measures of knowledge, total and data uncertainty by considering the mutual information between y and the parameters of the output distribution {µ, Λ}: Expected Data Uncertainty (5)



[y, {µ, Λ}]    Knowledge Uncertainty= H E p(µ,Λ|x,θ) [p(y|µ, Λ)] Total Uncertainty -E p(µ,Λ|x,θ) H[p(y|µ, Λ)]

