FAST PREDICTIVE UNCERTAINTY FOR CLASSIFICATION WITH BAYESIAN DEEP NETWORKS

Abstract

In Bayesian Deep Learning, distributions over the output of classification neural networks are approximated by first constructing a Gaussian distribution over the weights, then sampling from it to receive a distribution over the categorical output distribution. This is costly. We extend existing work to construct a Dirichlet approximation of this output distribution, yielding an analytic map between Gaussian distributions in logit space and Dirichlet distributions in the output space. We argue that the resulting Dirichlet distribution has theoretical and practical advantages, in particular more efficient computation of the uncertainty estimate, scaling to large datasets and networks like ImageNet and DenseNet. We demonstrate the use of this Dirichlet approximation by using it to construct a lightweight uncertainty-aware output ranking for the ImageNet setup.

1. INTRODUCTION

Quantifying the uncertainty of Neural Networks' (NNs) predictions is important in safety-critical applications such as medical-diagnosis (Begoli et al., 2019) and self-driving vehicles (McAllister et al., 2017; Michelmore et al., 2018) . Architectures for classification tasks produce a probability distribution as their output, constructed by applying the softmax to the point-estimate output of the penultimate layer. However, it has been shown that this distribution is overconfident (Nguyen et al., 2015; Hein et al., 2019) and thus cannot be used for predictive uncertainty quantification. Approximate Bayesian methods provide quantified uncertainty over the network's parameters and thus the outputs in a tractable fashion. The commonly used Gaussian approximate posterior (MacKay, 1992a; Graves, 2011; Blundell et al., 2015; Ritter et al., 2018) approximately induces a Gaussian distribution over the logits of a NN (Mackay, 1995) . But the associated predictive distribution does not have an analytic form. It is thus generally approximated by Monte Carlo (MC) integration requiring multiple samples. Predictions in Bayesian Neural Networks (BNNs) are thus generally expensive operations. In this paper, we re-consider an old but largely overlooked idea originally proposed by David JC MacKay (1998) in a different setting (arguably the inverse of the Deep Learning setting) which transforms a Dirichlet distribution into a Gaussian. Dirichlet distributions are generally defined on the simplex. But when its variable is defined on the inverse softmax's domain, its shape effectively approximates a Gaussian. The inverse of this approximation, which will be called the Laplace Bridge here (Hennig et al., 2012) , analytically maps a Gaussian distribution onto a Dirichlet distribution. Given a Gaussian distribution over the logits of a NN, one can thus efficiently obtain an approximate Dirichlet distribution over the softmax outputs (Figure 1 ). Our contributions are: We re-visit MacKay's derivation with particular attention to a symmetry constraint that becomes necessary in our "inverted" use of the argument from the Gaussian to the Dirichlet family. We then validate the quality of this approximation both theoretically and empirically, and demonstrate significant speed-up over MC-integration. Finally, we show a use-case, leveraging the analytic properties of Dirichlets to improve the popular top-k metric through uncertainties. Section 2 provides the mathematical derivation. Section 3 discusses the Laplace Bridge in the context of NNs. We compare it to the recent approximations of the predictive distributions of NNs in Section 4. Experiments are presented in Section 5. Figure 1 : Densities on the simplex of the true distribution (top row, computed by MC integration) and "Laplace Bridge" approximation constructed in this paper (bottom row). For column (a) and (b), two different Gaussians were constructed, such that the resulting MAP estimate is the same, but the uncertainty differs. For (c), (d) and (e) the same mean with decreasing uncertainty was used. We find that in all cases the Laplace Bridge is a good approximation and captures the desired properties.

2. THE LAPLACE BRIDGE

