DISCRETE PREDICTOR-CORRECTOR DIFFUSION MODELS FOR IMAGE SYNTHESIS

Abstract

We introduce Discrete Predictor-Corrector diffusion models (DPC), extending predictor-corrector samplers in Gaussian diffusion models to the discrete case. Predictor-corrector samplers are a class of samplers for diffusion models, which improve on ancestral samplers by correcting the sampling distribution of intermediate diffusion states using MCMC methods. In DPC, the Langevin corrector, which does not have a direct counterpart in discrete space, is replaced with a discrete MCMC transition defined by a learned corrector kernel. The corrector kernel is trained to make the correction steps achieve asymptotic convergence, in distribution, to the correct marginal of the intermediate diffusion states. Equipped with DPC, we revisit recent transformer-based non-autoregressive generative models through the lens of discrete diffusion, and find that DPC can alleviate the compounding decoding error due to the parallel sampling of visual tokens. Our experiments show that DPC improves upon existing discrete latent space models for class-conditional image generation on ImageNet, and outperforms continuous diffusion models and GANs, according to standard metrics and user preference studies.

1. INTRODUCTION

Generative Adversarial Networks (GANs) are the leading model class for a wide variety of content creation tasks (Goodfellow et al., 2014; Brock et al., 2018; Karras et al., 2020) . Recently, however, likelihood-based models, such as diffusion models (Dhariwal & Nichol, 2021; Ho et al., 2020; 2022) and generative transformers (Ramesh et al., 2021; Esser et al., 2021b; Chang et al., 2022) , have started rivaling GANs in offering an alternative training paradigm with superior training stability and improved generation diversity. In particular, ADM (Dhariwal & Nichol, 2021) and CDM (Ho et al., 2022) presented diffusion models attaining better perceptual quality on the class-conditional ImageNet benchmark compared to BigGAN (Brock et al., 2018) . Sampling speed, however, is still a bottleneck hindering the practical application of diffusion models. These models can be orders of magnitude slower than GANs, due to the need to take up to hundreds of steps to synthesize a single image during inference. Recently, discrete diffusion has been receiving attention as a promising direction for achieving an improved trade-off between generation quality and efficiency. Like the continuous (Gaussian) diffusion process (Sohl-Dickstein et al., 2015; Song & Ermon, 2019; Ho et al., 2020; Gu et al., 2022) , these models incrementally corrupt training data until a known base distribution is reached, and this corruption process is reversed when sampling from the learned model. Unlike continuous diffusion models, the corruption is applied in a latent, possibly low-dimensional, discrete space. The image generation quality of discrete diffusion models is still inferior to that of continous diffusion models. For example, the state-of-the-art discrete diffusion model (i.e., VQ-Diffusion (Gu et al., 2022) ) still notably underperforms CDM (Ho et al., 2022) and BigGAN (Brock et al., 2018) on ImageNet without the guidance from external classifiers. Contemporarily, non-autoregressive transformers (Chang et al., 2022; Gu et al., 2022; Zhang et al., 2021; Lezama et al., 2022) have demonstrated promising performances in both perceptual image quality and efficiency on the ImageNet benchmark. In particular, a non-autoregressive transformer model named MaskGIT (Chang et al., 2022) achieves comparable generation quality to the leading diffusion model ADM on ImageNet, while enjoying two orders-of-magnitude faster inference speed. It brings down the generation time of ADM from ∼500 steps to only 18 steps, making it possible to generate an image within 0.1 second on a TPU device. Non-autoregressive transformers are trained to predict masked visual tokens, inspired by masked language models such as BERT (Devlin et al., 2019) . The decoding process, similar to non-autoregressive sampling techniques from machine translation (Ghazvininejad et al., 2019; Kong et al., 2020) , predicts all missing tokens in parallel, starting from a fully masked sequence, and subsequently following an iterative refinement schedule. By viewing non-autoregressive transformers through the lens of discrete diffusion, our analysis yields new insights into a critical issue with these models, namely the compounding decoding error, which causes a mismatch between the inference and training distributions of the intermediate latents produced during the parallel sampling process. To tackle the compounding decoding error, we propose Discrete Predictor-Corrector (DPC) diffusion models that introduce iterative refinement of the intermediate diffusion states. DPC learns a corrector kernel that, coupled with the reverse diffusion predictor, forms a discrete Markov Chain Monte Carlo (MCMC) predictor-corrector algorithm that has the correct marginal distribution of the intermediate latents as its limiting distribution. The proposed DPC model is a new type of discrete diffusion model that is distinct from both non-autoregressive transformers and conventional discrete diffusion models. Compared to non-autoregressive transformers (Chang et al., 2022; Lezama et al., 2022; Gu et al., 2022) , DPC introduces new techniques to perform multi-step correction of intermediate states in the sampling process. Furthermore, it provides a theoretical underpinning for this model class, and shows that, in the ideal case, it can completely correct the train/test mismatch between the distributions of intermediate states. DPC advances conventional discrete diffusion by introducing a new discrete MCMC correction kernel, which can be considered as a discrete analogue to the Langevin corrector that was of (Song et al., 2021) for the continuous case. We empirically validate that DPC is able to achieve a good quality-vs-efficiency trade-off on two tasks: class-conditional image generation on the ImageNet dataset and unconditional generation on the Places2 (Zhou et al., 2017) dataset. The results show that DPC performs favorably against both non-autoregressive transformers (Chang et al., 2022; Lezama et al., 2022; Esser et al., 2021b) and discrete diffusion baselines (Gu et al., 2022) while maintaining a fast inference speed. In particular, without the help of external classifiers, DPC outperforms the state-of-the-art continuous diffusion models (i.e., ADM (Dhariwal & Nichol, 2021) and CDM (Ho et al., 2022) ) in FID (Heusel et al., 2017) , thereby establishing a new state-of-the-art on the high-resolution (512×512) image synthesis task on ImageNet. When leveraging an external pre-trained classifier and upsampling, DPC produces state-of-the-art class-conditional generation, yielding better Inception Score (IS) and FID compared to state-of-the-art GANs, (e.g. StyleGAN-XL (Sauer et al., 2022) ) and continuous diffusion models (e.g. ADM (Dhariwal & Nichol, 2021) ). Furthermore, we present user preference studies that confirm the perceptual quality provided by DPC.

2. BACKGROUND AND PROBLEM STATEMENT

2.1 DISCRETE DIFFUSION MODELS Let x 0 ∈ {1, . . . , K} N be a vector of discrete data, such as an image or a set of image tokens obtained by Vector-Quantized (VQ) encoding with a dictionary of K elements. Using a discrete diffusion process q(x t+1 |x t ), we can sample a sequence of latent variables x 1 , x 2 , . . . , x T such that the final latent x T has a simple known and fixed distribution p(x T ). Starting from a sample from p(x T ), we can then reverse this process using a learned reverse diffusion model p θ (x t-1 |x t ), eventually producing a sample x 0 . One instance of such a discrete diffusion process q, is the absorbing state diffusion process from (Austin et al., 2021) , in which we set x t = x 0 m t , where m t is a vector of binary masks that starts out as all ones, m 0 = 1, and ends up as all zeros, m T = 0. In between we gradually evolve m t by randomly setting more and more of its elements to zero according to m t ∼ q(m t |m t-1 ). The reverse diffusion model p θ (x t-1 |x t , m t ) can be constructed by first sampling a new mask q(m t-1 |m t ), and then sampling a new value for x t-1 corresponding to those values of m t-1 that are newly unmasked, i.e. sampling x i t-1 ∼ p θ (x i t-1 |x t ) for those i for which m i t-1 = 1, and m i t = 0.

