ANNEALED FISHER IMPLICIT SAMPLER

Abstract

Sampling from an un-normalized target distribution is an important problem in many scientific fields. An implicit sampler uses a parametric transform x = G θ (z) to push forward an easy-to-sample latent code z to obtain a sample x. Such samplers are favored for fast inference speed and flexible architecture. Thus it is appealing to train an implicit sampler for sampling from the un-normalized target. In this paper, we propose a novel approach to training an implicit sampler by minimizing the Fisher Divergence between sampler and target distribution. We find that the trained sampler works well for relatively simple targets but may fail for more complicated multi-modal targets. To improve the training for multi-modal targets, we propose another adaptive training approach that trains the sampler to gradually learn a sequence of annealed distributions. We construct the annealed distribution path to bridge a simple distribution and the complicated target. With the annealed approach, the sampler is capable of handling challenging multi-modal targets. In addition, we also introduce a few MCMC correction steps after the sampler to better spread the samples. We call our proposed sampler the Annealed Fisher Implicit Sampler (AFIS). We test AFIS on several sampling benchmarks. The experiments show that our AFIS outperforms baseline methods in many aspects. We also show in theory that the added MC correction steps get faster mixing by using the learned sampler as MCMC's initialization.

1. INTRODUCTION

Sampling from an un-normalized distribution is an important problem in many scientific fields such as Bayesian statistics (Green, 1995) , biology (Schütte et al., 1999) , physics simulations (Olsson, 1995) , machine learning (Andrieu et al., 2003) , and so on. Typically, the problem is formulated as: given a known differentiable un-normalized target potential function log p(x), one wants to sample from the target distribution. Due to the success of deep neural networks, there is increasing popularity to train a deep generative model to learn to sample (Hu et al., 2018; Wu et al., 2020; Matthews et al., 2022; Corenflos et al., 2021) . Such learned models which can approximately sample from target distribution are called samplers. Training a neural network (i.e., a parameterized transform) x = G θ (z) to push forward an easyto-sample latent code z ∼ p Z (z) to obtain a sample is an appealing approach. Such approaches are favored for fast sampling because they only need a single-time forward pass of neural network transform. Let G θ (.) denote the parametric transform and q(x) the un-normalized target distribution with unknown normalizing constant Z = q(x)dx. Let p θ (x) denote the sampler-induced distribution. Some previous work takes a normalizing flow model as sampler, and then minimizes the KL divergence between sampler-induced and target distributions regardless of normalizing constant: D KL (p θ , q) = E x∼p θ log p θ (x) -log q(x) + log Z . Note that Z is parameter-free and can be ignored during training. However, minimizing KL divergence relies on explicit log-likelihood of sampler-induced distribution, which can not be computed in a general transform. Such transform with no explicit likelihood is referred to as an implicit sampler. In this paper, we will focus on implicit samplers. Note that the annoying normalizing constant vanishes when considering the score function of a distribution, s(x) = ∇ x log p(x). Thus, we can take the score-based divergence to constructively get rid of the unknown normalizing constant for implicit samplers. Fisher divergence (FD), which is a popular score-based probability divergence, and its variants have obtained much success in recent years, especially in training deep generative models such as energy-based models (Kingma & Cun, 2010; Martens et al., 2012; Song et al., 2019) , score based diffusion models (Song et al., 2020; Kingma et al., 2021; Vahdat et al., 2021; Song & Ermon, 2019; Ho et al., 2020) , etc. Assume p(x), q(x) are two probability densities. The Fisher Divergence between p and q is defined as D F D (p, q) = 1 2 E x∼p(x) ∥∇ x log p(x) -∇ x log q(x)∥ 2 2 . It is always no less than 0 and equals to 0 if and only if p(x) = q(x) a.s. under probability measure p. Fisher Divergence is suitable for measuring the dissimilarity between sampler and un-normalized target distribution. So as to be used for training the implicit sampler. In this paper, we firstly propose a novel approach to learning a sampler by minimizing the Fisher Divergence between sampler and un-normalized target distributions. We call such a sampler the Fisher Implicit Sampler. We then show that the proposed sampler is capable of handling relatively simple target distribution, but would fail for more challenging multi-modal targets. To remedy this issue and unlock the full potential of the Fisher Implicit Sampler, we additionally propose a novel adaptive training approach that trains the implicit sampler gradually using a sequence of annealed distributions instead of the target distribution. We anneal the target distribution to bridge the hard-to-sample target and an easy-to-sample prior. More precisely, we extend the target distribution q(x) to a sequence of annealed distributions {q k (x)} k for k = 0, . . . , K, where q K (x) is the target density and q 0 (x) is an easy-to-sample prior distribution, typically a normal distribution. The design of such an annealed path gradually reduces the learning difficulty for the sampler. Moreover, we find that a few steps of MC correction after the sampler help the samples spread better with little cost, as also used in some previous work (Wu et al., 2020; Arbel et al., 2021; Matthews et al., 2022) . Combining all together, we call our proposed sampler the Annealed Fisher Implicit Sampler (AFIS), as illustrated in Figure 1 . We validate our AFIS on sampling benchmarks, showing improvements over baseline approaches. The main contributions of our work are summarized as follows: • We propose a novel loss function to minimize the Fisher Divergence. We show that minimizing the proposed loss is equivalent to minimizing the Fisher Divergence between sampler and target distribution. Note that our objective is largely different from other ones in previous work. • We provide an insightful understanding of the difficulty in learning multi-modal targets by minimizing Fisher Divergence. We facilitate the annealing technique on training samplers based on our understanding. • We bring in a novel annealing technique and MC correction steps with our sampler, leading to improved sampling performance with little additional cost.

2.1. TRAIN IMPLICIT SAMPLERS WITH SCORE-BASED DIVERGENCE