Laplace approximations 1 are a popular and light-weight method to approximate a general probability distribution q(x) with a Gaussian N (x|µ, Σ) when q(x) is twice differentiable and the Hessian at the mode is positive definite. It sets µ to a mode of q, and Σ = -(∇ 2 log q(x)| µ ) -1 , the inverse Hessian of log q at that mode. This scheme can work well if the true distribution is unimodal and defined on the real vector space. The Dirichlet distribution, which has the density function Dir(π|α) := Γ K k=1 α k K k=1 Γ(α k ) K k=1 π α k -1 k , is defined on the probability simplex and can be "multimodal" in the sense that the distribution diverges in the k-corner of the simplex when α k < 1. Both issues preclude a Laplace approximation, at least in the naïve form described above. However, MacKay (1998) noted that both can be fixed, elegantly, by a change of variable (see Figure 2 ). Details of the following argument can be found in the supplements. Consider the K-dimensional variable π ∼ Dir(π|α) defined as the softmax of z ∈ R K : π k (z) := exp(z k ) K l=1 exp(z l ) , for all k = 1, . . . , K. We will call z the logit of π. When expressed as a function of z, the density of the Dirichlet in π has to be multiplied by the absolute value of the determinan of the Jacobian det ∂π ∂z = k π k (z), thus removing the -1 terms in the exponent: Dir z (π(z)|α) := Γ K k=1 α k K k=1 Γ(α k ) K k=1 π k (z) α k , This density of z, the Dirichlet distribution in the softmax basis, can now be accurately approximated by a Gaussian through a Laplace approximation, yielding an analytic map from the parameter space α ∈ R K + to the parameter space of the Gaussian (µ ∈ R K and symmetric positive definite Σ ∈ 1 For clarity: Laplace approximations are also one out of several possible ways to construct a Gaussian approximation to the weight posterior of a NN, by constructing a second-order Taylor approximation of the empirical risk at the trained weights. This is not the way they are used in this section. The Laplace Bridge is agnostic to how the input Gaussian distribution is constructed. It could, e.g., also be constructed as a variational approximation, or the moments of Monte Carlo samples. See also Section 5. 2012)). Visualization of the Laplace Bridge for the Beta distribution (1D special case of the Dirichlet). Left: "Generic" Laplace approximations of standard Beta distributions by Gaussians. Note that the Beta Distribution (red curve) does not even have a valid approximation because the Hessian is not positive semi-definite. Middle: Laplace approximation to the same distributions after basis transformation through the softmax (4). The transformation makes the distributions "more Gaussian" (i.e. uni-modal, bell-shaped, with support on the real line), thus making the Laplace approximation more accurate. Right: The same Beta distributions, with the back-transformation of the Laplace approximations from the middle figure to the simplex, yielding a much improved approximate distribution. In particular, in contrast to the left-most image, the dashed lines now actually are probability distributions (they integrate to 1 on the simplex). R K×K ), given by µ k = log α k - 1 K K l=1 log α l , Σ k = δ k 1 α k - 1 K 1 α k + 1 α - 1 K K u=1 1 α u . The corresponding derivations require care because the Gaussian parameter space is evidently larger than that of the Dirichlet and not fully identified by the transformation. A pseudo-inverse of this map was provided by Hennig et al. (2012) . It maps the Gaussian parameters to those of the Dirichlet as α k = 1 Σ kk 1 - 2 K + e µ k K 2 K l=1 e -µ l (this equation ignores off-diagonal elements of Σ, more discussion below). Together, Eqs. 5, 6 and 7 will here be used for Bayesian Deep Learning, and jointly called the Laplace Bridge. A full derivation of the Laplace Bridge can be found in B and C. Even though the Laplace Bridge implies a reduction of the expressiveness of the distribution, we show in Section 3 that this map is still sufficiently accurate. Figure 1 shows the quality of the resulting approximation. We consider multiple different µ, Σ in three dimensions. We exhaustively sample from the Gaussian and apply the softmax. The resulting histogram is compared to the PDF of the corresponding Dirichlet. The first part of the figure emphasizes that a point estimate is insufficient. Since the mean for the Dirichlet is the normalized parameter vector α, the parameters (α 1 = [2, 2, 6] and α 2 = [11, 11, 51] ) yield the same point estimate even though their distributions are clearly different. The second part shows how the Laplace Bridge maps w.r.t decreasing uncertainty.

3. THE LAPLACE BRIDGE FOR BNNS

Let f θ : R N → R K be an L-layer Neural Network parametrized by θ ∈ R P , with a Gaussian approximate posterior N (θ|µ θ , Σ θ ). For any input x ∈ R N , one way to obtain an approximate Gaussian distribution on the pre-softmax output (logit vector) f θ (x) =: z is as q(z|x) ≈ N (z|f mµ θ (x), J(x) Σ θ J(x)) , where J(x) is the P × K Jacobian matrix representing the derivative ∂z ∂θ (Mackay, 1995) . This is a linearized approximation of the network which doesn't account for non-linearities and is often computationally infeasible. For the experiments with larger networks we therefore use a last-layer Laplace approximation of the network as successfully used, e.g. by Snoek et al. (2015) ; Wilson et al. (2016) ; Brosse et al. (2020) ; Kristiadi et al. (2020) . It is given by q(z|x) ≈ N (z|µ W (l) φ(x), φ(x) T Σ W (l) φ(x)) , (9) where φ(x) denotes the output of a forward pass through the first L-1 layers, µ W (l) is the maximum likelihood estimate for the weights of the last layer, and Σ W (l) is the inverse of the negative Hessian w.r.t. the loss Σ W (l) = -{ ∂ 2 L ∂ 2 W (l) } -1 . Approximating the density of the softmax of this Gaussian random variable as a Dirichlet, using the Laplace Bridge, analytically approximates the predictive distribution in a single step, as opposed to many samples. From Eq. ( 7), this requires O(K) computations to construct the K parameters α k of the Dirichlet. In contrast, MC-integration has computational costs of O(M J), where M is the number of samples and J is the cost of sampling from q(z|x) (typically J is of order K 2 after an initial O(K 3 ) operation for a matrix decomposition of the covariance). The Monte Carlo approximation has the usual sampling error of O(1/ √ M ), while the Laplace Bridge has a fixed but small error (empirical comparison in Section 5.2). We now discuss several qualitative properties of the Laplace Bridge relevant for the uncertainty quantification use case in Deep Learning. For output classes of "comparably high" probability, as defined below, the variance Var(π k |α) under the Laplace Bridge increases with the variance of the underlying Gaussian. In this sense, the Laplace Bridge approximates the uncertainty information encoded in the output of a BNN. Proposition 1 (proof in supplements). Let Dir(π|α) be obtained via the Laplace Bridge from a Gaussian O N (z|µ, Σ) over R K . Then, for each k = 1, . . . , K, letting α =k := l =k α l , if α k > 1 4 9α 2 =k + 10α =k + 1 -α =k -1 , then the variance Var(π k |α) of the k-th component of π is increasing in Σ kk . Intuitively, this result describes the condition that needs to be fulfilled such that the variance of the resulting Dirichlet scales with the variance of the k-th component of the Gaussian. It can be seen as a proxy for a high-quality approximation. An empirical evaluation showing that this condition is fulfilled in most cases can be found in the supplements. Further benefits of this approximation arise from the convenient analytical properties of the Dirichlet exponential family. For example, a point estimate of the posterior predictive distribution is directly given by the Dirichlet's mean, Eπ = α 1 K l=1 α l , . . . , α K K l=1 α l , Additionally, Dirichlets have Dirichlet marginals: If p(π) = Dir(π|α), then p      π 1 , π 2 , . . . , π j , k>j π k      = Dir   α 1 , α 2 , . . . , α j , k>j α k   . ( ) An additional benefit of the Laplace Bridge for BNNs is that it is more flexible than an MC-integral. If we let p(π) be the distribution over π := softmax(z) := [e z1 / l e z l , . . . , e z K / l e z l ] , then the MC-integral can be seen as a "point-estimate" of this distribution since it approximates Eπ. In contrast, the Dirichlet distribution Dir(π|α) approximates the distribution p(π). Thus, the Laplace Bridge enables tasks that can be done only with a distribution but not a point estimate. For instance, one could ask "what is the distribution of the first L classes?" when one is dealing with K-class (L < K) classification. Since the marginal distribution can be computed analytically (11), the Laplace Bridge provides a convenient yet cheap way of answering this question.

