STABLE TARGET FIELD FOR REDUCED VARIANCE SCORE ESTIMATION IN DIFFUSION MODELS

Abstract

Diffusion models generate samples by reversing a fixed forward diffusion process. Despite already providing impressive empirical results, these diffusion models algorithms can be further improved by reducing the variance of the training targets in their denoising score-matching objective. We argue that the source of such variance lies in the handling of intermediate noise-variance scales, where multiple modes in the data affect the direction of reverse paths. We propose to remedy the problem by incorporating a reference batch which we use to calculate weighted conditional scores as more stable training targets. We show that the procedure indeed helps in the challenging intermediate regime by reducing (the trace of) the covariance of training targets. The new stable targets can be seen as trading bias for reduced variance, where the bias vanishes with increasing reference batch size. Empirically, we show that the new objective improves the image quality, stability, and training speed of various popular diffusion models across datasets with both general ODE and SDE solvers. When used in combination with EDM (Karras et al., 2022) , our method yields a current SOTA FID of 1.90 with 35 network evaluations on the unconditional CIFAR-10 generation task. The code is available at https://github.com/Newbeeer/stf 

1. INTRODUCTION

Diffusion models (Sohl-Dickstein et al., 2015; Song & Ermon, 2019; Ho et al., 2020) have recently achieved impressive results on a wide spectrum of generative tasks, such as image generation (Nichol et al., 2022; Song et al., 2021b) , 3D point cloud generation (Luo & Hu, 2021) and molecular conformer generation (Shi et al., 2021; Xu et al., 2022a) . These models can be subsumed under a unified framework in the form of Itô stochastic differential equations (SDE) (Song et al., 2021b) . The models learn time-dependent score fields via score-matching (Hyvärinen & Dayan, 2005) , which then guides the reverse SDE during generative sampling. Popular instances of diffusion models include variance-exploding (VE) and variance-preserving (VP) SDE (Song et al., 2021b) . Building on these formulations, EDM (Karras et al., 2022) provides the best performance to date. We argue that, despite achieving impressive empirical results, the current training scheme of diffusion models can be further improved. In particular, the variance of training targets in the denoising scorematching (DSM) objective can be large and lead to suboptimal performance. To better understand the origin of this instability, we decompose the score field into three regimes. Our analysis shows that the phenomenon arises primarily in the intermediate regime, which is characterized by multiple modes or data points exerting comparable influences on the scores. In other words, in this regime, the sources of the noisy examples generated in the course of the forward process become ambiguous. We illustrate the problem in Figure 1(a) , where each stochastic update of the score model is based on disparate targets. We propose a generalized version of the denoising score-matching objective, termed the Stable Target Field (STF) objective. The idea is to include an additional reference batch of examples that are used to calculate weighted conditional scores as targets. We apply self-normalized importance sampling to aggregate the contribution of each example in the reference batch. Although this process can substantially reduce the variance of training targets (Figure 1 it does introduce some bias. However, we show that the bias together with the trace-of-covariance of the STF training targets shrinks to zero as we increase the size of the reference batch. Experimentally, we show that our STF objective achieves new state-of-the-art performance on CIFAR-10 unconditional generation when incorporated into EDM (Karras et al., 2022) . The resulting FID score (Heusel et al., 2017) is 1.90 with 35 network evaluations. STF also improves the FID/Inception scores for other variants of score-based models, i.e., VE and VP SDEs (Song et al., 2021b) , in most cases. In addition, it enhances the stability of converged score-based models on CIFAR-10 and CelebA 64 2 across random seeds, and helps avoid generating noisy images in VE. STF accelerates the training of score-based models (3.6⇥ speed-up for VE on CIFAR-10) while obtaining comparable or better FID scores. To the best of our knowledge, STF is the first technique to accelerate the training process of diffusion models. We further demonstrate the performance gain with increasing reference batch size, highlighting the negative effect of large variance. Our contributions are summarized as follows: (1) We detail the instability of the current diffusion models training objective in a principled and quantitative manner, characterizing a region in the forward process, termed the intermediate phase, where the score-learning targets are most variable (Section 3). (2) We propose a generalized score-matching objective, stable target field, which provides more stable training targets (Section 4). (3) We analyze the behavior of the new objective and prove that it is asymptotically unbiased and reduces the trace-of-covariance of the training targets by a factor pertaining to the reference batch size in the intermediate phase under mild conditions (Section 5). (4) We illustrate the theoretical arguments empirically and show that the proposed STF objective improves the performance, stability, and training speed of score-based methods. In particular, it achieves the current state-of-the-art FID score on the CIFAR-10 benchmark when combined with EDM (Section 6).

2. BACKGROUND ON DIFFUSION MODELS

In diffusion models, the forward processfoot_0 is an SDE with no learned parameter, in the form of: dx = f (x, t)dt + g(t)dw, where x 2 R d with x(0) ⇠ p 0 being the data distribution, t 2 [0, 1], f : R d ⇥ [0, 1] ! R d , g : [0, 1] ! R, and w 2 R d is the standard Wiener process. It gradually transforms the data distribution to a known prior as time goes from 0 to 1. Sampling of diffusion models is done via a corresponding reverse-time SDE (Anderson, 1982) : dx = ⇥ f (x, t) g(t) 2 r x log p t (x) ⇤ d t + g(t)d w, where • denotes time traveling backward from 1 to 0. Song et al. (2021b) proposes a probability flow ODE that induces the same marginal distribution p t (x) as the SDE: dx = ⇥ f (x, t) 1 2 g(t) 2 r x log p t (x) ⇤ d t. Both formulations progressively recover p 0 from the prior p 1 . We estimate the score of the transformed data distribution at time t, r x log p t (x), via a neural network, s ✓ (x, t). Specifically, the training objective is a weighted sum of the denoising score-matching (Vincent, 2011) : min ✓ E t⇠qt(t) (t)E x⇠p0 E x(t)⇠p t|0 (•|x) ⇥ ks ✓ (x(t), t) r x(t) log p t|0 (x(t)|x)k 2 2 ⇤ , where q t is the distribution for time variable, e.g., U [0, 1] for VE/VP (Song et al., 2021b ) and a log-normal distribution for EDM Karras et al. (2022) , and (t) =foot_1 t is the positive weighting function to keep the time-dependent loss at the same magnitude (Song et al., 2021b) , and p t|0 (x(t)|x) is the transition kernel denoting the conditional distribution of x(t) given x 2 . Specifically, diffusion models "destroy" data according to a diffusion process utilizing Gaussian transition kernels, which result in p t|0 (x(t)|x) = N (µ t , 2 t I). Recent works (Xu et al., 2022b; Rissanen et al., 2022) have also extended the underlying principle from the diffusion process to more general physical processes where the training objective is not necessarily score-related.

3. UNDERSTANDING THE TRAINING TARGET IN SCORE-MATCHING OBJECTIVE

The vanilla denoising score-matching objective at time t is: `DSM (✓, t) = E p0(x) E p t|0 (x(t)|x) [ks ✓ (x(t), t) r x(t) log p t|0 (x(t)|x)k 2 2 ], where the network is trained to fit the individual targets r x(t) log p t|0 (x(t)|x) at (x(t), t) -the "influence" exerted by clean data x on x(t). We can swap the order of the sampling process by first sampling x(t) from p t and then x from p 0|t (•|x(t)). Thus, s ✓ has a closed form minimizer: s ⇤ DSM (x(t), t) = E p 0|t (x|x(t)) [r x(t) log p t|0 (x(t)|x)] = r x(t) log p t (x(t)). The score field is a conditional expectation of r x(t) log p t|0 (x(t)|x) with respect to the posterior distribution p 0|t . In practice, a Monte Carlo estimate of this target can have high variance (Owen, 2013; Elvira & Martino, 2021) . In particular, when multiple modes of the data distribution have comparable influences on x(t), p 0|t (•|x(t)) is a multi-mode distribution, as also observed in Xiao et al. (2022) . Thus the targets r x(t) log p t|0 (x(t)|x) vary considerably across different x and this can strongly affect the estimated score at (x(t), t), resulting in slower convergence and worse performance in practical stochastic gradient optimization (Wang et al., 2013) . To quantitatively characterize the variations of individual targets at different time, we propose a metric -the average trace-of-covariance of training targets at time t:  V DSM (t) = E pt(x(t)) h Tr(Cov p 0|t (x|x(t)) (r x(t) log p t|0 (x(t)|x))) i = E pt(x(t)) E p 0|t (x|x(t)) ⇥ kr x(t) log p t|0 (x(t)|x)) r x(t) log p t (x(t))k 2 2 ⇤ . ( ) Phases: 1 2 3 (a) ODE Sampling (b) VDSM(t) versus t (x(t)|x) = N ⇣ x, 2 m ( M m ) 2t I ⌘ for some m and M (Song et al., 2021b) . V DSM (t) exhibits similar phase behavior across t in both toy and realistic cases. Moreover, V DSM (t) reaches its maximum value in the intermediate phase, demonstrating the large variations of individual targets. We defer more details to Appendix C.

4. TREATING SCORE AS A FIELD

The vanilla denoising score-matching approach (Equation 3) can be viewed as a Monte Carlo estimator, i.e., r x(t) log p t (x(t)) = E p 0|t (x|x(t)) [r x(t) log p t|0 (x(t)|x)] ⇡ 1 n P n i=1 r x(t) log p t|0 (x(t)|x i ) where x i is sampled from p 0|t (•|x(t)) and n = 1. The variance of a Monte Carlo estimator is proportional to 1 n , so we propose to use a larger batch (n) to counter the high variance problem described in Section 3. Since sampling directly from the posterior p 0|t is not practical, we first apply importance sampling with the proposal distribution p 0 . Specifically, we sample a large reference batch B L = {x i } n i=1 ⇠ p n 0 and get the following approximation: r x(t) log p t (x(t)) ⇡ 1 n n X i=1 p 0|t (x i |x(t)) p 0 (x i ) r x(t) log p t|0 (x(t)|x i ). The importance weights can be rewritten as p 0|t (x|x(t))/p 0 (x) = p t|0 (x(t)|x)/p t (x(t)). However, this basic importance sampling estimator has two issues. The weights now involve an unknown normalization factor p t (x(t)) and the ratio between the prior and posterior distribution can be large in high dimensional spaces. To remedy these problems, we appeal to self-normalization techniques (Hesterberg, 1995) to further stabilize the training targets: r x(t) log p t (x(t)) ⇡ n X i=1 p t|0 (x(t)|x i ) P n j=1 p t|0 (x(t)|x j ) r x(t) log p t|0 (x(t)|x i ). We term this new training target in Equation 5as Stable Target Field (STF). In practice, we sample the reference batch B L = {x i } n i=1 from p n 0 and obtain x(t) by applying the transition kernel to the "first" training data x 1 . Taken together, the new STF objective becomes: `STF (✓, t) = E {xi} n i=1 ⇠p n 0 E x(t)⇠p t|0 (•|x1) " s ✓ (x(t), t) n X k=1 p t|0 (x(t)|x k ) P n j=1 p t|0 (x(t)|x j ) r x(t) log p t|0 (x(t)|x k ) 2 2 # . When n = 1, STF reduces to the vanilla denoising score-matching (Equation 2). When n > 1, STF incorporates a reference batch to stabilize training targets. Intuitively, the new weighted target assigns larger weights to clean data with higher influence on x(t), i.e., higher transition probability p t|0 (x(t)|x). Similar to our analysis in Section 3, we can again swap the sampling process in Equation 6 so that, for a perturbation x(t), we sample the reference batch B L = {x i } n i=1 from p 0|t (•|x(t))p n 1 0 , where the first element involves the posterior, and the rest follow the data distribution. Thus, the minimizer of the new objective (Equation 6) is (derivation can be found in Appendix B.1) s ⇤ STF (x(t), t) = E x1⇠p 0|t (•|x(t)) E {xi} n i=2 ⇠p n 1 0 " n X k=1 p t|0 (x(t)|x k ) P j p t|0 (x(t)|x j ) r x(t) log p t|0 (x(t)|x k ) # . Note that although STF significantly reduces the variance, it introduces bias: the minimizer is no longer the true score. Nevertheless, in Section 5, we show that the bias converges to 0 as n ! 1, while reducing the trace-of-covariance of the training targets by a factor of n when p 0|t ⇡ p 0 . We further instantiate the STF objective (Equation 6) with transition kernels in the form of p t|0 (x(t)|x) = N (x, 2 t I), which includes EDM (Karras et al., 2022) , VP (through reparameterization) and VE (Song et al., 2021b ): E x1⇠p 0|t (•|x(t)) E {xi} n i=2 ⇠p n 1 0 2 4 s ✓ (x(t), t) 1 2 t n X k=1 exp ⇣ kx(t) x k k 2 2 2 2 t ⌘ P j exp ⇣ kx(t) xj k 2 2 2 2 t ⌘ (x k x(t)) 2 2 3 5 . To  i (t i ): v B L (x i (t i )) = P x2B L p t i |0 (xi(ti)|x) P y2B L p t i |0 (xi(ti)|y) r xi(ti) log p ti|0 (x i (t i )|x) Calculate the loss: L(✓) = 1 |B| P |B| i=1 (t i )ks ✓ (x i (t i ), t i ) v B L (x i (t i ))k 2 2 Update the model parameter: ✓ = ✓ ⌘rL(✓) end for return s ✓

5. ANALYSIS

In this section, we analyze the theoretical properties of our approach. In particular, we show that the new minimizer s ⇤ STF (x(t), t) (Equation 7) converges to the true score asymptotically (Section 5.1). Then, we show that the proposed STF reduces the trace-of-covariance of training targets propositional to the reference batch size in the intermediate phase, with mild conditions (Section 5.2).

5.1. ASYMPTOTIC BEHAVIOR

Although in general s ⇤ STF (x(t), t) 6 = r x(t) log p t (x(t)), the bias shrinks toward 0 with a increasing n. In the following theorem we show that the minimizer of STF objective at (x(t), t), i.e., s ⇤ STF (x(t), t), is asymptotically normal when n ! 1. Theorem 1. Suppose 8t 2 [0, 1], 0 < t < 1, then p n s ⇤ STF (x(t), t) r x(t) log p t (x(t)) d ! N ✓ 0, Cov(r x(t) p t|0 (x(t)|x)) p t (x(t)) 2 ◆ (8) We defer the proof to Appendix B.2. The theorem states that, for commonly used transition kernels, s ⇤ STF (x(t), t) r x(t) log p t (x(t)) converges to a zero mean normal, and larger reference batch size (n) will lead to smaller asymptotic variance. As can be seen in Equation 8, when n ! 1, s ⇤ STF (x(t), t) highly concentrates around the true score r x(t) log p t (x(t)).

5.2. TRACE OF COVARIANCE

We now highlight the small variations of the training targets in the STF objective compared to the DSM. As done in Section 3, we study the trace-of-covariance of training targets in STF: V STF (t) = E pt(x(t)) " Tr Cov p 0|t (•|x(t))p n 1 0 n X k=1 p t|0 (x(t)|x k ) P j p t|0 (x(t)|x j ) r x(t) log p t|0 (x(t)|x k ) !!# . In the following theorem we compare V STF with V DSM . In particular, we can upper bound V STF (t) by Theorem 2. Suppose 8t 2 [0, 1], 0 < t < 1, then V STF (t)  1 n 1 V DSM (t) + p 3d 2 t q E pt(x(t)) D f p 0 (x) k p 0|t (x|x(t)) ! + O ✓ 1 n 2 ◆ , where D f is an f-divergence with f (y) =

⇢

(1/y 1) 2 (y < 1.5) 8y/27 1/3 (y 1.5) . Further, when n d and p 0|t (x|x(t)) ⇡ p 0 (x) for all x(t), V STF (t) / VDSM(t) n 1 . We defer the proof to Appendix B.3. The second term that involves f -divergence D f is necessary to capture how the coefficients, i.e., p t|0 (x(t)|x k )/ P j p t|0 (x(t)|x j ) used to calculate the weighted score target, vary across different samples x(t). This term decreases monotonically as a function of t. In Phase 1, p 0|t (x|x(t)) differs substantially from p 0 (x) and the divergence term D f dominates. In contrast to the upper bound, both V STF (t) and V DSM (t) have minimal variance at small values of t since the training target is always dominated by one x. The theorem has more relevance in Phase 2, where the divergence term decreases to a value comparable to V DSM (t). In this phase, we empirically observe that the ratio of the two terms in the upper bound ranges from 10 to 100. Thus, when we use a large reference batch size (in thousands), the theorem implies that STF offers a considerably lower variance (by a factor of 10 or more) relative to the DSM objective. In Phase 3, the second term vanishes to 0, as p t ⇡ p t|0 with large t for commonly used transition kernels. As a result, STF reduces the average trace-of-covariance of the training targets by at least n 1 times in the far field. Together, we demonstrate that the STF targets have diminishing bias (Theorem 1) and are much more stable during training (Theorem 2). These properties make the STF objective more favorable for diffusion models training with stochastic gradient optimization.

6. EXPERIMENTS

In this section, we first empirically validate our theoretical analysis in Section 5, especially for variance reduction in the intermediate phase (Section 6.1). Next, we show that the STF objective improves various diffusion models on image generation tasks in terms of image quality (Section 6.2). In particular, STF achieves state-of-the-art performance on top of EDM. In addition, we demonstrate that STF accelerates the training of diffusion models (Section 6.3), and improves the convergence speed and final performance with an increasing reference batch size (Section 6.3). The proposed Algorithm 1 utilizes a large reference batch to calculate the stable target field instead of the individual target. In addition to the theoretical analysis in Section 5, we provide further empirical study to characterize the intermediate phase and verify the variance reduction effects by STF. Apart from V (t), we also quantify the average divergence between the posterior p 0|t (•|x(t)) and the data distribution p 0 at time t (introduced in Theorem 2):

6.1. VARIANCE REDUCTION IN THE INTERMEDIATE PHASE

D(t) = E pt(x(t)) ⇥ D f p 0|t (x|x(t)) k p 0 (x) ⇤ . Intuitively, the number of high-density modes in p 0|t (•|x(t)) grows as D(t) decreases. To investigate their behaviors, we construct two synthetic datasets: (1) a 64-dimensional mixture of two Gaussian components (Two Gaussians), and (2) a subset of 1024 images of CIFAR-10 (CIFAR-10-4096). Figure 3 (a) and Figure 3(b) show the behaviors of V DSM (t) and D(t) on Two Gaussian and CIFAR-10-4096. In both settings, V DSM (t) reaches its peak in the intermediate phase (Phase 2), while D(t) gradually decreases over time. These results agree with our theoretical understanding from Section 3. In Phase 2 and 3, several modes of the data distribution have noticeable influences on the scores, but only in Phase 2 are the influences much more distinct, leading to high variations of the individual target r x(t) log p t|0 (x(t)|x), x ⇠ p 0|t (•|x(t)). Recall that when n = 1, STF degenerates to individual target and V STF (t) = V DSM (t). We observe that V STF (t) decreases when enlarging n. In particular, the predicted relation V STF (t) / V DSM (t)/(n 1) in Theorem 2 holds for the two Gaussian datasets where D f is small. On the high dimensional dataset CIFAR-10-4096, the stable target field can still greatly reduce the training target variance with large reference batch sizes n. We demonstrate the effectiveness of the new objective on image generation tasks. We consider CIFAR-10 ( Krizhevsky et al., 2009) and CelebA 64 ⇥ 64 (Yang et al., 2015) datasets. We set the reference batch size n to 4096 (CIFAR-10) and 1024 (CelebA 64 2 ). We choose the current state-ofthe-art score-based method EDM (Karras et al., 2022) as the baseline, and replace the DSM objective with our STF objective during training. We also apply STF to two other popular diffusion models, VE/VP SDEs (Song et al., 2021b) . For a fair comparison, we directly adopt the architectures and the hyper-parameters in Karras et al. (2018) and Song et al. (2021b) for EDM and VE/VP respectively. In particular, we use the improved NCSN++/DDPM++ models (Karras et al., 2022) in the EDM scheme.

6.2. IMAGE GENERATION

To highlight the stability issue, we train three models with different seeds for VE on CIFAR-10. We provide more experimental details in Appendix D.1. Numerical Solver. The reverse-time ODE and SDE in scored-based models are compatible with any general-purpose solvers. We use the adaptive solver RK45 method (Dormand & Prince, 1980; Song et al., 2021b ) (RK45) for VE/VP and the popular DDIM solver (Song et al., 2021a) for VP. We adopt Heun's 2nd order method (Heun) and the time discretization proposed by Karras et al. (2022) for EDM. For SDEs, we apply the predictor-corrector (PC) sampler used in (Song et al., 2021b) . We denote the methods in a objective-sampler format, i.e., A-B, where A 2 {DSM, STF} and B 2 {RK45, PC, DDIM, Heun}. We defer more details to Appendix D.2. Results. For quantitative evaluation of the generated samples, we report the FID scores (Heusel et al., 2017) (lower is better) and Inception (Salimans et al., 2016) (higher is better). We measure the sampling speed by the average NFE (number of function evaluations). We also include the results of several popular generative models (Karras et al., 2020; Ho et al., 2020; Song & Ermon, 2019; Xu et al., 2022b) for reference. Table 1 and Table 2 report the sample quality and the sampling speed on unconditional generation of CIFAR-10 and CelebA 64 2 . Our main findings are: (1) STF achieves new state-ofthe-art FID scores for unconditional generation on CIFAR-10 benchmark. As shown in Ta-ble 1, The STF objective obtains a FID of 1.90 when incorporated with the EDM scheme. To the best of our knowledge, this is the lowest FID score on the unconditional CIFAR-10 generation task. In addition, the STF objective consistently improves the EDM across the two architectures. (2) The STF objective improves the performance of different diffusion models. We observe that the STF objective improves the FID/Inception scores of VE/VP/EDM on CIFAR-10, for most ODE and SDE samplers. STF consistently provides performance gains for VE across datasets. Remarkably, our objective achieves much better sample quality using ODE samplers for VE, with an FID score gain of 3.39 on CIFAR-10, and 2.22 on Celeba 64 2 . For VP, STF provides better results on the popular DDIM sampler, while suffering from a slight performance drop when using the RK45 sampler. (3) The STF objective stabilizes the converged VE model with the RK45 sampler. In Appendix E.1, we report the standard deviations of performance metrics for converged models with different seeds on CIFAR-10 with VE. We observe that models trained with the STF objective give more consistent results, with a smaller standard deviation of used metrics. We further provide generated samples in Appendix F. One interesting observation is that when using the RK45 sampler for VE on CIFAR-10, the generated samples from the STF objective do not contain noisy images, unlike the vanilla DSM objective. The variance-reduction techniques in neural network training can help to find better optima and achieve faster convergence rate (Wang et al., 2013; Defazio et al., 2014; Johnson & Zhang, 2013) . In Figure 4 , we demonstrate the FID scores every 50k iterations during the course of training. Since our goal is to investigate relative performance during the training process, and because the FID scores computed on 1k samples are strongly correlated with the full FID scores on 50k sample (Song & Ermon, 2020) , we report FID scores on 1k samples for faster evaluations. We apply ODE samplers for FID evaluation, and measure the training time on two NVIDIA A100 GPUs. For a fair comparison, we report the average FID scores of models trained by the DSM and STF objective on VE versus the wall-clock training time (h). For CelebA 64 2 datasets, the training time improvement is less significant than on CIFAR-10. Our hypothesis is that the STF objective is more effective when there are multiple well-separated modes in data distribution, e.g., the ten classes in CIFAR-10, where the DSM objective suffer from relatively larger variations in the intermediate phase. In addition, the converged models have better final performance when pairing with the STF on both datasets.

6.4. EFFECTS OF THE REFERENCE BATCH SIZE

According to our theory (Theorem 2), the upper bound of the trace-of-covariance of the STF target decreases proportionally to the reference batch size. Here we study the effects of the reference batch size (n) on model performances during training. The FID scores are evaluated on 1k samples using the RK45 sampler. As shown in Figure 5 , models converge faster and produce better samples when increasing n. It suggests that smaller variations of the training targets can indeed speed up training and improve the final performances of diffusion models.

7. RELATED WORK

Different phases of diffusion models. The idea of diffusion models having different phases has been explored in prior works though the motivations and definitions vary (Karras et al., 2022; Choi et al., 2022) . Karras et al. (2022) Importance sampling. The technique of importance sampling has been widely adopted in machine learning community, such as debiasing generative models (Grover et al., 2019) , counterfactual learning (Swaminathan & Joachims, 2015) and reinforcement learning (Metelli et al., 2018) . Prior works using importance sampling to improve generative model training include reweighted wakesleep (RWS) (Bornschein & Bengio, 2014) and importance weighted autoencoders (IWAE) (Burda et al., 2015) . RWS views the original wake-sleep algorithm (Hinton et al., 1995) as importance sampling with one latent variable, and proposes to sample multiple latents to obtain gradient estimates with lower bias and variance. IWAE utilizes importance sampling with multiple latents to achieve greater flexibility of encoder training and tighter log-likelihood lower bound compared to the standard variational autoencoder (Kingma & Welling, 2013; Rezende et al., 2014) . Variance reduction for Fisher divergence. One popular approach to score-matching is to minimize the Fisher divergence between true and predicted scores (Hyvärinen & Dayan, 2005) . Wang et al. (2020) links the Fisher divergence to denoising score-matching (Vincent, 2011) and studies the large variance problem (in O(1/ 4 t )) of the Fisher divergence when t ! 0. They utilize a control variate to reduce the variance. However, this is typically not a concern for current diffusion models as the time-dependent objective can be viewed as multiplying the Fisher divergence by (t) = 2 t , resulting in a finite-variance objective even when t ! 0.

8. CONCLUSION

We identify large target variance as a significant training issue affecting diffusion models. We define three phases with distinct behaviors, and show that the high-variance targets appear in the intermediate phase. As a remedy, we present a generalized score-matching objective, Stable Target Field (STF), whose formulation is analogous to the self-normalized importance sampling via a large reference batch. Albeit no longer an unbiased estimator, our proposed objective is asymptotically unbiased and reduces the trace-of-covariance of the training targets, which we demonstrate theoretically and empirically. We show the effectiveness of our method on image generation tasks, and show that STF improves the performance, stability, and training speed over various state-of-the-art diffusion models. Future directions include a principled study on the effect of different reference batch sampling procedures. Our presented approach is uniformly sampling from the whole dataset {x i } n i=2 ⇠ p n 1 0 , so we expect that training diffusion models with a reference batch of more samples in the neighborhood of x 1 (the sample from which x(t) is perturbed) would lead to an even better estimation of the score field. Moreover, the three-phase analysis can effectively capture the behaviors of other physics-inspired generative models, such as PFGM (Xu et al., 2022b) or the more advanced PFGM++ (Xu et al., 2023) . Therefore, we anticipate that STF can enhance the performance and stability of these models further.



For simplicity, we focus on the version where the diffusion coefficient g(t) is independent of x(t). We omit "(0)" from x(0) when there is no ambiguity.



(b)), especially in the intermediate regime,

