f -DM: A MULTI-STAGE DIFFUSION MODEL VIA PRO-GRESSIVE SIGNAL TRANSFORMATION

Abstract

Diffusion models (DMs) have recently emerged as SoTA tools for generative modeling in various domains. Standard DMs can be viewed as an instantiation of hierarchical variational autoencoders (VAEs) where the latent variables are inferred from input-centered Gaussian distributions with fixed scales and variances. Unlike VAEs, this formulation constrains DMs from changing the latent spaces and learning abstract representations. In this work, we propose f -DM, a generalized family of DMs, which allows progressive signal transformation. More precisely, we extend DMs to incorporate a set of (hand-designed or learned) transformations, where the transformed input is the mean of each diffusion step. We propose a generalized formulation of DMs and derive the corresponding de-noising objective together with a modified sampling algorithm. As a demonstration, we apply f -DM in image generation tasks with a range of functions, including down-sampling, blurring, and learned transformations based on the encoder of pretrained VAEs. In addition, we identify the importance of adjusting the noise levels whenever the signal is sub-sampled and propose a simple rescaling recipe. f -DM can produce high-quality samples on standard image generation benchmarks like FFHQ, AFHQ, LSUN and ImageNet with better efficiency and semantic interpretation. Please check our videos at http://jiataogu.me/fdm/. Figure 1 : Visualization of reverse diffusion from f -DMs with various signal transformations. x t is the denoised output, and z s is the input to the next diffusion step. We plot the first three channels of VQVAE latent variables. Low-resolution images are resized to 256 2 for ease of visualization.

1. INTRODUCTION

Diffusion probabilistic models (DMs, Sohl-Dickstein et al., 2015; Ho et al., 2020; Nichol & Dhariwal, 2021) and score-based (Song et al., 2021b) generative models have become increasingly popular as the tools for high-quality image (Dhariwal & Nichol, 2021 ), video (Ho et al., 2022b ), text-tospeech (Popov et al., 2021 ) and text-to-image (Rombach et al., 2021; Ramesh et al., 2022; Saharia et al., 2022a) synthesis. Despite the empirical success, conventional DMs are restricted to operate in the ambient space throughout the Gaussian noising process. On the other hand, common generative models like VAEs (Kingma & Welling, 2013) and GANs (Goodfellow et al., 2014; Karras et al., 2021) employ a coarse-to-fine process that hierarchically generates high-resolution outputs. We are interested in combining the best of the two worlds: the expressivity of DMs and the benefit of hierarchical features. To this end, we propose f -DM, a generalized multi-stage framework of DMs to incorporate progressive transformations to the inputs. As an important property of our formulation, f -DM does not make any assumptions about the type of transformations. This makes it compatible with many possible designs, ranging from domain-specific ones to generic neural networks. In this work, we consider representative types of transformations, including down-sampling, blurring, and neural-based transformations. What these functions share in common is that they allow one to derive increasingly more global, coarse, and/or compact representations, which we believe can lead to better sampling quality as well as reduced computation. Incorporating arbitrary transformations into DMs also brings immediate modeling challenges. For instance, certain transformations destroy the information drastically, and some might also change the dimensionality. For the former, we derive an interpolation-based formulation to smoothly bridge consecutive transformations. For the latter, we verify the importance of rescaling the noise level, and propose a resolution-agnostic signal-to-noise ratio (SNR) as a practical guideline for noise rescaling. Extensive experiments are performed on image generation benchmarks, including FFHQ, AFHQ, LSUN Bed/Church and ImageNet. f -DMs consistently match or outperform the baseline performance, while requiring relatively less computing thanks to the progressive transformations. Furthermore, given a pre-trained f -DM, we can readily manipulate the learned latent space, and perform conditional generation tasks (e.g., super-resolution) without additional training. Given a datapoint x ∈ R N , a DM models time-dependent latent variables z = {z t |t ∈ [0, 1], z 0 = x} based on a fixed signal-noise schedule {α t , σ t }:

2. BACKGROUND

q(z t |z s ) = N (z t ; α t|s z s , σ 2 t|s I), where α t|s = α t /α s , σ 2 t|s = σ 2 t -α 2 t|s σ 2 s , s < t. It also defines the marginal distribution q(z t |x) as: q(z t |x) = N (z t ; α t x, σ 2 t I), By default, we assume the variance preserving form (Ho et al., 2020) . That is, α 2 t + σ 2 t = 1, α 0 = σ 1 = 1, and the signal-to-noise-ratio (SNR, α 2 t /σ 2 t ) decreases monotonically with t. For generation, a parametric function θ is optimized to reverse the diffusion process by denoising z t = α t x + σ t ϵ to the clean input x, with a weighted reconstruction loss L θ . For example, the "simple loss" proposed in Ho et al. ( 2020) is equivalent to weighting residuals by ω t = α 2 t /σ 2 t : L θ = E zt∼q(zt|x),t∼[0,1] ω t • ∥x θ (z t , t) -x∥ 2 2 . (1) In practice, θ is parameterized as a U-Net (Ronneberger et al., 2015) . As suggested in Ho et al. (2020), predicting the noise ϵ θ empirically achieves better performance than predicting x θ , where



Figure 2: (a) the standard DMs; (b) a bottom-up hierarchical VAEs; (c) our proposed f -DM.