4. RELATED WORK

In Bayesian Neural Networks, analytic approximations of posterior predictive distributions have attracted a great deal of research. In the binary classification case, for example, the probit approximation has been proposed already in the 1990s (Spiegelhalter and Lauritzen, 1990; MacKay, 1992b) . However, while there exist some bounds (Titsias, 2016) and approximations of the expected log-sum-exponent function (Ahmed and Xing, 2007; Braun and McAuliffe, 2010) , in the multi-class case, obtaining a good analytic approximation of the expected softmax function under a Gaussian measure is still considered an open problem. The Laplace Bridge can be used to produce a close analytical approximation of this integral. It furthers the trend of sampling-free solutions within Bayesian Deep Learning (Wu et al., 2018; Haussmann et al., 2019, etc.) . The crucial difference is that, unlike these methods, the Laplace Bridge approximates the full distribution over the softmax outputs of a deep network. Recently, it has been proposed to model the distribution of softmax outputs of a network directly. Similar to the Laplace Bridge, Malinin and Gales (2018; 2019) ; Sensoy et al. (2018) proposed to use the Dirichlet distribution to model the posterior predictive for non-Bayesian networks. They further proposed novel training techniques in order to directly learn the Dirichlet. In contrast, the Laplace Bridge tackles the problem of approximating the distribution over the softmax outputs of the ubiquitous Gaussian-approximated Bayesian networks (Graves, 2011; Blundell et al., 2015; Louizos and Welling, 2016; Sun et al., 2017, etc) without any additional training procedure. This allows the Laplace Bridge to be used with pre-trained networks and emphasizes its non-invasive nature.

5. EXPERIMENTS

We conduct three main experiments. First, we compare the Laplace Bridge to the MC-integral in the example application of out-of-distribution (OOD) detection (Section 5.1). Secondly, we compare their computational cost and contextualize the speed-up for the prediction process in Section 5.2. Thirdly, in Section 5.3, we present analysis on ImageNet (Russakovsky et al., 2014) to demonstrate the scalability of the LB and the advantage of having a full Dirichlet distribution over softmax outputs. For a preliminary experiment to showcase useful features of the LB on the MNIST dataset and code for all experiments we refer the reader to Appendix D. All experiments were conducted using different forms of Laplace approximations on the weight space. For the smaller experiments, a full (all-layer) Laplace approximation with a diagonal Hessian (LeCun et al., 1990) was used. For the experiments with larger networks, the Laplace approximation has been applied only to the last-layer of the network. This scheme has been successfully used by Snoek et al. (2015) ; Wilson et al. (2016) ; Brosse et al. (2020) , etc. and it has been shown theoretically to mitigate overconfidence problems in ReLU networks (Kristiadi et al., 2020) . For the last-layer experiments we use diagonal and Kronecker-factorized (Ritter et al., 2018) (KFAC) approximations of the Hessian, since inverting the exact Hessian is too costly. A detailed mathematical explanation and setup of the experiments can be found in appendix D. While the Laplace Bridge could also be applied to different approximations of a Gaussian posterior predictive such as Variational Inference (Graves, 2011; Blundell et al., 2015) , we used a Laplace approximation in our experiments to construct such an approximation. This is for two reasons: (i) it is one of the fastest ways to get a Gaussian posterior predictive and (ii) it can be applied to pre-trained networks which is especially useful for ImageNet experiments. Nevertheless, we want to emphasize again that the Laplace Bridge can be applied to any Gaussian over the outputs independent of the way it was generated. Despite the overlap in nomenclature, the Laplace Bridge is not restricted to Gaussians arising from a Laplace approximation of the network.

5.1. OOD DETECTION

