SAM AS AN OPTIMAL RELAXATION OF BAYES

Abstract

Sharpness-aware minimization (SAM) and related adversarial deep-learning methods can drastically improve generalization, but their underlying mechanisms are not yet fully understood. Here, we establish SAM as a relaxation of the Bayes objective where the expected negative-loss is replaced by the optimal convex lower bound, obtained by using the so-called Fenchel biconjugate. The connection enables a new Adam-like extension of SAM to automatically obtain reasonable uncertainty estimates, while sometimes also improving its accuracy. By connecting adversarial and Bayesian methods, our work opens a new path to robustness.

1. INTRODUCTION

Sharpness-aware minimization (SAM) (Foret et al., 2021) and related adversarial methods (Zheng et al., 2021; Wu et al., 2020; Kim et al., 2022) have been shown to improve generalization, calibration, and robustness in various applications of deep learning (Chen et al., 2022; Bahri et al., 2022) , but the reasons behind their success are not fully understood. The original proposal of SAM was geared towards biasing training trajectories towards flat minima, and the effectiveness of such minima has various Bayesian explanations, for example, those relying on the optimization of description lengths (Hinton & Van Camp, 1993; Hochreiter & Schmidhuber, 1997) , PAC-Bayes bounds (Dziugaite & Roy, 2017; 2018; Jiang et al., 2020; Alquier, 2021) , or marginal likelihoods (Smith & Le, 2018). However, SAM is not known to directly optimize any such Bayesian criteria, even though some connections to PAC-Bayes do exist (Foret et al., 2021) . The issue is that the 'max-loss' used in SAM fundamentally departs from a Bayesian-style 'expected-loss' under the posterior; see Fig. 1(a) . The two methodologies are distinct, and little is known about their relationship. Here, we establish a connection by using a relaxation of the Bayes objective where the expected negative-loss is replaced by the tightest convex lower bound. The bound is optimal, and obtained by Fenchel biconjugates which naturally yield the maximum loss used in SAM (Fig. 1(a) ). From this, SAM can be seen as optimizing the relaxation of Bayes to find the mean of an isotropic Gaussian posterior while keeping the variance fixed. Higher variances lead to smoother objectives, which biases the solution towards flatter regions (Fig. 1(b) ). Essentially, the result connects SAM and Bayes through a Fenchel biconjugate that replaces the expected loss in Bayes by a maximum loss. What do we gain by this connection? The generality of our result makes it possible to easily combine the complementary strengths of SAM and Bayes. For example, we show that the relaxed-Bayes objective can be used to learn the variance parameter, which yields an Adam-like extension of SAM (Alg. 1). The variances are cheaply obtained from the vector that adapts the learning rate, and SAM's hyperparameter is adjusted for each parameter dimension via the variance vector. The extension improves the performance of SAM on standard benchmarks while giving comparable uncertainty estimates to the best Bayesian methods. Our work complements similar extensions for SGD, RMSprop, and Adam from the Bayesian deeplearning community (Gal & Ghahramani, 2016; Mandt et al., 2017; Khan et al., 2018; Osawa et al., 2019) . So far there is no work on such connections between SAM and Bayes, except for a recent empirical study by Kaddour et al. (2022) . Husain & Knoblauch (2022) give an adversarial interpretation of Bayes, but it does not cover methods like SAM. Our work here is focused on connections to approximate Bayesian methods, which can have both theoretical and practical impact. We discuss a new path to robustness by connecting the two fields of adversarial and Bayesian methodologies. Our main result in Theorem 2 connects the two by using the optimal convex relaxation (Fenchel biconjugate) of the negative expected loss. It shows that the role of ρ and σ are exactly the same, and a SAM minimizer obtained for a fixed ρ can always be recovered from the relaxation for some σ. An example is shown in Panel (b) for a loss (gray line) with 3 local minima indicated by A, B, and C. The expected loss is smoother (blue line) but the relaxation, which upper bounds it, is even more smooth (red line). Higher σ give smoother objectives where sharp minima A and C slowly disappear. The SAM minimizer is shown with a red star which matches the minimizer of the relaxation.

2. SAM AS AN OPTIMAL RELAXATION OF BAYES