Figure 1: Illustration of differences between the DSM objective and our proposed STF objective. The "destroyed" images (in blue box) are close to each other while their sources (in red box) are not. Although the true score in expectation is the weighted average of v i , the individual training updates of the DSM objective have a high variance, which our STF objective reduces significantly by including a large reference batch (yellow box).

Figure 2: (a): Illustration of the three phases in a two-mode distribution. (b): Estimated V DSM (t) for two distributions. We normalize the maximum value to 1 for illustration purposes.

We use V DSM (t) to define three successive phases relating to the behavior of training targets. As shown in Figure2(a), the three phases partition the score field into near, intermediate, and far regimes (Phase 1⇠3 respectively). Intuitively, V DSM (t) peaks in the intermediate phase (Phase 2), where multiple distant modes in the data distribution have comparable influences on the same noisy perturbations, resulting in unstable targets. In Phase 1, the posterior p 0|t concentrates around one single mode, thus low variation. In Phase 3, the targets remain similar across modes since lim t!1 p t|0 (x(t)|x) ⇡ p 1 for commonly used transition kernels.We validate this argument empirically in Figure2(b), which shows the estimated V DSM (t) for a mixture of two Gaussians as well as a subset of CIFAR-10 dataset(Krizhevsky et al., 2009) for a more realistic setting. Here we use VE SDE, i.e., p t|0