We compare the performance of the Laplace Bridge to the MC-integral (Diagonal and KFAC) on a standard OOD detection benchmark suite, to test whether the Laplace Bridge gives similar results to the MC sampling methods and compare their computational overhead. Following prior literature, we use the standard mean-maximum-confidence (MMC) and area under the ROC-curve (AUROC) metrics (Hendrycks and Gimpel, 2016) . For an in-distribution dataset, a higher MMC value is desirable while for the OOD dataset we want a lower MMC value (optimally, 1/K in K-class classification problems). For the AUROC metric, the higher the better, since it represents how good a method is for distinguishing in-and out-of-distribution datasets. The test scenarios are: (i) A two-layer convolutional network trained on the MNIST dataset. To approximate the posterior over the parameter of this network, a full (all-layer) Laplace approximation with a diagonal Hessian is used. The OOD datasets for this case are FMNIST (Xiao et al (Bulatov, 2011) , and KMNIST (Clanuwat et al., 2018) . (ii) For larger datasets, i.e. CIFAR-10 ( Krizhevsky et al., 2014) , SVHN (Netzer et al., 2011) , and CIFAR-100 (Krizhevsky et al., 2014) , we use a ResNet-18 network (He et al., 2016) . Since this network is large, Equation ( 8) in conjunction with a full Laplace approximation is too costly. We, therefore, use a last-layer Laplace approximation to obtain the approximate diagonal and KFAC Gaussian posterior. The OOD datasets for CIFAR-10, SVHN, and CIFAR-100 are SVHN and CIFAR100; CIFAR-10 and CIFAR-100; and SVHN and CIFAR-10, respectively. In all scenarios, the networks are well-trained with 99% accuracy on MNIST, 95.4% on CIFAR-10, 76.6% on CIFAR-100, and 100% on SVHN. For the sampling baseline, we use 1000 posterior samples to compute the predictive distribution which is motivated by Figure 3 . We use the mean of the Dirichlet to obtain a comparable approximation to the MC-integral. The results are presented in Table 1 . The Laplace Bridge yields, on average, better results than diagonal sampling and ties with KFAC sampling w.r.t both metrics. However, the Laplace Bridge is around 400 times faster than both sampling-based methods and is therefore preferable. Further comparisons with ensemble networks and other methods to approximate the integral of a softmax-Gaussian can be found in Appendix D. In both cases the LB is comparable or better in terms of OOD detection while being faster and yielding a full distribution.

5.2. TIME COMPARISON

We compare the computational cost of the density-estimated p sample distribution via sampling and the Dirichlet obtained from the Laplace Bridge p LB for approximating the true p true over softmax-Gaussian samplesfoot_0 . Different amounts of samples are drawn from the Gaussian, the softmax is applied and the KL divergence between the histogram of the samples with the true distribution is computed. We use KL-divergences D KL (p true p sample ) and D KL (p true p LB ), respectively, to measure similarity between approximations and ground truth while the number of samples for p sample is increased exponentially. The true distribution p true is constructed via MC with 100k samples. The experiment is conducted for three different Gaussian distributions over R 3 . Since the softmax applied to a Gaussian does not have an analytic form, the algebraic calculation of the approximation error is not possible and an empirical evaluation via sampling is the best option. The logistic-Normal (Aitchison and Shen, 1980 ) is a different distribution and not useful for our purposes as it has no analytic expected value. The fact that there is no analytic solution is part of the justification for using the Laplace Bridge in the first place. Figure 3 suggests that the number of samples required such that the distribution p sample approximates the true distribution p true as good as the Dirichlet distribution obtained via the Laplace Bridge is large, i.e. somewhere between 500 and 10000. This translates to a wall-clock time advantage of at least a factor of 100 before sampling becomes competitive in quality with the Laplace Bridge. To better understand how light-weight and non-invasive the Laplace Bridge is, we have timed different parts of the process for our setup. the last step of the prediction pipeline it is important to compare its time to that of a forward pass through the rest of the network. Re-using the ResNet-18 and CIFAR10 setup we measure the time in seconds for a forward pass, for the application of the Laplace Bridge, and for the sampling procedure with 10, 100, and 1000 samples. The resulting sum of timings for the entire test set is given in Table 2 ). We find that sampling is between 95% (for 1000 samples) and 21% (for 10 samples) of the entire prediction while the LB is only 3%. These results show that the time accelerations through the LB are a meaningful improvement for the prediction process as a whole.

5.3. UNCERTAINTY-AWARE OUTPUT RANKING ON IMAGENET

Classification tasks on large datasets with many classes, like ImageNet, are not often done in a Bayesian fashion since the posterior inference and sampling are expensive. The Laplace Bridge, in conjunction with the last-layer Bayesian approximations, can be used to alleviate this problem. Furthermore, having a full distribution over the softmax outputs of a BNN gives rise to new possibilities. For example, one could subsume all classes which have sufficiently overlapping marginal distributions into one if they are semantically similar as illustrated in Figure 4 . Another possibility is to improve the standard classification metrics. Large classification tasks like ImageNet are often compared along a top-5 metric, i.e. it is tested whether the correct class is within the five most probable estimates of the network. Although widely accepted, this metric has some pathologies. Consider two examples: i) Assume the network has to classify a hypothetical image of "raptor" and it is confident that the label is either a "hawk" or an "eagle". Then all probability mass should be distributed between those two classes. ii) Assume the network has to classify an image of which it is confident that it is a "fish" but it is uncertain between ten different possible fish species. In both cases a static rule (always 5) is ill-suited. Leveraging the probabilistic output provided by the Laplace Bridge, we propose a simple decision rule that can handle such examples and is more fine-grained due to its awareness of uncertainty. One may call such a rule uncertainty-aware top-k; pseudocode for the algorithm is given in the In the first column, the overlap between the marginal of all classes is large, signifying high uncertainty, i.e. the prediction is "I do not know". In the second column, "notebook" and "laptop" have confident, yet overlapping marginal densities and therefore yield a top-2 prediction: "either a notebook or a laptop". In the third column "desktop computer", "screen" and "monitor" have overlapping marginal densities, yielding a top-3 estimate. The last case shows a top-1 estimate: the network is confident that "laptop" is the only correct label. supplements. Instead of taking the top-k as a decision threshold for an arbitrary k we take the uncertainty/confidence of the model to inform the decision. This is more flexible and therefore able to handle situations in which different numbers of classes are plausible outcomes. The Dirichlet distribution obtained from the Laplace Bridge provides this capability. In particular, since the marginal distribution over each component of a Dirichlet distribution is a Beta(α i , j =i α j ), this can be done analytically and efficiently. The proposed decision rule uses the area of overlap between the marginal distributions of the sorted outcomes. This is similar to hypotheses testing, i.e. t-tests (Nickerson, 2000) or its Bayesian alternatives (Masson, 2011). If, for example, two Beta densities overlap more than 5%, we cannot say that they are different distributions with high confidence. All distributions that have sufficient overlap should become the new top-k estimate. Pseudo code for the uncertainty aware top-k ranking can be found in Algorithm 1. Figure 4 shows four examples from the "laptop" class of ImageNet. We evaluate this decision rule on the test set of ImageNet. The overlap is calculated through the inverse CDFfoot_1 of the respective Beta marginals. The original top-1 accuracy of DenseNet on ImageNet is 0.744. Meanwhile, the uncertainty-aware top-k accuracy is 0.797, where k is on average 1.688. A more detailed analysis of the distribution of top-k estimates (see Figure 5 ) shows that most of the predictions given by the uncertainty-aware metric still yielded a top-1 prediction. This means that using uncertainty does not imply adding meaningless classes to the prediction. However, there are some non-negligible cases where k equals to 2, 3, or 10 where all values larger than 10 are in the 10 bin for visibility. This indicates that whenever the class labels are ambiguous, our method can detect it, and thus yields a significantly higher accuracy.

