STABLE TARGET FIELD FOR REDUCED VARIANCE SCORE ESTIMATION IN DIFFUSION MODELS

Abstract

Diffusion models generate samples by reversing a fixed forward diffusion process. Despite already providing impressive empirical results, these diffusion models algorithms can be further improved by reducing the variance of the training targets in their denoising score-matching objective. We argue that the source of such variance lies in the handling of intermediate noise-variance scales, where multiple modes in the data affect the direction of reverse paths. We propose to remedy the problem by incorporating a reference batch which we use to calculate weighted conditional scores as more stable training targets. We show that the procedure indeed helps in the challenging intermediate regime by reducing (the trace of) the covariance of training targets. The new stable targets can be seen as trading bias for reduced variance, where the bias vanishes with increasing reference batch size. Empirically, we show that the new objective improves the image quality, stability, and training speed of various popular diffusion models across datasets with both general ODE and SDE solvers. When used in combination with EDM (Karras et al., 2022), our method yields a current SOTA FID of 1.90 with 35 network evaluations on the unconditional CIFAR-10 generation task. The code is available at https://github.com/Newbeeer/stf 

1. INTRODUCTION

Diffusion models (Sohl-Dickstein et al., 2015; Song & Ermon, 2019; Ho et al., 2020) have recently achieved impressive results on a wide spectrum of generative tasks, such as image generation (Nichol et al., 2022; Song et al., 2021b) , 3D point cloud generation (Luo & Hu, 2021) and molecular conformer generation (Shi et al., 2021; Xu et al., 2022a) . These models can be subsumed under a unified framework in the form of Itô stochastic differential equations (SDE) (Song et al., 2021b) . The models learn time-dependent score fields via score-matching (Hyvärinen & Dayan, 2005) , which then guides the reverse SDE during generative sampling. Popular instances of diffusion models include variance-exploding (VE) and variance-preserving (VP) SDE (Song et al., 2021b) . Building on these formulations, EDM (Karras et al., 2022) provides the best performance to date. We argue that, despite achieving impressive empirical results, the current training scheme of diffusion models can be further improved. In particular, the variance of training targets in the denoising scorematching (DSM) objective can be large and lead to suboptimal performance. To better understand the origin of this instability, we decompose the score field into three regimes. Our analysis shows that the phenomenon arises primarily in the intermediate regime, which is characterized by multiple modes or data points exerting comparable influences on the scores. In other words, in this regime, the sources of the noisy examples generated in the course of the forward process become ambiguous. We illustrate the problem in Figure 1(a) , where each stochastic update of the score model is based on disparate targets. We propose a generalized version of the denoising score-matching objective, termed the Stable Target Field (STF) objective. The idea is to include an additional reference batch of examples that are used to calculate weighted conditional scores as targets. We apply self-normalized importance sampling to aggregate the contribution of each example in the reference batch. Although this process can substantially reduce the variance of training targets (Figure 1 it does introduce some bias. However, we show that the bias together with the trace-of-covariance of the STF training targets shrinks to zero as we increase the size of the reference batch. (b)), especially in the intermediate regime, v 1 v 2 v 3 (a) DSM v 1 v 2 v 3 (b) STF Experimentally To the best of our knowledge, STF is the first technique to accelerate the training process of diffusion models. We further demonstrate the performance gain with increasing reference batch size, highlighting the negative effect of large variance. Our contributions are summarized as follows: (1) We detail the instability of the current diffusion models training objective in a principled and quantitative manner, characterizing a region in the forward process, termed the intermediate phase, where the score-learning targets are most variable (Section 3). (2) We propose a generalized score-matching objective, stable target field, which provides more stable training targets (Section 4). (3) We analyze the behavior of the new objective and prove that it is asymptotically unbiased and reduces the trace-of-covariance of the training targets by a factor pertaining to the reference batch size in the intermediate phase under mild conditions (Section 5). (4) We illustrate the theoretical arguments empirically and show that the proposed STF objective improves the performance, stability, and training speed of score-based methods. In particular, it achieves the current state-of-the-art FID score on the CIFAR-10 benchmark when combined with EDM (Section 6).

2. BACKGROUND ON DIFFUSION MODELS

In diffusion models, the forward processfoot_0 is an SDE with no learned parameter, in the form of:  dx = f (x, t)dt + g(t)dw, where x 2 R d with x(0) ⇠ p 0 being the data distribution, t 2 [0, 1], f : R d ⇥ [0, 1] ! R d , g : [0, 1] ! R,



For simplicity, we focus on the version where the diffusion coefficient g(t) is independent of x(t).



Figure 1: Illustration of differences between the DSM objective and our proposed STF objective. The "destroyed" images (in blue box) are close to each other while their sources (in red box) are not. Although the true score in expectation is the weighted average of v i , the individual training updates of the DSM objective have a high variance, which our STF objective reduces significantly by including a large reference batch (yellow box).

and w 2 R d is the standard Wiener process. It gradually transforms the data distribution to a known prior as time goes from 0 to 1. Sampling of diffusion models is done via a corresponding reverse-time SDE(Anderson, 1982):