In SAM, the training loss (θ) is replaced by the maximum loss within a ball of radius ρ > 0 around the parameter θ ∈ Θ ⊂ R P , as shown below for an 2 -regularized problem, E SAM (θ; ρ, δ) = sup ≤ρ (θ + ) + δ 2 θ 2 , with δ > 0 as the regularization parameter. The use of 'maximum' above differs from Bayesian strategies that use 'expectation', for example, the variational formulation by Zellner (1988) , L(q) = E θ∼q [-(θ)] -D KL (q(θ) p(θ)), where q(θ) is the generalized posterior (Zhang, 1999; Catoni, 2007) , p(θ) is the prior, and D KL (• •) is the Kullback-Leibler divergence (KLD). For an isotropic Gaussian posterior q(θ) = N (θ | m, σ 2 I) and prior p(θ) = N (θ | 0, I/δ), the objective with respect to the mean m becomes E Bayes (m; σ, δ) = E ∼N (0,σ 2 I) [ (m + )] + δ 2 m 2 , which closely resembles Eq. 1, but with the maximum replaced by an expectation. The above objective is obtained by simply plugging in the posterior and prior, then collecting the terms that depend on m, and finally rewriting θ = m + . There are some resemblances between the two objectives which indicate that they might be connected. For example, both use local neighborhoods: the isotropic Gaussian can be seen as a 'softer' version of the 'hard-constrained' ball used in SAM. The size of the neighborhood is decided by their respective parameters (σ or ρ in Fig. 1(a) ). But, apart from these resemblances, are there any deeper connections? Here, we answer this question. We will show that Eq. 1 is equivalent to optimizing with respect to m an optimal convex relaxation of Eq. 3 while keeping σ fixed. See Theorem 2 where an equivalence is drawn between σ and ρ, revealing similarities between the mechanisms used in SAM and Bayes to favor flatter regions. The space M lies above the parabola, while Ω is a half space. Panel (b) shows f (µ) and its biconjugate f * * (µ) with the blue mesh and red surface respectively. We use the loss from Fig. 1(b) . The f * * function perfectly matches the original loss at the boundary where v → 0 (black curve), clearly showing the regions for three minima A, B and C. Panel (c) shows the same but with respect to the standard mean-variance (ω, v) parameterization. Here, convexity of f * * is lost, but we see more clearly the negative (ω) appear at v → 0 (black curve). At this limit, we also get f (µ) = f * * (µ) = -(ω), but for higher variances, both f and f * * are smoothed version of (ω). The slices at v = 0.5, 2, and 4 give rise to the curves shown in Fig. 1(b) . sider a function f (u), defined over u ∈ U and taking values in the extended real line. Given a dual space v ∈ V, the Fenchel biconjugate f * * and conjugate f * are defined as follows, f * * (u) = sup v ∈V v , u -f * (v ), f * (v ) = sup u ∈ U u , v -f (u ), where •, • denotes the inner product. Note that f need not be convex, but both f * and f * * are convex functions. Moreover, the biconjugate is the optimal convex lower bound: f (u) ≥ f * * (u). We will use the conjugate functions to connect the max-loss in SAM to the expected-loss involving regular, minimal exponential-family distributions (Wainwright et al., 2008) of the following formfoot_0 , In what follows, we will always use µ to parameterize the distribution, denoting q µ (θ). As an example, for a Gaussian N (θ | ω, v) over scalar θ with mean ω and variance v > 0, we have, q λ (θ) = exp ( λ, T(θ) -A(λ)) , T(θ) = (θ, θ 2 ), µ = ω, ω 2 + v , λ = (ω/v, -1/(2v)) . Because v > 0, the space M is the epigraph of parabola (ω, ω 2 ), and Ω is a negative half space because the second natural parameter is constrained to be negative; see Fig. 2 (a). The dual spaces can be used to define the Fenchel conjugates. We will use the conjugates of the 'expected negative-loss' acting over µ ∈ interior(M), f (µ) = E θ∼qµ [-(θ)] . For all µ / ∈ interior(M), we set f (µ) = +∞. Using Eq. 4, we can obtain the biconjugate, f * * (µ) = sup λ ∈Ω λ , µ - sup µ ∈M µ , λ -E θ∼q µ [-(θ) ] , where we use λ and µ to denote arbitrary points from Ω and M, respectively, and avoid confusion to the change of coordinates µ = ∇A(λ ). We show in App. A that, for an isotropic Gaussian, the biconjugate takes a form where max-loss automatically emerges. We define µ to be the expectation parameter of a Gaussian N (θ | ω, vI), and λ to be the natural parameter of a (different) Gaussian N (θ | m, σ 2 I). The equation below shows the biconjugate with the max-loss included in the second term, f * * (µ) = sup m,σ E θ∼qµ - 1 2σ 2 θ -m 2 -sup (m + ) - 1 2σ 2 2 . ( ) The max-loss here is in a proximal form, which has an equivalence to the trust-region form used in Eq. 1 (Parikh & Boyd, 2013, Sec. 3.4) . The first term in the biconjugate arises after simplification of the λ , µ term in Eq. 8, while the second term is obtained by using the theorem below to switch the maximum over µ to a maximum over θ, and subsequently over by rewriting θ = m + . Theorem 1. For a Gaussian distribution that takes the form (5), the conjugate f * (λ ) can be obtained by taking the supremum over Θ instead of M, that is, sup µ ∈M µ , λ -E θ∼q µ [-(θ)] = sup θ∈Θ T(θ), λ -[-(θ)] , assuming that (θ) is lower-bounded, continuous and majorized by a quadratic function. Moreover, there exists λ for which f * (λ ) is finite and, assuming (θ) is coercive, dom(f * ) ⊆ Ω. A proof of this theorem is in App. B, where we also discuss its validity for generic distributions. The red surface in Fig. 2 (b) shows the convex lower bound f * * to the function f (shown with the blue mesh) for a Gaussian with mean ω and variance v (Eq. 6). Both functions have a parabolic shape which is due to the shape of M. Fig. 2 (c) shows the same but with respect to the standard mean-variance parameterization (ω, v). In this parameterization, convexity of f * * is lost, but we see more clearly the negative (ω) appear at v → 0 (shown with the thick black curve). At the limit, we get f (µ) = f * * (µ) = -(ω), but for higher variances, both f and f * * are smoothed versions of (ω). The function f * * (µ) is a 'lifted' version of the one-dimensional (θ) to a two-dimensional µ-space. Such liftings have been used in the optimization literature (Bauermeister et al., 2022) . Our work connects them to exponential families and probabilistic inference. The convex relaxation preserves the shape at the limit v → 0, while giving smoothed lower bounds for all v > 0. For a high enough variance the local minima at A and C disappear, biasing the optimization towards flatter region B of the loss. We will now use this observation to connect to SAM.

2.2. SAM AS AN OPTIMAL RELAXATION OF THE BAYES OBJECTIVE