The learning-to-sample problem arises in many application fields of machine learning. Assume we only have access to an un-normalized target distribution q(x) (or its logarithm log q(x)), and the goal is to approximately sample from the target. In recent years, training a neural networkbased transform to approximately sample from target distribution is an appealing method. Such a transform is called a neural sampler. Let G θ denote a neural network which transforms a relatively simple latent code z ∼ p 0 (z) to a sample x = G θ (z). Here, p Z (z) is an easy-to-sample latent distribution, usually the standard Normal distribution. A general neural sampler does not have an explicit expression of the log-likelihood function, which we name them implicit samplers. Because of the un-normalized target distribution and unavailable log-likelihood, training implicit samplers by minimizing KL or related divergence always fails. An alternative way is to consider score-based divergence. The Stein Neural Sampler of Hu et al. (2018) is trained by minimizing Stein's Discrepancy between sampler and target distributions. The Stein Discrepancy (SD) (Gorham & Mackey, 2015) is defined as D SD (p, q) = sup f ∈F E x∼p ⟨∇ x log q(x), f (x)⟩ + ⟨∇ x , f (x)⟩ , The calculation of Stein's discrepancy relies on solving a maximization problem w.r.t. test function f . When the function class F is carefully chosen, the optimal f may have an explicit solution or easier formulation. For instance, Hu et al. (2018) found that if F is taken to be F = {f : E p ∥f ∥ 2 2 ≤ δ}, the SD is equivalent to a regularized representation D SD (p, q) = max f E x∼p ⟨∇ x log q(x), f (x)⟩ + ⟨∇ x , f (x)⟩ -λ f T f . They used two neural networks: G θ to parametrize an implicit sampler and f η to parametrize the test function. Let p θ (x) denote the implicit sampler distribution induced by x = G θ (z) with z ∼ p Z (z). Stein Neural Sampler solves a minimax problem on parameter pair (θ, η) to obtain a sampler that minimizes the SD between sampler and target by min θ max η L(θ, η) = min θ max η E x∼p θ ⟨∇ x log q(x), f η (x)⟩ + ⟨∇ x , f η (x)⟩ -λ f T η f η . Here the notion x ∼ p θ means x = G θ (z) with z ∼ p Z (z). They called the above SD the Fisher Stein Discrepancy and the corresponding sampler FSD Neural Sampler. The Stein Neural Sampler opens the door to training implicit samplers by minimizing score-based Divergence. In fact, the FSD Neural Sampler calculates a surrogate of Fisher Divergence. The FSD's test function f provides an approximation of Fisher Divergence. However, as we show in Section 3.1, their calculation of Fisher Divergence only provides partial gradient updates of the sampler's parameters, thus leading to training failure even for simple target.

2.2. SCORE FUNCTION ESTIMATION