Algorithm 1 Uncertainty-aware top-k

Input: A Dirichlet parameter α ∈ R K obtained by applying the Laplace Bridge to the Gaussian over the logit of an input, a percentile threshold T e.g. 0.05, a function class of that returns the underlying class of a sorted index.

α = sort descending(α)

// start with the highest confidence α 0 = i α i C = {class of(1)} // initialize top-k, must include at least one class for i = 2, . . . , K do F i-1 = Beta(α i-1 , α 0 -αi-1 ) // the previous marginal CDF F i = Beta(α i , α 0 -αi ) // the current marginal CDF l i-1 = F -1 i-1 (T /2) // left T 2 percentile of the previous marginal  r i = F -1 i (1 -T /2) // right T 2 percentile of the current marginal if r i > l i-1 then C = C ∪ {class of(i)} //

6. CONCLUSION

We have adapted an old but overlooked approximation scheme for new use in Bayesian Deep Learning. Given a Gaussian approximation to the weight-space posterior of a neural network (which can be constructed by various means, including another Laplace approximation), and an input, the Laplace Bridge analytically maps the marginal Gaussian prediction on the logits onto a Dirichlet distribution over the softmax vectors. The associated computational cost of O(K) for K-class prediction compares favorably to that of Monte Carlo sampling. The proposed method both theoretically and empirically preserves predictive uncertainty, offering an attractive, low-cost, high-quality alternative to Monte Carlo sampling. In conjunction with a low-cost, last-layer Bayesian approximation, it can be useful in real-time applications wherever uncertainty is required-especially because it reduces the cost of predicting a posterior distribution at test time.

A APPENDIX FIGURES

The parameters of Figure 1 are the µ and Σ mapped from α = (2, 2, 6) and α = (11, 11, 51) through the Laplace Bridge for a) and b) respectively. The values have been chosen to have the same mode of the Dirichlet distribution but clearly yield different uncertainties. For c), d) and e) they share µ = (-1, 2, -1) and vary Σ with 10 • I3, 1 • I3 and 0.1 • I3 where I3 is the three-dimensional identity matrix. The parameters of Figure 2 are from left to right α, β = (0.8, 0.9), (4, 2, ), (2, 7).

CHANGE OF VARIABLE FOR PDF

Let z be an n-dimensional continuous random variable with joint density function px. If y = G(x), where G is a differentiable function, then y has density py: p(y) = f G -1 (y) det   dG -1 (z) dz z=y   (12) where the differential is the Jacobian of the inverse of G evaluated at y. This procedure, also known as 'change of basis', is at the core of the Laplace bridge since it is used to transform the Dirichlet into the softmax basis.

PROOF FOR PROPOSITION

Proof. Considering that α k is a decreasing function of Σ kk by definition ( 7), it is sufficient to show that under the hypothesis, the derivative of ∂ ∂α k Var(π k |α) is negative. By definition, the variance Var(π k |α) is Var(π k |α) = α k α k +α =k - α 2 k (α k +α =k ) 2 α k + α =k + 1 . The derivative is therefore ∂ ∂α k Var(π k |α) = α =k (α 2 =k -α =k α k + α =k -α k (2α k + 1)) (α k + α =k ) 3 (α k + α =k + 1) 2 . Solving ∂ ∂α k Var(π k |α) < 0 for α k yields α k > 1 4 9α 2 =k + 10α =k + 1 -α =k -1 . Therefore, under this hypothesis, Var(π k |α) is a decreasing function of α k .

EXPERIMENTAL EVALUATION OF THE PROPOSITION

To test how often the condition is fulfilled we count its frequency. The fact that the condition is fulfilled implies a good approximation. The fact that the condition is not fulfilled does not automatically imply a bad approximation.

B APPENDIX (DERIVATION OF LB)