We will now show that a minimizer of the SAM objective in Eq. 1 can be recovered by minimizing a relaxation of the Bayes objective where expected-loss f is replaced by its convex lower bound f * * . The new objective, which we call relaxed-Bayes, always lower bounds the Bayes objective of Eq. 2, sup µ∈M L(q µ ) ≥ sup µ∈M f * * (µ) -D KL (q µ (θ) p(θ)) (11) = sup µ∈M sup m,σ -sup (m + ) - 1 2σ 2 2 -D KL q µ (θ) 1 Z p(θ)e -1 2σ 2 θ-m 2 + log Z , where in the second line we substitute f * * (µ) from (9), and merge E θ∼qµ [ 1 2σ 2 θm 2 ] with the KLD term. Z denotes the normalizing constant of the second distribution in the KLD. The relaxed-Bayes objective has a specific difference-of-convex (DC) structure because both f * * and KLD are convex with respect to µ. Optimizing the relaxed-Bayes objective relates to double-loop algorithms where the inner-loop uses max-loss to optimize for (m, σ), while the outer loop solves for µ. We will use the structure to simplify and connect to SAM. Due to the DC structure, it is possible to switch the order of maximization with respect to µ and (m, σ), and eliminate µ by a direct maximization (Toland, 1978) . The maximum occurs when the KLD term is 0 and the two distributions in its arguments are equal. A detailed derivation is in App. C, which gives us a minimization problem over (m, σ) E relaxed (m, σ; δ ) = sup (m + ) - 1 2σ 2 2 + δ 2 m 2 - P 2 log(σ 2 δ ) =-log Z , where δ = 1/(σ 2 + 1/δ 0 ) and we use a Gaussian prior p(θ) = N (θ | 0, I/δ 0 ). This objective uses max-loss in the proximal form, which is similar to the SAM loss in Eq. 1 with the difference that the hard-constraint is replaced by a softer penalty term. The theorem below is our main result which shows that, for a given ρ, a minimizer of the SAM objective in Eq. 1 can be recovered by optimizing Eq. 12 with respect to m while keeping σ fixed. Theorem 2. For every (ρ, δ), there exist (σ, δ ) such that arg min θ∈Θ E SAM (θ; ρ, δ) = arg min m∈Θ E relaxed (m, σ; δ ), assuming that the SAM-perturbation satisfies = ρ at a stationary point. A proof is given in App. D and essentially follows from the equivalences between proximal and trust-region formulations. The result establishes SAM as a special case of relaxed-Bayes, and shows that ρ and σ play the exactly same role. This formalizes the intuitive resemblance discussed at the beginning of Sec. 2, and suggests that, in both Bayes and SAM, flatter regions are preferred for higher σ because such values give rise to smoother objectives (Fig. 1(b) ). The result can be useful for combining complementary strengths of SAM and Bayes. For example, uncertainty estimation in SAM can now be performed by estimating σ through the relaxed-Bayes objective. Another way is to exploit the DC structure in µ-space to devise new optimization procedures to improve SAM. In Sec. 3, we will demonstrate both of these improvements of SAM. It might also be possible to use SAM to improve Bayes, something we would like to explore in the future. For Bayes, it still is challenging to get a good performance in deep learning while optimizing for both mean and variance (despite some recent success (Osawa et (Farquhar et al., 2020) . SAM works well, despite using a simple posterior, and it is possible that the poor performance of Bayes is due to poor algorithms, not simplistic posteriors. We believe that our result can be used to borrow algorithmic techniques from SAM to improve Bayes, and we hope to explore this in the future.

3. BAYESIAN-SAM TO OBTAIN UNCERTAINTY ESTIMATES FOR SAM

In this section, we will show an application of our result to improve an aspect of SAM. We will derive a new algorithm that can automatically obtain reasonable uncertainty estimates for SAM. We will do so by optimizing the relaxed-Bayes objective to get σ. A straightforward way is to directly optimize Eq. 12, but we will use the approach followed by Khan et al. (2018), based on the naturalgradient method of Khan & Rue (2021), called the Bayesian learning rule (BLR). This approach has been shown to work well on large problems, and is more promising because it leads to Adam-like algorithms, enabling us to leverage existing deep-learning techniques to improve training (Osawa et al., 2019) , for example, by using momentum and data augmentation. To learn variances on top of SAM, we will use a multivariate Gaussian posterior q µ (θ) = N (θ | ω, V) with mean m and a diagonal covariance V = diag(s) -1 , where s is a vector of precision values (inverse of the variances). The posterior is more expressive than the isotropic Gaussian Algorithm 1 Our Bayesian-SAM (bSAM) is a simple modification of SAM with Adam . Just add the red boxes, or use them to replace the blue boxes next to them. The advantage of bSAM is that it automatically obtains variances σ 2 through the vector s that adapts the learning rate (see line 11). Two other main changes are in line 3 where Gaussian sampling is added, and line 5 where the vector s is used to adjust the perturbation . The change in line 8 improves the variance by adding the regularizer δ and a constant γ (just like Adam), while line 9 uses a Newton-like update where s is used instead of √ s. We denote by a • b, a/b and |a|, the element-wise multiplication, division and absolute value, respectively. N denotes the number of training examples. Input: Learning rate α, β 1 , L 2 -regularizer δ, SAM parameter ρ, β 2 = 0.999, γ = 10 -8 0.1 1: Initialize mean ω ← (NN weight init), scale s ← 0 1 , and momentum vector g m ← 0 2: while not converged do 3: θ ← ω + e, where e ∼ N (e | 0, σ 2 ), σ 2 ← 1/(N • s) {Sample to improve the variance} 4: g ← 1 B i∈M ∇ i (θ) where M is minibatch of B examples 5: ← ρ g g g s {Rescale by a vector s, instead of a scalar g used in SAM} 6: The BLR uses gradients with respect to µ to update λ. The pair (µ, λ) is defined similarly to Eq. 6. g ← 1 B i∈M ∇ i (ω + ) 7: g m ← β 1 g m + (1 -β 1 ) (g + δω) 8: s ← β 2 s + (1 -β 2 ) (g + δω) 2 √ s • |g| + δ + γ {Improved variance estimate} 9: ω ← ω -α g m √ s+γ g m s {Scale the We will use the update given in Khan & Rue (2021, Eq. 6) which when applied to Eq. 11 gives us, λ ← (1 -α)λ + α [∇f * * (µ) + λ 0 ] , where α > 0 is the learning rate and λ 0 = (0, -δ 0 I/2) is the natural parameter of p(θ). This is obtained by using the fact that ∇ µ D KL (q µ (θ) p(θ)) = λλ 0 ; see Khan & Rue (2021, Eq. 23). For α = 1, the update is equivalent to the difference-of-convex (DC) algorithm (Tao & An, 1997). The BLR generalizes this procedure by using α < 1, and automatically exploits the DC structure in µ-space, unlike the naive direct minimization of (12) with gradient descent. The update (14) requires ∇f * * , which is obtained by maximizing a variant of Eq. 9, that replaces the isotropic covariance by a diagonal one. For a scalable method, we solve them inexactly by using the 'local-linearization' approximation used in SAM, along with an additional approximation to do single block-coordinate descent step to optimize for the variances. The main change is in line 5, where the perturbation is adapted by using a vector s, as opposed to SAM which uses a scalar g . Modifications in line 3 and 8 are aimed to improve the variance estimates because with just Adam we may not get good variances; see the discussion in Khan et al. (2018, Sec 3.4). Finally, the modification in line 9 is due to the Newton-like updates arising from BLR (Khan & Rue, 2021) where s is used instead of √ s to scale the gradient as explained in App. E. Due to their overall structural similarity, the bSAM algorithm has a similar computational complexity to Adam with SAM. The only additional overhead compared to SAM-type methods is the computation of a random noise perturbation which is negligible compared to gradient computations. For Bayesian methods, we report the performance of the predictive distribution p(y | D) = p(y | D, θ) q(θ) dθ, which we approximate using an average over 32 models drawn from q(θ). 1 , the shaded row show that bSAM consistently improves over the baselines and is the overall best method.

4.1. ILLUSTRATION ON

0.5, while red and blue are closer to 0 and 1 respectively. We see that the bSAM result is comparable to the exact posterior and corrects the overconfident predictions of SAM. Performing a Laplace approximation around the SAM solution leads to underconfident predictions. The posteriors are visualized in Fig. 5(c ) and other details are discussed in App. F.1. We show results on another small toy example ('two moons') in App. F.2. We can also use this toy problem to see the effect of using f vs f * * in the Bayesian objective. Fig. 7 in App. F.3 shows a comparison of optimizing Eq. 11 to an optimization of the original Bayesian objective Eq. 2. For both, we use full covariance to avoid the inaccuracies arising due to a diagonal covariance assumption. We find that the posteriors are very similar for both f and f * * , indicating that the gap introduced by the relaxation is also small.

4.2. REAL DATASETS

We first perform a set of experiments without any data augmentation, which helps us to quantify the improvements obtained by the regularization induced by our Bayesian approach. This is useful in applications where it is not easy to do data-augmentation (for example, tabular or time series data). All further details about hyperparameter settings and experiments are in App. G. The results are shown in Table 1 , and the proposed bSAM method overall performs best with respect to test accuracy as well as uncertainty metrics. bSAM also works well when training smaller networks on MNIST-like datasets, see Table 6 in App. F.8. In Table 2 we show results with data augmentations (random horizontal flipping and random cropping). We consider CIFAR-10, CIFAR-100 and TinyImageNet, and still find bSAM to perform favorably, although the margin of improvement is smaller with data augmentation. We note that, for SGD on CIFAR-10 in bSAM not only improves uncertainty estimates and test accuracy, it also reduces the sensitivity of SAM to the hyperparameter ρ. This is shown in Fig. 4 for SAM-SGD, SAM-Adam and bSAM, where bSAM shows the least sensitivity to the choice of ρ. Plots for other metrics are in App. F.4. To understand the reasons behind the gains obtained with bSAM, we do several ablation studies. First, we study the effect of Gaussian sampling in line 3 in Alg. 1, which is not used in SAM. Experiments in App. F.5 show that Gaussian sampling consistently improves the test accuracy and uncertainty metrics. Second, in App. F.6, we study the effect of using posterior averaging over the quality of bSAM's predictions, and show that increasing the number of sampled models consistently improves the performance. Finally, we study the effect of "m-sharpness", where we distribute the mini-batch across several compute nodes and use a different perturbation for each node (m refers to the number of splits). This has been shown to improve the performance of both SAM (Foret et al., 2021) and Bayesian methods (Osawa et al., 2019) . In App. F.7, we show that larger values of m lead to better performance for bSAM too. In our experiments, we used m = 8. For completeness, we show how m-sharpness is implemented in bSAM in App. E.4.

5. DISCUSSION

In this paper, we show a connection between SAM and Bayes through a Fenchel biconjugate. When the expected negative loss in the Bayes objective is replaced by the biconjugate, we obtain a relaxed Bayesian objective, which involves a maximum loss, similar to SAM. We showed that SAM can be seen as optimizing the mean while keeping variance fixed. We then derived an extension where the variances are optimized using a Bayesian learning algorithm applied to the relaxation. The numerical results show that the new variant improves both uncertainty estimates and generalization, as expected of a method that combines the strengths of Bayes and SAM. We expect this result to be useful for researchers interested in designing new robust methods. The principles used in SAM are related to adversarial robustness which is different from Bayes. We show that the Fenchel conjugate is the bridge between the two, and we used it to obtain an extension of SAM. Since SAM gives an excellent performance on many deep-learning benchmark, the techniques used there will be useful in improving Bayesian methods. Finally, the use of convex duality can be useful to improve theoretical understanding of SAM-like methods. It is possible to use the results from convex-optimization literature to study and improve adversarial as well as Bayesian methods. Our work may open such directions and provide a new path for research in improving robustness. Expectations with respect to distributions arise naturally in many other problems, and often sampling is used to approximate them. Our paper gives a different way to approximate such expectations using convex optimization rather than sampling. Such problems may also benefit from our results.

APPENDIX A DERIVATION OF THE FENCHEL BICONJUGATE

Here, we give the derivation of the Fenchel Biconjugate given in Eq. 9. For an isotropic Gaussian, we have T(θ) = (θ, θθ ) and λ = ( m /σ 2 , -1 /(2σ 2 )I). We will use the following identity to simplify, λ , T(θ) = m σ 2 θ + Tr - I 2σ 2 θθ = - 1 2σ 2 m -θ 2 + 1 2σ 2 m 2 . ( ) Using this, the first term in Eq. 8 becomes λ , µ = E θ∼qµ [ λ , T(θ) ] = E θ∼qµ - 1 2σ 2 m -θ 2 + 1 2σ 2 m 2 . ( ) Similarly, for the second term in Eq. 8, we use Theorem 1 and Eq. 15 to simplify, sup µ ∈M µ , λ -E θ∼qµ [-(θ)] = sup θ∈Θ λ , T(θ) + (θ) = sup θ∈Θ - 1 2σ 2 θ -m 2 + 1 2σ 2 m 2 + (θ) = sup (m + ) - 1 2σ 2 2 + 1 2σ 2 m 2 , ( ) where in the last step we have performed a change of variables θm = . Subtracting ( 17) from ( 16), we arrive at Eq. 9. B PROOF OF THEOREM 1 Theorem 1. For a Gaussian distribution that takes the form (5), the conjugate f * (λ ) can be obtained by taking the supremum over Θ instead of M, that is, sup µ ∈M µ , λ -E θ∼q µ [-(θ)] = sup θ∈Θ T(θ), λ -[-(θ)] , assuming that (θ) is lower-bounded, continuous and majorized by a quadratic function. Moreover, there exists λ for which f * (λ ) is finite and, assuming (θ) is coercive, dom(f * ) ⊆ Ω. Proof. We first pull out the expectation in the left-hand side of Eq. 10, sup µ ∈M µ , λ -E θ∼q µ [-(θ)] = sup µ ∈M E θ∼q µ [ λ , T(θ) + (θ) =J(θ) ]. We will denote the expression inside the expectation by J(θ), and prove that the supremum over µ is equal to the supremum over θ. We will do so for two separate cases where the supremum J * = sup θ∈Θ J(θ) < ∞ is finite, and infinite respectively. When J * is finite: For this case, our derivation proceeds in two steps, where we show that both the following two inequalities are true at the same time, sup µ ∈M E θ∼q µ [J(θ)] ≤ J * , sup µ ∈M E θ∼q µ [J(θ)] ≥ J * . This is only possible when the two are equal, which will establish the result. To prove the first inequality, we take a sequence {θ t } ∞ t=1 for which J(θ t ) → J * when sending t → ∞. By the definition of supremum, there always exists such a sequence. From the sequence {θ t } ∞ t=1 we will now construct a sequence {µ t } ∞ t=1 for which the expectation of J(θ) also converges to J * . This will establish that the supremum over µ is larger or equal than the supremum over θ. It is only larger or equal, since hypothetically there could be another sequence of {µ t } ∞ t=1 which achieves higher value. We will see soon that this cannot be the case.

Published as a conference paper at ICLR 2023

We use the sequence {µ t } ∞ t=1 where µ t = (θ t , Σ t + θ t θ t ) with {Σ t } ∞ t=1 ⊂ S P ×P + is taken in the space of symmetric positive-definite matrices with the sequence such that Σ t → 0. Now, we can see that the expectation E θ∼q µ t [J(θ)] converges to the supremum as follows: lim t→∞ E θ∼q µ t [J(θ)] = lim t→∞ E ∼N (0,I) J θ t + Σ 1/2 t = E ∼N (0,I) lim t→∞ J θ t + Σ 1/2 t = sup θ∈R P J(θ). In the above steps, we used the fact that J is Gaussian-integrable and continuous. A sufficient condition for the integrability is that (θ) is lower-bounded and majorized by quadratics. With this, we have established that the supremum over µ is at least as large as the one over θ. To prove the second inequality, we will show that the supremum over µ is also less or equal than the supremum over θ. To see that, we note that J(θ ) ≤ J * for all θ ∈ Θ. This implies that for any distribution q µ , we have that E θ ∼q µ [J(θ )] ≤ J * . Since the inequality holds for any distribution, we also have that sup µ ∈M E θ ∼q µ [J(θ )] ≤ J * . Having established both inequalities, we now know that the suprema in (10) agrees whenever the supremum over θ is finite. When J * is infinite: For such cases, we will show that sup µ ∈M E θ∼q µ [J(θ)] = ∞, which will complete the proof. Again, by definition of supremum, there exists a sequence {θ t } ∞ t=1 for which J(θ t ) → ∞ when t → ∞. This means that for any M > 0 there is a parameter θ t(M ) in the sequence such that J(θ t(M ) ) > M . Now from the previous part of the proof, we know that we can construct a sequence q µ t of Gaussians whose expectation converges to J(θ t(M ) ). Since the above holds for any M > 0, the supremum over the expectation is also +∞. Conditions for existence of λ for which J * is finite: The final step of our proof is to give the conditions when there exist at λ for which J * is finite, and therefore the result is non-trivial. Since (θ) is majorized by a quadratic function, there exists a constant c ∈ R, vector b ∈ R P and matrix A ∈ S P ×P + such that the following bound holds: (θ) ≤ c + b θ + 1 2 θ Aθ. Then for the candidate λ = (-b, - 1 2 A) we know that J(θ) = (θ) + λ , T(θ) ≤ c has a finite upper bound for any θ. Hence, the supremum is finite for this choice of λ ∈ Ω. Domain of f * (λ ): For a Gaussian distribution, valid (λ 1 , λ 2 ) ∈ Ω satisfy the constraint λ 2 ≺ 0. We now show that if λ / ∈ Ω, then f * (λ ) = +∞. Since λ 2 ⊀ 0, there exists a direction d = 0 such that d λ 2 d ≥ 0 and λ 1 d ≥ 0. Taking the sequence {θ t } ∞ t=0 with θ t = t • d we see J(θ t ) = tλ 1 d + t 2 d λ 2 d + (td) ≥ (td) → ∞ as t → ∞, since coercivity of (θ) implies that (θ t ) → ∞ for any sequence {θ t } ∞ t=1 with θ t → ∞. Some additional comments on the assumptions used in the theorem are in order. 1. Any reasonable loss function that one would want to minimize is lower-bounded, so this assumption is always satisfied in practice. 2. The existence of an upper-bounding quadratic holds for functions with Lipschitz continuous gradient, which is a standard assumption in optimization known as the descent lemma. In practice, our conjugates are also applicable to cases where this assumption is not satisfied. In such a case, one can consider the conjugate function only in a local neighbourhood around the current solution by restricting the set of admissible θ in the supremum. It may not matter that the bounding quadratic will eventually intersect the loss far away outside of our current region of interest. 3. We assumed that (θ) is continuous since virtually all loss functions used in deep learning are continuous. Interestingly, if (θ) is not continuous, the conjugate function f * will involve its semicontinous relaxation. 4. The coercivity assumption on (θ) is another standard assumption in optimization. For non-coercive losses, there could exist quadratic upper-bounds which are flat in some direction. This would lead to a dual variable λ at the boundary of Ω. The above Theorem 1 also holds for other exponential family distributions such as scalar Bernoulli and Gamma distributions, but we do not give more details here because they are not relevant for connecting to SAM. There are also cases where the result does not hold. The simplest case is a Gaussian with fixed covariance. The above proof does not work there, because the closure (see e.g., Csiszár & Matúš (2005) ; Malagò et al. ( 2011)) of the family of distributions does not contain all Dirac delta distributions. Intuitively, the requirement is that we should be able to approach any delta as a limit of a sequence of members of the family, which excludes the fixed-covariance case.

C DERIVATION OF THE EQUIVALENT RELAXED OBJECTIVE (12)

Inserting the biconjugate (9) into the relaxed Bayes objective (11) yields sup µ∈M sup m,σ -sup (m + ) - 1 2σ 2 2 + E θ∼qµ - 1 2σ 2 θ -m 2 log p(θ) -log q µ (θ) , where we expanded the KLD as D KL (q µ (θ) p(θ)) = E θ∼qµ [log q µ (θ)] -E θ∼qµ [log p(θ)]. Interchanging the order of maximization and noticing that the first term does not depend on µ, we can find a closed-form expression for the maximization in µ as follows: sup µ∈M E θ∼qµ - 1 2σ 2 θ -m 2 + log p(θ) -log q µ (θ) = sup µ∈M E θ∼qµ log 1 Z exp - 1 2σ 2 θ -m 2 p(θ) -log q µ (θ) + log Z = log Z -inf µ∈M D KL q µ (θ) 1 Z exp - 1 2σ 2 θ -m 2 p(θ) = log Z. In the last step, the KLD is attained at zero with q µ (θ) = 1 Z exp -1 2σ 2 θm 2 p(θ). Finally, we compute Z as Z = exp - 1 2σ 2 θ -m 2 p(θ) dθ = (2πσ 2 ) P/2 (2πσ 2 ) -P/2 exp - 1 2σ 2 θ -m 2 p(θ) dθ = (σ 2 ) P/2 (σ 2 + 1/δ 0 ) -P/2 exp - m 2 2(σ 2 + 1 δ ) , where in the last step we inserted the normalization constant for the product of the two Gaussians N (θ | m, σ 2 I) and N (θ | 0, 1 δ0 I) as for example given in Rasmussen & Williams (2006, App. 2). This gives the following expression log Z = P 2 log(σ 2 ) - P 2 log(σ 2 + 1/δ 0 ) - 1 2(σ 2 + 1 δ0 ) m 2 = P 2 log(σ 2 ) + P 2 log(δ ) - δ 2 m 2 , ( ) where δ = 1/(σ 2 + 1/δ 0 ). Switching to a minimization over m, σ 2 and inserting the result for the exact µ-solution (18) leads to the final objective inf m,σ 2 sup (m + ) - 1 2σ 2 2 + δ 2 m 2 - P 2 log(σ 2 ) - P 2 log(δ ), which is the energy function in (12).

D PROOF OF THEOREM 2

Theorem 2. For every (ρ, δ), there exist (σ, δ ) such that arg min θ∈Θ E SAM (θ; ρ, δ) = arg min m∈Θ E relaxed (m, σ; δ ), assuming that the SAM-perturbation satisfies = ρ at a stationary point. Proof. We will show that any stationary point of E SAM is also a stationary point of E relaxed when σ and δ are chosen appropriately. The correspondence between the stationary points follows directly from the equivalences between proximal and trust-region formulations (Parikh & Boyd, 2013, Sec. 3.4) . To show this, we first write the two objectives in terms of perturbation and respectively, min m sup (m + ) + 1 2σ 2 2 + δ 2 m 2 , min θ sup (θ + ) + i{ ≤ ρ} + δ 2 θ 2 , where i{ • ≤ ρ} is the indicator function which is zero inside the 2 -ball of radius ρ and +∞ otherwise. Since these are min-max problems, we need an appropriate notion of "stationary point", which we consider here to be a local Nash equilibirum (Jin et al., 2020, Prop. 3) . At such points, the first derivatives in θ and (or ) vanish. Taking the derivatives and setting them to 0, we get the following conditions for the two problems, δ m * = -∇ (m * + * ), * = -σ 2 ∇ (m * + * ), δθ * = -∇ (θ * + * ), * = - ρ µ ∇ (θ * + * ), for some constant multiplier µ > 0. Here, we used the assumption that the constraint * ≤ ρ is active at our optimal point. Since the constraint is active, the negative gradient is an element of the (non-trivial) normal cone, which gives us the multiplier µ in the second line above. The two equations are structurally equivalent. Any pair (m * , * ) satisfying the first line also satisfies the second line for σ 2 = ρ/µ and δ = δ. The opposite is also true when using the pair (θ * , * ) in the first line, which proves the theorem. We assume that the constraint in the perturbation is active, that is, * = ρ. This assumption would be violated if there exists a local maximum within a ρ-ball around the parameter θ * . However, such a local maximum is unlikely to exist since the parameter θ * is determined by minimization and tends to lie inside a flat minimum.

E DERIVATION OF THE BSAM ALGORITHM

In this section of the appendix, we show how the natural gradient descent update (14) leads us to the proposed bSAM algorithm shown in Alg. 1. As mentioned in the main text, for bSAM we consider a generalized setting where our Gaussian distribution q µ (θ) = N (θ | ω, V) has diagonal covariance. This corresponds to the following setting: T(θ) = (θ, θθ ), µ = ω, ω 2 + diag(s) -1 , λ = s • ω, -1 2 diag(s) . ( ) Here, ω ∈ R P is the mean and diag(s) = V -1 denotes the entries of the inverse diagonal covariance matrix. All operations (squaring, multiplication) are performed entrywise.

E.1 THE GRADIENT OF THE BICONJUGATE

First, we discuss how to compute that gradient of the Fenchel biconjugate, as the BLR update (14) requires it in every iteration. The diagonal Gaussian setup leads to a slightly generalized expression for the Fenchel biconjugate function: -f * * (µ) = min m∈R P ,b∈R P + sup ∈R P (m + ) - 1 2 Σ -1 2 2 + E θ∼qµ 1 2 Σ -1 2 (θ -m) 2 = min m∈R P ,b∈R P + sup ∈R P (m + ) - 1 2 Σ -1 2 2 + 1 2 Σ -1 2 (ω -m) 2 + 1 2 tr(Σ -1 V), ( ) where Σ = diag(1/b) denotes the diagonal covariance matrix with 1/b on its diagonal. It follows from results in convex analysis, for instance from Rockafellar & Wets (1998, Proposition 11.3), that ∇f * * (µ) = (Σ -1 * m * , -1 2 b * ), that is, the gradient of the biconjugate can be constructed from the optimal solution (m * , b * ) of the optimization problem (20) . In (20) there are three optimization steps: over , m, and b. For the first two, we will make very similar approximation to what is used for SAM (Foret et al., 2021) , and simply add an additional optimization over b. To simplify computation, we will make an additional approximation. Instead of an iterative joint minimization in m and b, we will perform a single block-coordinate descent step where we first minimize in m for fixed b = s and then afterwards in b for fixed ω = m. This will give rise to a simple algorithm that works well in practice. Optimization with respect to . To simplify the supremum in we will use SAM's technique of local linearization (Foret et al., 2021) with one difference. Instead of using the current solution to linearize, we will sample a random linearization point θ ∼ N (θ | ω, diag(s) -1 ). This is preferable from a Bayesian perspective and is also used in the other Bayesian variants (Khan et al., 2018). Linearization at the sampled θ then gives us, (m + ) ≈ (θ) + ∇ (θ), m + -θ . Using this approximation in (20) , we get a closed form solution for * = Σ∇ (θ). Optimization with respect to m. To get a closed-form solution for the optimization in m, we again consider a linearized loss: (m + ) ≈ (ω + ) + ∇ (ω + ), mω . Moreover, we set = V∇ (θ), corresponding to a fixed b = s. This reduces the optimization of (20) to m * = arg min m∈R P ∇ (ω + ), m + 1 2 Σ -1/2 * (m -ω) 2 , which has closed form solution m * = ω -Σ * ∇ (ω + ) which we here leave to depend Σ * . Optimization with respect to Σ via b. Inserting * into (20) with the linearized loss, the minimization over b while fixing m = ω also has a simple closed-form solution as shown below, b * = arg min b∈R P + Σ 1/2 ∇ (θ) 2 + P i=1 b i s i = s • g 2 = |g| • √ s, where θ ∼ N (θ | m, diag(s) -1 ) and g = ∇ (θ).

E.2 SIMPLIFYING THE NATURAL-GRADIENT UPDATES

Having a way to approximate the biconjugate-gradient, we now rewrite and simplify the BLR update ( 14). The derivations are largely similar to the ones used to derive the variational online Newton Inserting the parametrization (19) and our expression ∇f * * (µ) = (Σ -1 * m * , -1 2 b * ) into the BLR update (14), we get the following updates s • ω ← (1 -α)s • ω + αΣ -1 * m * , s ← (1 -α)s + α(b * + δ 0 ), where we also used the prior λ 0 = (0, -1 2 δ 0 I). By changing the order of the updates, we can write them in an equivalent form that resembles an adaptive gradient method: s ← (1 -α)s + α(b * + δ 0 ), ω ← ω -αV Σ -1 * (ω -m * ) + δω . Published as a conference paper at ICLR 2023 Now, inserting the solutions Σ -1 * (ωm * ) = ∇ (θ + ) and b * = |∇ (θ)| • √ s from Eqs. 21 and 22 into these updates, we arrive at s ← (1 -α)s + α |∇ (θ)| • √ s + δ 0 , ω ← ω -αV (∇ (ω + ) + δ 0 ω) , where = V∇ (θ) and θ ∼ N (θ | ω, V).

E.3 FURTHER MODIFICATIONS FOR LARGE-SCALE DEEP LEARNING

Now we will consider further modifications to the updates (25) -( 26) to arrive at our bSAM method shown in Alg. 1. These are similar to the ones used for VOGN (Osawa et al., 2019) . First, we consider a stochastic approximation to deal with larger problems. To that end, we now briefly describe the general learning setup. The loss is then written as a sum of individual losses, (θ) = N i=1 i (θ) where N is the number of datapoints. The N datapoints are disjointly partitioned into K minibatches {B i } K i=1 of B examples each. We then apply a stochastic variant to the bound N 1 K K i=1 f * * i (µ) ≤ f * * (µ), where the individual functions are f i (µ) = 1 B j∈Bi E θ∼qµ [-j (θ)]. Essentially, the variant computes the gradient on an incrementally sampled f i * * rather than considering the full dataset. With s = Nŝ, δ 0 = N δ this turns (25) -( 26) into the following updates: Nŝ ← (1 -α)Nŝ + α N δ + N √ Nŝ • |g| , ω ← ω -α [N δω + N g ] /(Nŝ), where g = 1 B j∈M ∇ j (θ), θ ∼ N (θ | ω, diag(σ 2 )), σ 2 = (Nŝ) -1 and M denotes a randomly drawn minibatch. Moreover, we have g = 1 B j∈M ∇ j (ω + ) with = g/(Nŝ) = ρ g/ŝ where we introduce the step-size ρ > 0 in the adversarial step to absorb the 1/N factor and to gain additional control over the perturbation strength. Following the practical improvements of Osawa et al. (2019), we introduce a damping parameter γ > 0 and seperate step-size β 2 > 0 on the precision estimate, as well as an exponential moving average for the gradient β 1 > 0. Under these changes, and dividing out the N -factors in ( 27) -( 28) we arrive at the following method: g m ← β 1 g m + (1 -β 1 ) (δω + g ) , ŝ ← β 2 ŝ + (1 -β 2 ) δ + γ + √ Nŝ • |g| , ω ← ω -αg m /ŝ, with g and g defined as above. In practice, the method performed better without the √ N -factor in the ŝ-update. Renaming ŝ as s, the updates (29) -(31) are equivalent to bSAM (Alg. 1).

E.4 BSAM ALGORITHM WITH m-SHARPNESS

The bSAM algorithm parallelizes well onto multiple accelerators. This is done by splitting up a minibatch into m parts and computing independent perturbations for each part in parallel. The algorithm remains exactly the same as Alg. 1, with lines 3 -6 replaced by the following lines 1 -8: 1: Equally partition M into {M 1 , . . . , M m } of B examples each 2: for k = 1 . . . m in parallel do 3: θ k ← ω + e k , e k ∼ N (e k | 0, σ 2 ), σ 2 ← 1/(N • s) 4: g k ← (1/B) i∈M k ∇ i (θ k ) 5: k ← g k /s 6: g ,k ← (1/B) i∈M k ∇ i (ω + e k ) 7: end for 8: g ← (1/m) m k=1 g ,k , g ← (1/m) m k=1 g k

F ADDITIONAL DETAILS AND EXPERIMENTS

F.1 DETAILS ON THE LOGISTIC REGRESSION EXPERIMENT Fig. 5 (a) shows the binary classification data-set (red vs blue circles) we adopted from (Murphy, 2012, Ch. 8.4). The data can be linearly seperated by a linear classifier with two parameters whose decision boundary passes through zero. We show the decision boundaries obtained from the MAP solution and from samples of the exact Bayesian posterior in Fig. 5(a) . In this section, we empirically check the effect of using f * * in the Bayesian objective instead of f . To that end, we compute an exact solution to the relaxed problem as well as the original problem (2) for a full Gaussian variational family. For small problems, gradients of the biconjugate function can be easily approximated using convex optimization. To see this, notice that: ∇f * * (µ) = arg min λ ∈ Ω f * (λ ) -λ , µ = arg min λ ∈ Ω,c∈R c -λ , µ s.t. c ≥ f * (λ ). This is a optimization problem with linear cost function and convex constraints. To implement the constraints, we approximate the loss by a maximum over linear functions, which is always possible for convex losses. The linear functions are first-order approximations of the loss at L points θ i ∼ q µ drawn from the current variational distribution. Under this approximation, the constraint in (32) can be written as L seperate constraints for i ∈ {1, . . . , L} as follows: sup θ∈Θ θ λ 1 + θ λ 2 θ + (θ i ) + ∇ (θ i ), θ -θ i ≤ c ⇔ - 1 4 (λ 1 + ∇ (θ i )) λ 2 -1 (λ 1 + ∇ (θ i )) -c ≤ ∇ (θ i ), θ i -(θ i ). The convex program (32) with convex constraints (33) is then solved using the software package CVXPY (Diamond & Boyd, 2016) to obtain the gradient of the biconjugate function. Having the gradients of the biconjugate available, our posterior is obtained by the Bayesian learning rule of Khan & Rue (2021), λ ← (1 -α)λ + α λ 0 + ∇f (µ), for Bayes (2), λ 0 + ∇f * * (µ), for relaxed-Bayes (11). (34) Fig. 7 compares both solutions obtained by iterating (34) on the logistic regression problem described in App. F.1. We can see that the relaxation induces a lower-bound as expected, but the relaxed solution is still a reasonable approximation to the true posterior. We show the sensitivity of all considered SAM variants to the hyperparameter ρ in Fig. 8 . As the bSAM algorithm learns the variance vector, the method is expected to be less sensitive to the choice of ρ. We confirm this here for a ResNet-20 on the CIFAR-100 dataset without data augmentation. In the derivation of the bSAM algorithm in Sec. 3 we performed a local linearization in order to approximate the inner problem. There, we claim that performing this local linearization not at the For bSAM, we plot the markers at 100 • ρ to roughly align the minima/maxima of the curves. The proposed bSAM method adapts to shape of the loss and is overall more robust to misspecified parameter ρ while also giving the overall best performance for all four metrics (Accuracy ↑, NLL ↓, ECE ↓, AUROC ↑). mean but rather at a sampled point is preferable from a Bayesian viewpoint. A concurrent work by Liu et al. (2022) also shows that combining SAM with random sampling improves the performance. The following Table 3 Table 5 : Effect of "m-sharpness" mini-batch size for SAM-SGD and bSAM for a ResNet-20-FRN on CIFAR-100 without data augmentation. As the effective minibatch size decreases, all performance metrics (accuracy and uncertainty) tend to improve. This boost in performance is not captured by our theory, and understanding it is an interesting direction for future work.

F.8 EXPERIMENTS ON MNIST

In Table 6 we compare bSAM to several baselines for smaller neural networks on MNIST and FashionMNIST. Also in these smaller settings, bSAM performs competitively to state-of-the-art.

G CHOICE OF HYPERPARAMETERS AND OTHER DETAILS G.1 EXPERIMENTS IN TABLE 1 AND TABLE 6

In the following, we list the details of the experimental setup. First, we provide some general remarks, followed by tables of the detailed parameters used on each dataset. For all experiments, the hyperparameters are selected using a grid-search over a moderate amount of configurations to find the best validation accuracy. We always use a batch-size of B = 128. For SAM-SGD, SAM-Adam and bSAM we split each minibatch into m = 8 subbatches, for VOGN we set m = 16 and consider independently computed perturbations for each subbatch. Finally, to demonstrate that bSAM does not require an excessive amount of tuning and parameters transfer well, we use the same hyper-parameters found on CIFAR-10 for the CIFAR-100 experiments. We summarize all hyperparameters in Table 7 . MNIST. The architecture is a 784 -500 -300 -9 fully-connected multilayer perceptron (MLP) with ReLU activation functions. All methods are trained for 75 epochs. We use a cosine learning rate decay scheme, annealing the learning rate to zero. The hyperparameters of the individual methods are shown in Table 7 . For SWAG, we run SGD for 60 epochs (with the same parameters). Then collect SWAG statistics with fixed α = 0.01, β 1 = 0.95. For all SWAG runs, we set the rank to 20, and collect a sample every 100 steps for 15 epochs. FashionMNIST. The architecture is a LeNet-5 with ReLU activations. We use a cosine learning rate decay scheme, annealing the learning rate to zero. We train all methods for 120 epochs. For SWAG, we run SGD for 105 epochs and collect 15 epochs with α = 0.001 and β 1 = 0.8. CIFAR-10/100. The architecture is a ResNet-20 with filter response normalization nonlinearity adopted from Izmailov et al. (2021) . The ResNet-20 has the same number of parameters as the one used in the original publication (He et al., 2016) . We train for 180 epochs and decay the learning rate by factor 10 at epoch 100 and 120. For SWAG, we run SGD for 165 epochs and collect for 15 epochs with α = 0.0001, β 1 = 0.9.

G.2 EXPERIMENTS IN TABLE 2

The bSAM method depends on the number of data-samples N . To account for the random cropping and horizontal flipping, we rescale this number by a factor 4 which improved the performance. This corresponds to a tempered posterior, as also suggested in Osawa et al. (2019) . In all experiments, all methods use cosine learning rate schedule which anneals the learning rate across 180 epochs to zero. The batch size is again chosen as B = 128. For an overview over all the hyperparameters, please see Table 8 . 

H TABLE OF NOTATIONS

The following Table 10 provides an overview of the notations used throughout the paper. Multidimensional quantities are written in bold-face, scalars and scalar-valued functions are regular face.

Symbol Description •, •

An inner-product on a finite-dimensional vector-space.

•

The 2 -norm. f * , f * * Fenchel conjugate and biconjugate of the function f . θ Parameter vector of the model, e.g., the neural network.

T(θ)

Sufficient statistics of an exponential family distribution over the parameters. µ, µ , µ Expectation (or mean/moment) parameters of exponential family. λ, λ , λ Natural parameters of an exponential family.

Ω

Space of valid natural parameters.

M

Space of valid expectation parameters. q µ Exponential family distribution with moments µ. q λ Exponential family distribution with natural parameters λ.

A(λ)

Log-partition function of the exponential family in natural parameters. A * (µ) Entropy of q µ , the convex conjugate of the log-partition function. 



The notation used in the paper is summarized in Table10in App. H.



Higher variances give smoother objectives

Figure 1: Panel (a) highlights the main difference between SAM and Bayes. SAM uses max-loss (the red dot), while Bayes uses expected-loss (the blue dot shows (θ + ) for an ∼ N ( | 0, σ 2 )).Our main result in Theorem 2 connects the two by using the optimal convex relaxation (Fenchel biconjugate) of the negative expected loss. It shows that the role of ρ and σ are exactly the same, and a SAM minimizer obtained for a fixed ρ can always be recovered from the relaxation for some σ. An example is shown in Panel (b) for a loss (gray line) with 3 local minima indicated by A, B, and C. The expected loss is smoother (blue line) but the relaxation, which upper bounds it, is even more smooth (red line). Higher σ give smoother objectives where sharp minima A and C slowly disappear. The SAM minimizer is shown with a red star which matches the minimizer of the relaxation.

b) shows an illustration where the relaxation is shown to upper bound the expected loss. We will now describe the construction of the relaxation based on Fenchel biconjugates. Biconjugate and expected loss in (ω, v)-parameterization

Figure 2: Panel (a) shows the dual-structure of exponential-family for a Gaussian defined in Eq. 6.The space M lies above the parabola, while Ω is a half space. Panel (b) shows f (µ) and its biconjugate f * * (µ) with the blue mesh and red surface respectively. We use the loss from Fig.1(b). The f * * function perfectly matches the original loss at the boundary where v → 0 (black curve), clearly showing the regions for three minima A, B and C. Panel (c) shows the same but with respect to the standard mean-variance (ω, v) parameterization. Here, convexity of f * * is lost, but we see more clearly the negative (ω) appear at v → 0 (black curve). At this limit, we also get f (µ) = f * * (µ) = -(ω), but for higher variances, both f and f * * are smoothed version of (ω). The slices at v = 0.5, 2, and 4 give rise to the curves shown in Fig.1(b).

with natural parameter λ ∈ Ω, sufficient statistics T(θ), log-partition function A(λ), and a base measure of 1 with Ω = {λ : A(λ) < ∞}. Dual spaces naturally arise for such distributions and can be used to define conjugates. For every λ ∈ Ω, there exists a unique 'dual' parameterization in terms of the expectation parameter µ = E θ∼q λ [T(θ)]. The space of all valid µ satisfies µ = ∇A(λ), and all such µ form the interior of a dual space which we denote by M. We can also write λ in terms of µ using the inverse λ = ∇A * (µ), and in general, both µ and λ are valid parameterizations.

c) shows three pairs of blue and red lines. The blue line in each pair is a slice of the function f , while the red line is a slice of f * * . The values of v for which the functions are plotted are 0.5, 2, and 4. The slices at these variances are visualized in Fig.1(b) along with the original loss.

gradient by s instead of √ s} 10: end while 11: Return the mean m and variance σ 2 ← 1/(N • s) used in the previous section, and is expected to give better uncertainty estimates. As before, we will use a Gaussian prior p(θ) = N (θ | 0, I/δ 0 ).