Since the implicit sampler does not have an explicit log-likelihood function or score function, training it with score-based divergence requires inevitably estimating the score function (or equivalent component). Score matching (Hyvärinen & Dayan, 2005) and its variants provided powerful approaches to estimating score function through samples. Assume one only has available samples x ∼ p, and wants to use a parametric approximated distribution q ϕ (x) to approximate p. Such an approximation can be made by minimizing the Fisher Divergence between p and q ϕ . We can rewrite the Fisher Divergence as D F D (p, q ϕ ) = E x∼p ∥∇ x log p(x)∥ 2 2 + ∥∇ x log q ϕ (x)∥ 2 2 -2⟨∇ x log p(x), ∇ x log q ϕ (x)⟩ . Under certain conditions, the equality E x∼p ⟨∇ x log p(x), ∇ x log q ϕ (x)⟩ = -E x∼p ∆ log q ϕ (x) holds (usually referred to as Stein's Identity (Stein, 1981; Gorham & Mackey, 2017) ) . Here ∆ log q ϕ (x) = i ∂ 2 ∂x 2 i log q ϕ (x) denotes the Laplacian operator applied on log q ϕ (x). Combining this equality and noting that the first term of FD E x∼p ∥∇ x log p(x)∥ 2 2 does not rely on parameter ϕ, we have that minimizing D F D (p, q ϕ ) is equivalent to minimizing the following objective L(ϕ) = E x∼p ∥∇ x log q ϕ (x)∥ 2 2 + 2∆ log q ϕ (x) . This objective can be estimated only through samples from p, thus is tractable when q ϕ is welldefined. More specifically, one only needs to define a score network s ϕ (x) : R D → R D instead of a density to estimate the score function of p in some cases. This technique was proposed in Hyvärinen & Dayan (2005) named after Score Matching. Other variants of score matching were also studied (Song et al., 2019; Vincent, 2011; Pang et al., 2020; Meng et al., 2020; Lu et al., 2022; Bao et al., 2020) . Score Matching related techniques have been widely used in training energy-based models and score-based diffusion models in recent years. In this paper, we use score matching related techniques to estimate the score function of the sampler's distribution. 3 ANNEALED FISHER IMPLICIT SAMPLER 3.1 MINIMIZING THE FISHER DIVERGENCE: S2D LOSS Let G θ (.) : R D Z → R D X be an implicit sampler (i.e., a neural transform), p Z latent distribution, p θ sampler induced distribution x = G θ (z), and q(x) un-normalized target. Our goal is to pull close the FD between p θ and q in order to train the sampler. Recall the definition of Fisher Divergence between p θ , q is D F D (p θ , q) = E x∼p θ ∥∇ x log p θ (x) -∇ x log q(x)∥ 2 2 . For our learning-to-sample setting, the target score function ∇ x log q(x) is known. A direct solution seems work if one uses an additional score network s ϕ (.) : R D X → R D X to approximate sampler's score function. Samples from implicit sampler is cheap to obtain, so estimating sampler's score function is not hard with score matching related techniques. We call this step the Score Estimation Step. With a good approximated s ϕ (x) of sampler's score function, one may wish to minimize the approximated Fisher Divergence to update the sampler θ * = arg min θ E x=G θ (z),z∼p Z (z) ∥s ϕ (x) -∇ x log q(x)∥ 2 2 . We call this step the Score Difference Minimization Step. By alternating the above two steps, one may wish the Fisher divergence will be minimized, thus the training of sampler is done. We name the resulting approach the Direct Method. Interestingly, the Direct Method coincides with FSD Neural Sampler as we state in Proposition 1. We put detailed proof in Appendix A due to limited pages. Proposition 1. Estimating the sampler's score function s ϕ (.) with score matching is equivalent to maximizing the Fisher Stein Discrepancy objective to obtain FSD's optimal test function. More specially, the optimal score estimation s * and FSD optimal test function f * satisfy f * (x) = 1 2λ ∇ x log q(x) -s * (x) . Moreover, the Direct method is equivalent to FSD when training implicit Sampler. Although the direct method seems reasonable, it fails as we show in the experiment on a simple Banana target in Figure 2 . We find that even if sampler's score function is estimated perfectly at each iteration, the direct method still gives only partial parameter gradient for minimizing the Fisher Divergence. We start by analyzing Fisher Divergence's gradient w.r.t. sampler's parameter. The Fisher Divergence is L F D (θ) = E x∼p θ ∥∇ x log q(x) -∇ x log p θ (x)∥ 2 2 . One wants to adjust θ to minimize L F D (θ). The θ gradient of the above objective writes ∂ ∂θ E p θ ∥s d (x)-s θ (x)∥ 2 = E p θ ∥s d (x)-s θ (x)∥ 2 ∂ ∂θ log p θ (x)+E p θ 2(s θ (x)-s d (x)) T ∂ ∂θ s θ (x). The first gradient term coincides with the direct approach if we asynchronously estimate the sampler's score function perfectly. More precisely, with perfect score estimation s ϕ (x) = ∇ x log p θ (x), we have ∂ ∂θ E x∼p θ ∥∇ x log q(x) -s ϕ (x)∥ 2 2 = ∂ ∂θ ∥∇ x log q(x) -s ϕ (x)∥ 2 2 p θ (x)dx = ∥∇ x log q(x) -s ϕ (x)∥ 2 2 ∂ ∂θ p θ (x)dx = ∥∇ x log q(x) -s ϕ (x)∥ 2 2 p θ (x) ∂ ∂θ log p θ (x)dx = E x∼p θ ∥∇ x log q(x) -s ϕ (x)∥ 2 2 ∂ ∂θ log p θ (x). The above equation reveals that the direct method only takes partial gradient to minimize the FD between sampler and target. In many cases, this partial gradient leads to training failure as we observe in Figure 2 . In (Hu et al., 2018) , FSD Neural Sampler used Kernelized Stein Discrepancy trained implicit sampler as initialization before training with FSD. However, such initialization limits the usage of FSD because the optimization might start from a local minima which is close to KSD's local minima and can potentially be mislead the sampler. In order to minimize the Fisher Divergence correctly, we propose a novel training objective called Score Square Difference loss (S2D) which accounts for the full parameter gradient to minimize the Fisher Divergence. The S2D loss is defined as the difference of target and sampler's square score norm, where the sampler's score function is estimated asynchronously with a score network s ϕ (x). More precisely, our S2D loss is defined as L S2D (θ) := E x∼p θ ∥∇ x log q(x)∥ 2 2 -∥s ϕ (x)∥ 2 2 , where s ϕ (.) is the estimated score function of sampler distribution. The score function is usually estimated by score matching related techniques. The notation x ∼ p θ means x = G θ (z), z ∼ p Z (z). The following proposition 2 shows that, if the sampler score function is estimated perfectly, the parameter gradient of S2D loss is the same as the gradient of Fisher Divergence. Proposition 2. Assume s ϕ (x) = ∇ x log p θ (x). Then the following equality holds: ∂ ∂θ L S2D (θ) = ∂ ∂θ L F D (θ). We give the detailed proof in Appendix B. This proposition says that, if we alternate between score estimation of sampler's score function, and minimization of the S2D loss, we are actually minimizing the Fisher Divergence between sampler and target. The S2D loss is a surrogate of Fisher Divergence which can provide the same parameter gradient as Fisher Divergence. So minimizing the S2D loss gives the same results as minimizing the intractable Fisher Divergence. Figure 3 gives an illustration of the relation between S2D loss and Fisher Divergence. The black curve stands for the intractable Fisher Divergence. Green curve represents the S2D loss. The S2D loss shares the same gradient parameter as Fisher Divergence. We refer to a sampler trained with such approach the Fisher Implicit Sampler (FIS). We give an algorithm for FIS in Algorithm 1. We take standard score matching as an illustration of score estimation step, but other score estimation techniques such as denoising score matching and sliced score matching also works. Algorithm 1: Fisher Implicit Sampler training Input: un-normalized target log q(x), latent distribution p Z (z), implicit sampler G θ , score network s ϕ , mini-batch size B, max iteration M. Randomly initialize (θ (0) , ϕ (0) ). for t in 0:M do # update score network parameter Get mini-batch x i = G θ (t) (z i ), z i ∼ p Z (z), i = 1, .., B. Calculate score matching objective: L SM (ϕ) = 1 B B i=1 ∥s ϕ (x i )∥ 2 2 + 2⟨∇ x , s ϕ (x i )⟩ . Minimize L SM (ϕ) to get ϕ (t+1) . # update sampler parameter Get mini-batch latent code z i ∼ p Z (z), i = 1, . . . , B. Use re-parametrization trick to calculate S2D loss for sampler L S2D (θ) = 1 B B i=1 ∥∇ x log q(G θ (z i ))∥ 2 2 -∥s ϕ (t+1) (G θ (z i ))∥ 2 2 . Minimize L S2D (θ) to get θ (t+1) . end return (θ, ϕ). Figure 2 shows that our proposed FIT (S2D loss) can successfully train an implicit sampler from scratch to sample from the famous banana shape distribution. While the Direct method fails to train the correct sampler. Although FIT is capable of handling benchmark targets, we find that FIT fails on more challenging multi-modal targets with very separated modes. To remedy the multi-modal failure issues and fully unlock the potential of the S2D loss, we propose to combine the annealing techniques with FIT for multi-modal targets. The idea of annealing is widely used in sampling and stochastic optimization literature (Neal, 2001; Salimans et al., 2015; Chen et al., 2016; Doucet et al., 2001; Van Laarhoven & Aarts, 1987) . The technique constructs a distribution bridge between a relatively simple prior distribution and a complicated target. The learning (or other operations such as sampling or optimization) are gradually operated on each middle distribution from prior to the target. Typically, the annealing technique can lower the barrier of operation of the target by dispersing the difficulty to all middle distributions.

3.2. ANNEALED FISHER IMPLICIT TRAINING

By executing FIT steps repeatedly, the sampler is trained to minimize the Fisher divergence between p θ and target q. However, directly minimizing the Fisher divergence is problematic in practice. If the sampler's distribution is too dissimilar to the target, the Fisher divergence could be hard to estimate accurately as mentioned in (Wenliang & Kanagawa, 2020) . The Fisher divergence can be small under any tolerance even if two distribution are largely different in terms of KL divergence. More precisely, the Fisher Divergence is likely inaccurate if two distributions are too dissimilar. Due to this issue, the sampler might not be able to estimated the Fisher divergence accurately, making the training fail. In fact, above issue occurs a lot in real applications. Sampler is often initialized to concentrate around the origin, while the target distribution rarely concentrates around the origin. To remedy the inaccurate Score Estimation issue, we need to guide the sampler to start from learning a relatively simple target, and then the more challenging one. Based on such intuition, we introduce a gradual relaxation of target distribution. More precisely, we construct a sequence of annealed distributions {q k }, k ∈ {0, .., K} which gradually transform a relatively simple distribution q 0 to target distribution q K = q. Typically, q 0 is chosen as N (0, I) for simplicity. We let the sampler gradually learn to sample from each q k with k increasing from k = 0 to k = K. Since when k is small q k is simpler than q K , the estimation of Fisher divergence is easier. Thus the sampler can learn to approximate q k . When one gradually turns k to k = K, the sampler will gradually learn to sample from our final target q K = q. Such easy-to-hard technique is commonly known as annealing techniques. Marinari & Parisi (1992); Geyer & Thompson (1995) proved faster mixing time with temperature annealed target. Wenzel et al. (2020) utilized anneal path to connect model and posterior in Bayesian inference regime. Mandt et al. (2016) ; Huang et al. (2018) ; Fu et al. (2019) annealed the KL regularization in variational inference. D' Angelo & Fortuin (2021) proposed to anneal the target when running Stein Variational Gradient Descent algorithm for better mixing speed. Perhaps the most similar annealing approach to ours is Neal (2001); Wu et al. (2020) which construct a geometric distribution path p k (x) between a Gaussian prior and target density. We utilize a similar anneal path as in Wu et al. (2020) . In this paper, we anneal the target distribution q with a geometric interpolation starting with a standard Gaussian distribution as prior log q k (x) = λ k log q K (x) + (1 -λ k ) log q 0 (x) with q 0 = N (0, I) and 0 ≤ λ k ≤ 1 a pre-defined annealing schedule function with λ 0 = 0, λ K = 1. The score function is then linearly interpolated with ∇ x log q k (x) = λ k ∇ x log q K (x) + (1 -λ k )∇ x log q 0 (x), where ∇ x log q K (x) = ∇ x log q(x) is the target score function and ∇ x log q 0 (x) the prior score. For standard Normal prior, we have ∇ x log q 0 (x) = -x. We name our FIS sampler combined with annealing technique the Annealed Fisher Implicit Training. Because of the pages limitation, we put the full AFIS algorithm in Appendix F. By annealing the target distribution to a sequence of easier-to-learn targets, we divide the difficulty of sampler to learn one final distribution to learn sequentially from less difficult targets. Thus the sampler will not be bothered by inaccurate Fisher divergence estimation and training failure. Figure 1 gives a brief summary of how AFIS works. The Annealed Fisher Implicit Sampler is trained along annealed distributions progressively.

3.3. MONTE CARLO CORRECTION

Deterministic sampler suffers from mode-connection issue. The issue says that a deterministic transform can not fully disconnect two modes as studied in Wu et al. (2020) . Such issue limit the use of a pure deterministic sampler. Recent works show that combining stochastic corrections with deterministic transforms could improve the sampling performance (Wu et al., 2020; Song et al., 2020; Song & Ermon, 2019) . MCMC (Hastings, 1970; Roberts & Rosenthal, 1998; Xifara et al., 2014; Neal, 2011 ) is a commonly used stochastic transform family. By running MCMC, one can approximated sample from some un-normalized target distribution. Thus a few steps MCMC is a nice way to serve as stochastic corrections. In particular, after training the sampler, we take the generated samples x = G θ (z), z ∼ p 0 (z) as initialization and run several MCMC as correction steps to spread samples for better diversity. Both energy-based and score-based MCMC can be used. We take the Langevin MC as an illustration and put more details of MC corrections in Appendix C. Note that our method is not limited to these MC corrections. Langevin Dynamic Correction A set of particles is assumed to reach q(x) as a stationary distribution if it is driven by a Langevin Dynamic with local updates dX t = ∇ Xt log q(X t )/2 + dW t , where W t is standard Brownian motion. The discrete scheme of Langevin Correction is given by X (t+1) = X (t) + ϵ 2 ∇ log q(X (t) ) + √ ϵZ (t) , where Z (t) ∼ N (0; I). The Fokker-Planck equation tells that under certain conditions, q(x) is the only stationary distribution of above diffusion dynamic. About 20 updates of steps is sufficient to have good enough correction effects in practice. The combination of deterministic sampler and stochastic correction in fact gives faster mixing for MCMC. The deterministic sampler sample particles coarsely near target's high density modes. After that, the MCMC helps the particle spread better around each modes. In particular, we show that Langevin mixing time can be controlled by Fisher divergence between sampler distribution and target. Taking advantage of flexible neural network architecture, AFIS can be trained to match target score at any precision. The Theorem 1 shows that Langevin Correction's mixing time can be reduced by well trained sampler. Theorem 1. Assume the target potential log q(x) is smooth and satisfies Assume generated distribution p induced by AFIS x = G(z) is trained to match Fisher divergence under δ precision D F (p, q) ≤ δ. Then there exists a positive constant λ and a dimension-free positive constant C which only depend on target distribution q(x), such that under Langevin diffusion with initial distribution p 0 = p, dX t = ∇ log q(x)/2dt + dW t , the diffusion time lim ∥x∥2→+∞ ∥∇ log q(x)∥ 2 2 2 -∆ log q(x) = +∞. T * = max 0, 1 2λ C + log( δ ϵ ) is enough to control the KL divergence between corrected distribution p T and target q under tolerance ϵ. In practice, the AFIS can be trained to achieve any precision to match target under Fisher Divergence. The above theorem says, the better AFIS is trained, the shorter time for MC correction is needed to achieve same tolerance in terms of KL divergence. We provide the detailed proof in Appendix D.

3.4. COMBINING ALL: THE ANNEALED FISHER IMPLICIT SAMPLER

Combining the S2D loss, the annealed technique, and MC corrections, we obtain our final sampler: the Annealed Fisher Implicit Sampler (AFIS) with MC corrections. Figure 4 shows a comparison of trained sampler's samples on Double Well distribution. Double Well is a usually used bi-variate testing target with two separated modes. The figure shows that the AFIS with a few steps of MC correction gives the best samples. The AFIS with no MCMC correction can not fully separate two disjoint modes. The FIS (without annealing) fails to learn the two modes. The FSD-NS (or the Direct Method) also fails for training. To be concluded, the experiments show that S2D loss, annealed technique, and MC correction all contribute to successful learning.

4.1. AFIS FOR SYNTHETIC TARGET

For sanity check, we apply AFIS on some toy target distributions as used in Hu et al. (2018) ; Rezende & Mohamed (2015) . The anneal path p λ (x) ∝ exp(λ log p target (x) + (1 -λ) log p prior (x)) starts from a Normal distribution when λ = 0 and ends with the target when λ = 1. Let M be the number of max iterations, and t be the current training iteration. We set λ i to grow linearly from 0 to 1 when i < 9M/10. We train the sampler with real target log q(x) for rest M/10 iterations. The annealed path reduces the bar of learning to sample, resulting relatively accurate updating direction for the current sampler. The sampler is guided along the annealed path towards the target. We defer the detailed experiment settings and more results to Appendix E.1. Specifically, we visualize the sample results on three distributions with hard-to-sample characteristics such as multi-modality and periodicity, as shown in Figure 5 . It shows that samples from our AFIS+MC method perfectly match all target distributions. For quantitative comparison, we calculate the Maximum Mean Discrepancy between the pure HMC samples and all samplers' samples. The FSD-NS does not converge when training, so we omit the result of FSD-NS in comparison. Since the task focuses on training implicit samplers, we do not compare other explicit samplers. Table 1 summarizes the results of the MMD evaluation of all samplers. In all datasets, our AFIS consistently performs better than FIS. With additional MC correction steps, we always get lower MMD compared to the pure AFIS method. Table 1 : MMD (with rbf kernel) evaluation for synthetic targets. Additional 10 Langevin MC correction steps are used in AFIS+MC sampler. The lower the metric, the better the sampler. Target banana double well t1 t2 t3 FIS(ours) 1.12e-2±1.07e-3 3.51e-1±3.06e-3 4.54e-2±4.10e-3 7.48e-2±1.81e-3 5.18e-2±2.68e-3 AFIS(ours) 7.07e-4±1.72e-4 1.07e-2±1.37e-3 3.31e-3±1.13e-3 4.64e-2±2.53e-3 2.65e-2±1.91e-3 AFIS+MC(ours) 2.45e-4±1.20e-4 5.99e-3±1.33e-3 2.15e-3±8.37e-4 3.61e-2±2.67e-3 2.20e-2±1.97e-3

4.2. BAYESIAN REGRESSION

We also test our Implicit Sampler on Bayesian regression tasks as in Song et al. (2017) . HMC is a good baseline for such tasks, as pointed out in Neklyudov et al. (2020) ; Neklyudov & Welling (2022) . The inference of the Bayesian logistic regression model aims to sample from the posterior distribution. We compare FIS (no anneal), AFIS, and AFIS+MC on Australian, German, and Heart datasets. To evaluate samples' quality, we run HMC as a baseline to obtain approximated samples from target distributions and calculate Maximum Mean Discrepancy between samples from implicit samplers and HMC baseline. Table 2 shows the results of the Bayesian inference experiments. Other than FSD-NS, which always fails during training, our generators can generate high-quality samples. Moreover, annealed technique and MC correction steps further improve sample quality. Experimental details can be found in Appendix E.2. Table 2 : MMD (with rbf kernel) evaluation for posterior sampling. Additional 10 Langevin MC correction steps are used in AFIS+MC sampler. The lower the metric, the better the sampler. Posterior Australian German Heart FIS(ours) 7.99e-3±2.81e-4 1.91e-4±6.48e-6 9.84e-5±1.08e-5 AFIS(ours) 6.30e-3±2.50e-4 2.42e-6±4.02e-7 3.66e-5±1.08e-5 AFIS+MC(ours) 2.16e-3±1.08e-4 2.46e-6±3.97e-7 3.64e-5±1.07e-5

5. CONCLUSION

We have presented a novel approach for training an implicit sampler to sample from un-normalized density. Our approach minimizes the Fisher Divergence with the aid of an asynchronous score network. We show theoretically that our method can accurately minimize the Fisher Divergence for the implicit sampler, which is the first one as far as we know. Besides, our approach uses both the annealing technique and stochastic corrections for improved sampling performance. We also prove the faster mixing for MC correction. We test our approach on commonly used synthetic target generation and Bayesian regression benchmarks and observe ideal performance.

A PROOF OF PROPOSITION 1

We provide the proof of Proposition 1 here. Proof. With fixed p and known target q, the optimal test function f * has representation f * = arg min f L(f ) Where functional L(f ) has integral representation L(f ) =E x∼p ⟨∇ x log q(x), f (x)⟩ + ⟨∇ x , f (x)⟩ -λ[f T (x)f (x)] = p(x)⟨∇ x log q(x), f (x)⟩ + p(x)⟨∇ x , f (x)⟩ -λp(x)[f T (x)f (x)]dx = l(x, f , ∇f )dx. Here l(x, f , ∇f ) = p(x)⟨∇ x log q(x), f (x)⟩ + p(x)⟨∇ x , f (x)⟩ -λp(x)[f T (x)f (x)]. By Euler- Lagrange equation, the optimal function f satisfies ∂l ∂f - d dx ( ∂l ∂f ′ ) + ∂ 2 ∂x 2 ( ∂l ∂f ′′ ) = 0. By calculation, we have ∂l ∂f (x) = p(x)∇ log q(x) -2λp(x)f (x) d dx ( ∂l ∂f ′ )(x) = ∇ x p(x) ∂l ∂f ′′ (x) = 0. So the optimal f * satisfies the Euler-Lagrange equation as p(x)∇ x log q(x) -2λp(x)f (x) -∇ x p(x) = 0. Divide the both side with p(x) and note that ∇ x p(x)/p(x) = ∇ x log p(x), the equation turns to f * (x) = 1 2λ ∇ x log q(x) -∇ x log p(x) . Next consider optimal s * . The s * is obtained by minimizing the Score Matching objective, which is equivalent to minimizing the Fisher divergence between p and s induced family, thus the optimal s * (x) = ∇ x log p(x). Substitute ∇ x log p(x) with s * into f * formula, we have f * (x) = 1 2λ ∇ x log q(x) -s * (x) .

B PROOF OF PROPOSITION 2

In this section, we prove that the S2D loss and Fisher Divergence shares exactly the same parameter gradient. Proof. Let p θ denote sampler's distribution. s θ denote the true but unknown sampler's score function. q denotes the known un-normalized target. For rest of the proof, the notion ∥x∥ represents the L 2 norm of a vector in D X dimensional Euclidean space x ∈ R D X . Recall that the Fisher Divergence is defined as L F D (θ) = E x∼p θ ∥∇ x log q(x) -s θ (x)∥ 2 2 . Thus the sampler parameter gradient of Fisher Divergence writes ∂ ∂θ E p θ ∥∇ x log q(x) -s θ (x)∥ 2 = ∂ ∂θ ∥∇ x log q(x) -s θ (x)∥ 2 2 p θ (x)dx = ∥∇ x log q(x) -s θ (x)∥ 2 2 ∂ ∂θ p θ (x)dx + p θ (x) ∂ ∂θ ∥∇ x log q(x) -s θ (x)∥ 2 2 dx = E p θ ∥∇ x log q(x) -s θ (x)∥ 2 ∂ ∂θ log p θ (x) + E p θ 2(s θ (x) -∇ x log q(x)) T ∂ ∂θ s θ (x) = (1) + (2). The first term can be estimated with (1) = ∥∇ x log q(x) -s θ (x)∥ 2 ∂ ∂θ p θ (x)dx = ∂ ∂θ sg ∥∇ x log q(x) -s θ (x)∥ 2 p θ (x) = ∂ ∂θ E p θ sg ∥∇ x log q(x) -s θ (x)∥ 2 . Here the operator sg denotes stop gradient operator with respect to parameter θ. sg[f θ ] stop the parameter dependence of θ for function f , meaning that one can only evaluate f θ (x) point-wise but can not obtain the θ gradient of f θ (x). Here we stop the gradient of function ∥∇ log q(x) -s θ (x)∥ 2 , so we can use another score network s ϕ to approximate s θ point-wise, regardless of the θ parameter dependence. Next we consider the second term. The second term turns to (2) = E p θ 2(s θ (x) -∇ x log q(x)) T ∂ ∂θ s θ (x) = E p θ 2(s θ (x) -∇ x log q(x)) T ∂ ∂θ ∇ x log p θ (x) = 2 p θ (x)(s θ (x) -∇ x log q(x)) T ∂ ∂θ ∂ ∂x log p θ (x)dx = 2 p θ (x)(s θ (x) -∇ x log q(x)) T ∂ ∂θ 1 p θ (x) ∂p θ (x) ∂x dx = 2 (s θ (x) -∇ x log q(x)) T ∂ ∂θ ∂ ∂x p θ (x) dx -2 p θ (x)(s θ (x) -∇ x log q(x)) T ∂ log p θ (x) ∂x ∂ log p θ (x) ∂θ = (3) + (4). Looking at (3), we have (3) = 2 (s θ (x) -∇ x log q(x)) T ∂ ∂θ ∂ ∂x p θ (x) dx = 2 ∂ ∂θ sg (s θ (x) -∇ x log q(x)) T ∂ ∂x p θ (x) dx = 2 ∂ ∂θ ∂ ∂ϵ p θ (x + ϵv)dx, v = sg (s θ (x) -∇ x log q(x)) , ϵ = 0 = 2 ∂ ∂θ ∂ ∂ϵ p θ (x + ϵv)dx = 2 ∂ ∂θ ∂ ∂ϵ 1 = 0. Above equality holds because of p θ (x + ϵv)dx = 1 holds for all v, θ, ϵ. If we view ϵ as a shift strength parameter, the above equality recovers the first order Bartlett identity (Bartlett, 1953) . Next we turns to term (4). Note that (4) = -2 p θ (x)(s θ (x) -∇ x log q(x)) T ∂ log p θ (x) ∂x ∂ log p θ (x) ∂θ = -2 p θ (x) (s θ (x) -∇ x log q(x)) T ∂ log p θ (x) ∂x ∂ log p θ (x) ∂θ = -2 (s θ (x) -∇ x log q(x)) T ∂ log p θ (x) ∂x ∂p θ (x) ∂θ = -2 ∂ ∂θ sg (s θ (x) -∇ x log q(x)) T ∂ log p θ (x) ∂x p θ (x) = -2 ∂ ∂θ sg (s θ (x) -∇ x log q(x)) T ∂ log p θ (x) ∂x p θ (x) = -2 ∂ ∂θ E p θ (x) sg (s θ (x) -∇ x log q(x)) T ∂ log p θ (x) ∂x = -2 ∂ ∂θ E p θ (x) sg (s θ (x) -∇ x log q(x)) T sg ∂ log p θ (x) ∂x = -2 ∂ ∂θ E p θ (x) sg (∇ x log q(x) -s θ (x)) T sg s θ (x) . Combining all above, we calculate the parameter derivative as ∂ ∂θ E p θ ∥∇ x log q(x) -s θ (x)∥ 2 = (1) + (2) = (1) + (3) + (4) = ∂ ∂θ E p θ sg ∥∇ x log q(x) -s θ (x)∥ 2 + 0 -2 ∂ ∂θ E p θ (x) sg (s θ (x) -∇ x log q(x)) T sg s θ (x) = ∂ ∂θ E p θ sg ∥∇ x log q(x)∥ 2 -sg ∥s θ (x)∥ 2 . Thus the equivalent loss function L S2D (θ) = E p θ sg ∥∇ x log q(x)∥ 2 -sg ∥s θ (x)∥ 2 . Share the same parameter gradients as the Fisher divergence which is intractable. Since we only need the x gradient of sampler score function s θ (because the stop gradient operator), so we can estimate s θ (x) through another score network s ϕ (x) with samples consistently obtained from sampler. With above objective function, we could minimize the Fisher divergence between p θ and q.

C INTRODUCTION TO METROPOLIS-HASTINGS AND HAMILTONIAN CORRECTION

Assume the target distribution is p(x), the MH MCMC requires a proposal distribution p(x|x) to propose candidate samples x ∼ q(x|x). The Markov chain then accept the candidate sample with probability r = min{ p(x)q(x|x) p(x)q(x|x) , 1}. Under some conditions, the chain will eventually reach p(x) as stationary distribution. The proposal distribution can be symmetric or non-symmetric. Conditional gaussian q(x|x) = N (x; σ 2 ) is a usual choice. Proposals based on score function q(x|x) = N (x + ϵ 2 ∇ x log p(x), σ 2 ) is also popular (Xifara et al., 2014) . If one consider an auxiliary state space of (x, v) and execute the proposal in such space, the MC schedule is called Hamiltonian Monte Carlo. The Hamiltonian Monte Carlo execute a Monte Carlo dynamic in auxiliary space. With current sample X (t) . The HMC sample a momentum vector from an auxiliary distribution V (t) ∼ exp(-v T M -1 v/2). The joint sample (X (t) , V (t) ) updated by running a Hamiltonian Dynamics in joint space via dX t dt = ∂H ∂V , dV t dt = - ∂H ∂X . Here H(x, v) = -log p(x) + 1 2 v T M -1 v is the Hamiltonian of such mechanical system. HMC has many advantage that it mixes well for high-dimensional targets, and travels in joints space thus not easy to be trapped in local minima. Leap frog integrator is usually a practical choice for numerical updates (Neal, 2011) . To make Markov Chain detail balanced, additional Metropolis correction is also needed for a Hamiltonian proposal. In short words, HMC iteratively accepts new position and momentum pair (x, ṽ) with rate min 1, H(x,ṽ) H(x,v) where (x, ṽ) = LeapF rog(x, v) as approximated Hamiltonian proposal.

D PROOF OF THEOREM 1

We give the proof of Theorem 1 here. To begin with, we give a lemma to bound KL divergence with Fisher divergence as shown in Yamano (2021) Lemma 2. For fixed q, there exists a dimension-free positive constant c such that for every distribution p which is both integral and log-integral with respect to q, and p has same support as q, we have D KL ≤ c 2 D F (p, q). proof of lemma. For every 1st order smooth function f , assume both |f | 2 and ∥∇f ∥ 2 2 are integrable with respect to q, the log-Sobolev's inequality (Gross, 1975) shows that there exist a dimension-free positive constant c, such that |f | 2 log |f |q(x)dx ≤ c ∥∇f ∥ 2 q(x)dx + ∥f ∥ 2 2 log ∥f ∥ 2 2 . Here ∥f ∥ 2 2 = |f | 2 q(x)dx. Replace f = p/q, we have LHS = 1 2 (p/q) log(p/q)q = 1 2 E p log(p/q) = D KL (p, q). So we have ∇ √ p √ q = 1 2 ∇p √ p √ q -∇q √ q √ p q = 1 2 p q ∇p p - p q ∇q q = 1 2 p q ∇ log p -∇ log q . Thus the first term in RHS is c ∥∇f ∥ 2 q(x)dx = c ∥∇ p q ∥ 2 q(x)p = c 2 ∥∇ log p -∇ log q∥ 2 p = E p ∥∇ log p -∇ log q∥ 2 = D F (p, q). Note that ∥f ∥ 2 2 = |f | 2 q(x)dx = (p/q)q = p = 1. We combine both sides to conclude c 2 D KL (p, q) ≤ 1 2 D F (p, q) + 0. So we have D KL (p, q) ≤ c 2 D F (p, q), where c be another positive constant. The above lemma shows that KL divergence is upper bounded with Fisher divergence, which we are using to train the sampler. With above lemma, we can calculate mixing time for Langevin correction in proof below Proof. Assume target satisfies lim ∥x∥2→+∞ ( |∇ log q(x)| 2 2 2 -∆ log q(x)) = +∞. then their exits a constant λ > 0, such that Poincare inequality holds for each f ∈ C 1 (R d ) ∩ L 2 (q) with E q f = 0 Theorem 4.3 in Pavliotis ( 2014) λ∥f ∥ 2 L 2 (q) ≤ ∥∇f ∥ 2 L 2 (q) . Let p 0 denotes the ASS distribution, which is trained to be bounded with D F (p 0 , q) ≤ δ. By lemma, the KL between initial distribution p 0 and target q is bounded by Fisher divergence with a dimension-free constant c D KL (p 0 , q) ≤ c 2 D F (p 0 , q) ≤ δ ≤ +∞ With Poincare's inequality holds, the KL along Langevin diffusion dX t = ∇ log q(X t )/2 + dW t decays exponentially fast as in Theorem 4.6 in Pavliotis (2014) D KL (p t , q) ≤ exp(-2λt)D KL (p 0 , q) ≤ exp(-2λt) c 2 D F (p 0 , q) ≤ exp(-2λt) c 2 δ. Thus if we want D KL (p t , q) to be controlled under tolerance ϵ, we only need diffused time t to satisfies t ≥ 1 2λ log( c 2 ) + log( δ ϵ ) = 1 2λ C + log( δ ϵ ) . where we place C = log( c 2 ) to be another constant. The diffusion time must be positive, thus we take T * = max{0, 1 2λ C + log( δ ϵ ) }, and finish the proof.

E.1 SYNTHETIC TARGET

For toy 2-dimensional data experiments, we use a 3-layer MLP neural network with 200 hidden units in each layer as the sampler. The activation of the sampler is chosen as LeakyReLU non-linearity with a 0.2 coefficient. The score network is a 3-layer MLP with 200 hidden units in each layer. The activation of the score network is GELU non-linearity. When reporting the numbers in Tab 1, we compute MMD metrics based on a total of 2000 samples. We run 20 independent experiments for each target and algorithm to calculate the mean and standard deviation. Figure 6 visualizes the model capabilities of FIS, AFIS, and AFIS+MC samplers for matching three 2-dimensional target energy functions. 

E.2 BAYESIAN REGRESSION

For high-dimensional data experiments, we also use 3-layer MLP neural networks as the sampler and score network, respectively. The activation of the sampler is chosen as LeakyReLU non-linearity with a 0.2 coefficient. The activation of the score network is GELU non-linearity. For Australian and Heart distributions, we use 400 hidden units in each layer and 600 hidden units for German distribution. When reporting the numbers in Tab 2, we compute the MMD metric based on a total of 2000 samples. We run 20 independent experiments for each target and algorithm to calculate the mean and standard deviation. For the basic settings of Bayesian Regression problems, readers could refer to Song et al. (2017) for more details.



Figure 1: Illustration of proposed Annealed Fisher Implicit Sampler.

Figure 2: Direct method fails for simple Banana distribution while S2D loss succeeds.

Figure 3: S2D loss and Fisher Divergence. The S2D loss shares the same parameter gradient as Fisher Divergence if sampler's score is estimated perfectly asynchronously. Thus minimizing the S2D loss to update the sampler is equivalent to minimizing the Fisher Divergence between sampler and target.

Figure 4: Sample comparison on Double Well targets. (a) real samples; (b) samples from trained AFIS with 5 steps of HMC correction; (c) samples from trained AFIS; (d) samples from trained FIS without annealing; (e) samples from trained FSD-NS. All samplers and score networks use the same architecture.

Figure 5: Target and AFIS+MC samples.

Figure 6: Comparison between samples generated by FIS, AFIS and AFIS+MC on three 2D energy functions.

ETHICS STATEMENT

Our work proposes an approach to train an implicit sampler by minimizing Fisher Divergence between sampler and target distribution. Since the research is a fundamental methodology in machine learning, the negative consequences of the methodology seem not obvious.

REPRODUCIBILITY STATEMENT

We provide details of our approach and sampler in Appendix. We provide complete proofs of all theoretical results also in Appendix. We also propose the python code for implementation. We state that our research is reproducible.

F FULL AFIS ALGORITHM

This section gives the full Annealed Fisher Implicit Sampler training algorithm.Algorithm 2: Annealed Fisher Implicit Sampler training algorithm Input: un-normalized target log q(x), annealed schedule {λ k } K k=1 , prior distribution log q prior (x) ; latent distribution p Z (z), implicit sampler G θ , score network s ϕ , mini-batch size B, max iteration M. Randomly initialize (θ (0) , ϕ (0) ). for k in 1:K do # anneal the target set log q k (x) = λ k log q(x) + (1 -λ k ) log q prior (x) for t in 1:M do # update score network parameter Get mini-batch from samplerCalculate score matching objectiveMinimize L SM (ϕ) to get ϕ (t+1) . # update sampler parameter Get mini-batch latent code z i ∼ p Z (z), i = 1, . . . , B.Use re-parametrization trick to calculate S2D loss for samplerMinimize L S2D (θ) to get θ (t+1) . end end return (θ, ϕ).