Assume we have a Dirichlet in the standard basis with parameter vector α and probability density function: Dir(π|α) := Γ K k=1 α k K k=1 Γ(α k ) K k=1 π α k -1 k , ( ) We aim to transform the basis of this distribution via the softmax transform to be in the new base π:  π k (z) := exp(z k ) K l=1 exp(z l ) , Usually, to transform the basis we would need the inverse transformation H -1 (z) as described in the main paper. However, the softmax does not have an analytic inverse. Therefore David JC MacKay uses the following trick. Assume we know that the distribution in the transformed basis is: Dirz(π(z)|α) := Γ K k=1 α k K k=1 Γ(α k ) K k=1 π k (z) α k , then we can show that the original distribution is the result of the basis transform by the softmax. The Dirichlet in the softmax basis: We show that the density over π shown in Equation 15 transforms into the Dirichlet over z. First, we consider the special case where π is confined to a I -1 dimensional subspace satisfying i πi = c. In this subspace we can represent π by an I -1 dimensional vector a such that πi = ai i, ..., I -1 (16) πI = c - I-1 i ai and similarly we can represent z by an I -1 dimensional vector ϕ: zi = ϕi i, ..., I -1 (18) zI = 1 - I-1 i ϕi then we can find the density over ϕ (which is proportional to the required density over z) from the density over π (which is proportional to the given density over π) by finding the determinant of the (I -1) × (I -1) Jacobian J given by J ik = ∂zi ∂ϕi = I j ∂zi ∂πj ∂πj ∂ϕ k = δ ik zi -ziz k + zizI = zi(δ ik -(z k -zI )) We define two additional I -1 dimensional helper vectors z + k := z k -zI and n k := 1, and use det(I -xy T ) = 1 -x • y from linear algebra. It follows that det J = I-1 i=1 zi × det[I -nz + T ] = I-1 i=1 zi × (1 -n • z + ) (21) = I-1 i=1 zi × 1 - k z + k = I I i=1 xi Therefore, using Equation 15 we find that P (z) = P (π) | det J| ∝ I i=1 z α i -1 i ( ) This result is true for any constant c since it can be put into the normalizing constant. Thereby we make sure that the integral of the distribution is 1 and we have a valid probability distribution.

C APPENDIX (DERIVATION OF INVERSION)

Through the figures of the 1D Dirichlet approximation in the main paper we have already established that the mode of the Dirichlet lies at the mean of the Gaussian distribution and therefore π(y) = α i α i . Additionally, the elements of y must sum to zero. These two constraints combined yield only one possible solution for µ. µ k = log α k - 1 K K l=1 log α l Calculating the covariance matrix Σ is more complicated but layed out in the following. The logarithm of the Dirichlet is, up to additive constants log py(y|α) = k α k π k Using π k as the softmax of y as shown in Equation 14we can find the elements of the Hessian L L kl = α(δ kl πk -πk πl ) where α := k α k and π = α k α for the value of π at the mode. Analytically inverting L is done via a lengthy derivation using the fact that we can write L = A + XBX and inverting it with the Schur-complement. This process results in the inverse of the Hessian L -1 kl = δ kl 1 α k - 1 K 1 α k + 1 α l - 1 K K u 1 αu We are mostly interested in the diagonal elements, since we desire a sparse encoding for computational reasons and we otherwise needed to map a K ×K covariance matrix to a K ×1 Dirichlet parameter vector which would be a very overdetermined mapping. Note that K is a scalar not a matrix. The diagonal elements of Σ = L -1 can be calculated as Σ kk = 1 α k 1 - 2 K + 1 K 2 k l 1 α l . To invert this mapping we transform Equation 23to α k = e µ k K l α 1/K l (28) by applying the logarithm and re-ordering some parts. Inserting this into Equation 27 and re-arranging yields K l α 1/K l = 1 Σ kk e -µ 1 - 2 K + 1 K 2 K u e -µu which can be re-inserted into Equation 28 to give α k = 1 Σ k k 1 - 2 K + e -µ k K 2 K l e -µ k which is the final mapping. With Equations 23 and 27 we are able to map from Dirichlet to Gaussian and with Equation 30we are able to map the inverse direction.

D APPENDIX (EXPERIMENTAL DETAILS)

The exact experimental setups, i.e. network architectures, learning rates, random seeds, etc. can be found in the accompanying anonymized GitHub repositoryfoot_2 . This section is used to justify some of the decisions we made during the process in more detail, highlight some miscellaneous interesting things and showcase the additional experiments promised in the main paper.

EXPERIMENTAL SETUP

The (F-, K-, not-)MNIST networks have been trained with Adam, a learning rate of 1e -6 and batch size of 32 over 6 epochs. The batch size was chosen to be so low because the Jacobian for the full-layer Laplace approximation would otherwise not fit into memory. CIFAR10, CIFAR100 and SVHN have been trained with SGD, a start learning rate of 0.1, momentum of 0.9, weight decay of 5e -4 and a batch size of 128. We applied a learning rate schedule that multiplies the learning rate with 0.1 at epoch 30, 80 and 130 for CIFAR10, at 100, 150, 200 for CIFAR100 and at 10, 20 and 25 for SVHN. (F-, K-, not-)MNIST has been trained on a small 2 conv-layer + 1 linear layer setup and all other OOD experiments were conducted with a ResNet-18. For the ImageNet experiments we used the pre-trained weights from PyTorch and only computed the Hessian information with BACKpack (Dangel et al., 2019) in one backward pass. All prior uncertainties for the Laplace approximation have been chosen such that the resulting MMC is not more than 5 -10% away from the MAP estimate to have comparable results.

MATHEMATICAL DESCRIPTION OF THE SETUP