Figure 3: For a 2D logistic regression, bSAM gives predictive uncertainties that are similar to those obtained by an exact Bayesian posterior. White areas show uncertain predictive probabilities around 0.5. SAM's point estimate gives overconfident predictions, while Laplace leads to underconfidence.

Figure 4: bSAM is less sensitive to the choice of ρ than both SAM-SGD and SAM-Adam. We show accuracy and NLL. Plots for the other metrics are in App. F.4.

approaches, see Khan et al. (2017; 2018); Khan & Rue (2021).

Fig. 5(b) shows the 2D probability density function of the exact Bayesian posterior over the two parameters, where the black cross indicates the mode (MAP solution), and the orange triangles the weight-samples corresponding to the classifiers shown in the left figure. In Fig. 5(c), we show the solutions found by SAM, bSAM and an approximate Bayesian posterior. On this example, the posterior approximation found by bSAM is similar to the Bayesian one. Performing a Laplace approximation at the SAM solution leads to the covariance shown by the green dashed ellipse. It overestimates the extend of the posterior as it is based only on local information.

Figure 6: bSAM improves SAM's predictive uncertainty. Using a Laplace approximation at the SAM solution can lead to an overestimation of the predictive uncertainty.

Figure 7: Fig. 7(a) compares the posterior approximations obtained by solving (2) and our lowerbound (11) for a full Gaussian variational family. Fig. 7(b) shows the gap in the evidence-lower bound (ELBO) induced by the relaxation.