Figure 3: (a, b): V DSM (t) and D(t) versus t. We normalize the maximum values to 1 for illustration purposes. (c, d): V STF (t) with a varying reference batch size n.

Figure3(c) and Figure3(d) further show the relationship between V STF (t) and the reference batch size n. Recall that when n = 1, STF degenerates to individual target and V STF (t) = V DSM (t). We observe that V STF (t) decreases when enlarging n. In particular, the predicted relation V STF (t) / V DSM (t)/(n 1) in Theorem 2 holds for the two Gaussian datasets where D f is small. On the high dimensional dataset CIFAR-10-4096, the stable target field can still greatly reduce the training target variance with large reference batch sizes n.

6.3 ACCELERATING TRAINING OF DIFFUSION MODELS (a) CIFAR-10 (b) CelebA 64 ⇥ 64

Figure 4: FID and generated samples throughout training on (a) CIFAR-10 and (b) CelebA 64 2 .

Figure 5: FID scores in the training course with varying reference batch size.The STF objective achieves better FID scores with the same training time, although the calculation of the target field by the reference batch introduces slight overhead (Algorithm 1). In Figure4(a), we show that the STF objective drastically accelerates the training of diffusion models on CIFAR-10. The STF objective achieves comparable FID scores with 3.6⇥ less training time (25h versus 90h). For CelebA 64 2 datasets, the training time improvement is less significant than on CIFAR-10. Our hypothesis is that the STF objective is more effective when there are multiple well-separated modes in data distribution, e.g., the ten classes in CIFAR-10, where the DSM objective suffer from relatively larger variations in the intermediate phase. In addition, the converged models have better final performance when pairing with the STF on both datasets.