In principle, the Gaussian over the weights required by the Laplace Bridge for BNNs can be constructed by any Gaussian approximate Bayesian method such as variational Bayes (Graves, 2011; Blundell et al., 2015) and Laplace approximations for NNs (MacKay, 1992a; Ritter et al., 2018) . We will focus on the Laplace approximation, which uses the same principle as the Laplace Bridge. However, in the Laplace approximation for neural networks, the posterior distribution over the weights of a network is the one that is approximated as a Gaussian, instead of a Dirichlet distribution over the outputs as in the Laplace Bridge. Given a dataset D := {(xi, ti)} D i=1 and a prior p(θ), let p(θ|D) ∝ p(θ)p(D|θ) = p(θ) (x,t)∈D p(y = t|θ, x) , be the posterior over the parameter θ of an L-layer network f θ . Then we can get an approximation of the posterior p(θ|D) by fitting a Gaussian N (θ|µ θ , Σ θ ) where µ θ = θMAP , Σ θ = (-∇ 2 | θ MAP log p(θ|D)) -1 =: H -1 θ . That is, we fit a Gaussian centered at the mode θMAP of p(θ|D) with the covariance determined by the curvature at that point. We assume that the prior p(θ) is a zero-mean isotropic Gaussian N (θ|0, σ 2 I) and the likelihood function is the Categorical density p(D|θ) = (x,t)∈D Cat(y = t|softmax(f θ (x))) . For various applications in Deep Learning, an approximation with full Hessian is often computationally too expensive. Indeed, for each input x ∈ R N , one has to do K backward passes to compute the Jacobian J(x). Moreover, it requires an O(P K) storage which is also expensive since P is often in the order of millions. A cheaper alternative is to fix all but the last layer of f θ and only apply the Laplace approximation on WL, the last layer's weight matrix. This scheme has been used successfully by Snoek et al. (2015) ; Wilson et al. (2016) ; Brosse et al. (2020) , etc. and has been shown theoretically that it can mitigate overconfidence problems in ReLU networks (Kristiadi et al., 2020) . In this case, given the approximate last-layer posterior p(W L |D) ≈ N (vec(W L )|vec(W L MAP ), H -1 W L ) , one can efficiently compute the distribution over the logits. That is, let φ : R N → R Q be the first L -1 layers of f θ , seen as a feature map. Then, for each x ∈ R N , the induced distribution over the logit W L φ(x) =: z is given by p(z|x) = N (z|W L MAP φ(x), (φ(x) ⊗ I)H -1 W L (φ(x) ⊗ I)) , where ⊗ denotes the Kronecker product. An even more efficient last-layer approximation can be obtained using a Kronecker-factored matrix normal distribution (Louizos and Welling, 2016; Sun et al., 2017; Ritter et al., 2018) . That is, we assume the posterior distribution to be p(W L |D) ≈ MN (W L |W L MAP , U, V) , where U ∈ R K×K and V ∈ R Q×Q are the Kronecker factorization of the inverse Hessian matrix H -1 W L (Martens and Grosse, 2015) . In this case, for any x ∈ R N , one can easily show that the distribution over logits is given by p(z|x) = N (z|W L MAP φ(x), (φ(x) Vφ(x))U) , which is easy to implement and computationally cheap. Finally, and even more efficient, is a last-layer approximation scheme with a diagonal Gaussian approximate posterior, i.e. the so-called mean-field approximation. In this case, we assume the posterior distribution to be p(W L |D) ≈ N (vec(W L )|vec(W L MAP ), diag(σ 2 )) , where σ 2 is obtained via the diagonal of the Hessian of the log-posterior w.r.t. vec(W L ) at vec(W L MAP ). To give a good preliminary intuitive understanding of the LBs usefulness we empirically investigate the approximation quality of the Laplace Bridge in a "real-world" BNN on the MNIST dataset. A CNN with 2 convolutional and 2 fully-connected layers is trained on the first three digits of MNIST (the digits 0, 1, and 2). Adam optimizer with learning rate 1e-3 and weight decay 5e-4 is used. The batch size is 128. To obtain the posterior over the weights of this network, we perform a full (all-layer) Laplace approximation using Back-PACK (Dangel et al., 2019) to get the diagonal Hessian. The network is then evaluated on the full test set of MNIST (containing all ten classes). We present the results in Figure 6 . We show for each k = 1, . . . , K, the average variance 1 D k D k i=1 Var(π k (f θ (xi)) ) of the resulting Dirichlet distribution over the softmax outputs, where D k is the number of test points predicted with label k. The results show that the variance of the Dirichlet distribution obtained via the Laplace Bridge is useful for uncertainty quantification: OOD data can be easily detected since the mean variance of the first three classes is nearly zero while that of the others is higher.

OOD DETECTION

Every experiment has been conducted with 5 different seeds. In the tables the mean and standard deviations are presented. The reason why the sampling procedure for the CIFAR-10 and CIFAR-100 case are similarly fast even though we draw from a 10-vs 100-dimensional Gaussian is because the sampling procedures were parallelized on a GPU. All prior uncertainties over the weights were chosen such that the MMC of the sampling averages was around 5% lower than the MAP estimate. To compare the OOD detection capabilities of the LB with other more standard methods we first evaluate it compared to an ensemble of 100 different sets of weights. As shown in Table 4 we find that the Laplace Bridge yields comparable results while being more than a thousand times faster. Table 4 : Results for sampling from all weights instead of just the last layer. Number of samples was 100. Sampling all weights and doing a forward pass seems to be worse than using the Laplace Bridge even though it takes much longer.  ) i exp(a (i) ) 1 Z exp - 1 2 i (a (i) -ā(i) ) 2 v (i) ≈ exp(τ (v (j) )ā (j) ) i exp(τ (v (i) )ā (i) ) as an approximation. The logit output of class i is denoted by a (i) , its average by ā(i) , its variance by v (i) and  τ (v) = 1/ 1 + π • v/8. (y) ≈ p 1 + p Σp -Σp + 1 2 diag(Σ) - 1 2 diag(Σ) ( ) where p is the vector of logit outputs, Σ is its covariance matrix, and represents the entry-wise product. The results in Table 5 suggest that the LB shows slightly better performance than the other two approaches. Additionally, the LB not only yields an approximation of the integral but additionally a fully parameterized Dirichlet distribution of the output classes and is therefore preferable. Table 5 : Comparison of the Laplace Bridge Dirichlet mean with other methods to compute the integral over a softmax-Gaussian. We find that the LB is slightly better than the other two methods while additionally not only providing an estimate for the integral but also giving a fully parameterized Distribution. 