ABLATION STUDY: EFFECT OF GAUSSIAN SAMPLING (LINE 3 IN ALG. 1)

Figure8: Sensitivity of SAM-variants to ρ. For bSAM, we plot the markers at 100 • ρ to roughly align the minima/maxima of the curves. The proposed bSAM method adapts to shape of the loss and is overall more robust to misspecified parameter ρ while also giving the overall best performance for all four metrics (Accuracy ↑, NLL ↓, ECE ↓, AUROC ↑).

Mean and variance of a normal distribution, corresponding to µ and λ. Mean and variance of another normal, corresponding to µ and λ . N (e|m, σ 2 I) e follows a normal distribution with mean m and covariance σ 2 I. δ Precision of an isotropic Gaussian prior.



Comparison with data augmentation. Similar to Table



confirms that the noisy linearization helps in practice.

Performing the loss-linearization at the noisy point rather than at the mean ("bSAM w/o Noise") typically improves the results with respect to all metrics.F.6 ABLATION STUDY: EFFECT OF BAYESIAN MODEL AVERAGINGIn Table4we show how the predictive performance and uncertainty of bSAM depends on the number of MC samples used to approximate the integral in the Bayesian marginalization.

Increasing the number of MC samples to approximate the marginal Bayesian predictive distribution tends to improve the performance in terms of NLL and ECE. F.7 ABLATION STUDY: DEPENDANCE OF SAM AND BSAM ON "M-SHARPNESS" It has been observed in Foret et al. (2021) that the performance of SAM depends on the number m of subbatches the minibatch is split into. This phenomenon is referred to as m-sharpness. From a theoretical viewpoint, a larger value of m corresponds to a looser lower-bound for bSAM, as the sum of biconjugates is always smaller than the biconjugate of the sum.Here we investigate the performance of SAM and bSAM on the parameter m for a ResNet-20-FRN on CIFAR-100 without data augmentation. The other parameters are chosen as before, and the batch-size is set to 128. The results are summarized in the following Table5.