aggregate the time-dependent STF objective over t, we sample the time variable t from the training distribution q t and apply the weighting function (t). Together, the final training objective for STF is E t⇠qt(t) [ (t)`S TF (✓, t)]. We summarize the training process in Algorithm 1. The small batch size |B| is the same as the normal batch size in the vanilla training process. We defer specific use cases of STF objectives combined with various popular diffusion models to Appendix A. Training iteration T , Initial model s ✓ , dataset D, learning rate ⌘. for t = 1 . . . T do Sample a large reference batch B L from D, and subsample a small batch B = {x i } Uniformly sample the time {t i } |B| i=1 ⇠ q t (t) |B| Obtain the batch of perturbed samples {x i (t i )} |B| i=1 by applying the transition kernel p t|0 on B Calculate the stable target field of B L for all x

CIFAR-10 sample quality (FID, Inception) and number of function evaluation (NFE).

FID and NFE on CelebA 64 2

argues that the training targets are difficult and unnecessary to learn in the very near field (small t in our Phase 1), whereas the training targets are always dissimilar to the true targets in the intermediate and far field (our Phase 2 and Phase 3). As a result, their solution is sampling t with a log-normal distribution to emphasize the relevant region (relatively large t in our Phase 1). In contrast, we focus on reducing large training target variance in the intermediate and far field, and propose STF to better estimate the true target (cf.Karras et al. (2022)).Choi et al. (2022) identifies a key region where the model learns perceptually rich contents, and determines the training weights (t) based on the signal-to-noise ratio (SNR) at different t. As SNR is monotonically decreasing over time, the resulting up-weighted region does not match our Phase 2 characterization. In general, our STF method reduces the training target variance in the intermediate field and is complementary to previous improvements of diffusion models.

ACKNOWLEDGEMENTS

We are grateful to Benson Chen for reviewing an early draft of this paper. We would like to thank Hao He and the anonymous reviewers for their valuable feedback. YX and TJ acknowledge support from MIT-DSTA Singapore collaboration, from NSF Expeditions grant (award 1918839) "Understanding the World Through Code", and from MIT-IBM Grand Challenge project. ST and TJ also acknowledge support from the ML for Pharmaceutical Discovery and Synthesis Consortium (MLPDS).