UNCERTAINTY-AWARE OUTPUT RANKING ON IMAGENET

The prior covariances for the Laplace approximation of the Hessian over the weights were chosen such that uncertainty estimate of the Laplace bridge MMC over the outputs was not more than 5% lower than the MAP estimate. The length of list generated by our uncertainty aware method was chosen such that it contained at least one and maximally ten samples. Originally we wanted to choose the maximal length according to the size of the largest category (e.g. fishes or dogs) but the class tree hierarchy of ImageNet does not answer this question meaningfully. We chose ten because there are no reasonable bins larger than ten when looking at a histogram. Pseudo code for the uncertainty aware top-k ranking can be found in Algorithm 1.



I.e. samples are obtained by first sampling from a Gaussian and transforming it via the softmax function. Also known as the quantile function or percent point function https://github.com/19219181113/LB_for_BNNs



Figure 2: (Adapted fromHennig et al. (2012)). Visualization of the Laplace Bridge for the Beta distribution (1D special case of the Dirichlet). Left: "Generic" Laplace approximations of standard Beta distributions by Gaussians. Note that the Beta Distribution (red curve) does not even have a valid approximation because the Hessian is not positive semi-definite. Middle: Laplace approximation to the same distributions after basis transformation through the softmax (4). The transformation makes the distributions "more Gaussian" (i.e. uni-modal, bell-shaped, with support on the real line), thus making the Laplace approximation more accurate. Right: The same Beta distributions, with the back-transformation of the Laplace approximations from the middle figure to the simplex, yielding a much improved approximate distribution. In particular, in contrast to the left-most image, the dashed lines now actually are probability distributions (they integrate to 1 on the simplex).

Figure 3: KL-divergence plotted against the number of samples (left) and wall-clock time (right). Monte Carlo density estimation becomes as good as the Laplace Bridge after around 750 to 10000 samples and takes at least 100 times longer. The three lines represent three different samples.

Figure4: Upper row: images from the "laptop" class of ImageNet. Bottom row: Beta marginals of the top-k predictions for the respective image. In the first column, the overlap between the marginal of all classes is large, signifying high uncertainty, i.e. the prediction is "I do not know". In the second column, "notebook" and "laptop" have confident, yet overlapping marginal densities and therefore yield a top-2 prediction: "either a notebook or a laptop". In the third column "desktop computer", "screen" and "monitor" have overlapping marginal densities, yielding a top-3 estimate. The last case shows a top-1 estimate: the network is confident that "laptop" is the only correct label.

Figure 5: A histogram of ImageNet predictions' length using the proposed uncertainty-aware top-k. Most test images are a top-1 prediction, indicating high confidence. There are some top-2, top-3, and top-10 predictions, showing increased uncertainty. All predictions > 10 are grouped in the 10-bin for visibility.

Figure 6: Average variance of the Dirichlet distributions of each MNIST class. The in-distribution uncertainty (variance) is nearly nil, while out-of-distribution variance is higher. This implies usability for OOD detection.

., OOD detection results. The Laplace Bridge (LB) wins most comparisons with Diagonal sampling and draws even with KFAC sampling w.r.t. both metrics. However, the LB is around 400 times faster on average which is the key property of the LB. 1000 samples were drawn from the Gaussian over the outputs motivated by Figure3. The MNIST experiments were done with a Laplace approximation of the entire network while the others only used the last layer. The same table with error bars is in Appendix D.

Training a ResNet-18 on CIFAR10 over 130 epochs takes us 71 minutes and 30 seconds. Computing a Hessian for the network from the training data can be done with(Dangel et al., 2019) at the cost of one backward pass or around 29 seconds. No other changes to the training procedure are necessary to apply the LB. Since the Laplace Bridge only applies to Contextualization of the timings for the entire predictive process. We find that the LB provides a significant speed-up of the process as a whole.

overlap detected, add the current class



However, we think the Laplace Bridge should primarily not be compared to other OOD detection algorithms such as those listed inHendrycks and Gimpel (2016) but rather to algorithms that approximate the integral of the softmax transform applied to Gaussian samples since they have comparable speed. The two algorithms we compare to the LB are the Extended MacKay approach Gibbs (1997) (see equation 5.33 in (Gibbs)) which uses exp(a(j)

Every experiment has been conducted with 5 different seeds. The presented curves are the averages over these 5 experiments with errorbars. The reason why taking one sample is slower than two is because of the way random numbers are generated for the normal distribution. For further information read up on the Box-Mueller Transform.

OOD detection results. Same table as in the main experiments but with error estimates. The Laplace Bridge (LB) wins most comparisons with Diagonal sampling and draws even with KFAC sampling w.r.t. both metrics. However, the LB is around 400 times faster on average. 1000 samples were drawn from the Gaussian over the outputs motivated by Figure3. The (F-, K-, not-)MNIST experiments were done with a Laplace approximation of the entire network while the others only used the last layer. Five runs with different seeds per experiment were conducted. Standard deviation is reported to the third decimal.