Our neural network outputs the natural param-

Comparison on small datasets. eters of the categorical distribution as a minimal exponential family (number of classes minus one output neurons). The loss function is the negative log-likelihood.

Hyperparameters used in the experiments without data augmentation. "×" denotes that this method does not use that hyperparameter. " " indicates fixed hyperparameter across all datasets.

Hyperparameters used in the experiments with data augmentation shown in Table2. "×" denotes that this method does not use that hyperparameter. " " indicates fixed hyperparameter across all datasets.G.3 RESNET-18 EXPERIMENTS IN TABLE 2We trained all methods over 180 epochs with a cosine learning rate scheduler, annealing the learning rate to zero. The batch-size is set to 200, and both SAM and bSAM use m-sharpness with m = 8.

Hyperparameters used in the ResNet-18 experiments of Table2. "×" denotes that this method does not use that hyperparameter. " " indicates fixed hyperparameter across all datasets.

Notation.

ACKNOWLEDGEMENTS

We would like to thank Pierre Alquier for his valuable comments and suggestions. This work is supported by the Bayes duality project, JST CREST Grant Number JPMJCR2112.

AUTHOR CONTRIBUTIONS STATEMENT

List of Authors: Thomas Möllenhoff (T.M.), Mohammad Emtiyaz Khan (M.E.K.).T.M. proposed the Fenchel biconjugates, their connections to SAM (Theorems 1 and 2), and derived the new bSAM algorithm. M.E.K. provided feedback to simplify these. T.M. conducted all the experiments with suggestions from M.E.K. Both authors wrote the paper together.

