DISCRETE PREDICTOR-CORRECTOR DIFFUSION MODELS FOR IMAGE SYNTHESIS

Abstract

We introduce Discrete Predictor-Corrector diffusion models (DPC), extending predictor-corrector samplers in Gaussian diffusion models to the discrete case. Predictor-corrector samplers are a class of samplers for diffusion models, which improve on ancestral samplers by correcting the sampling distribution of intermediate diffusion states using MCMC methods. In DPC, the Langevin corrector, which does not have a direct counterpart in discrete space, is replaced with a discrete MCMC transition defined by a learned corrector kernel. The corrector kernel is trained to make the correction steps achieve asymptotic convergence, in distribution, to the correct marginal of the intermediate diffusion states. Equipped with DPC, we revisit recent transformer-based non-autoregressive generative models through the lens of discrete diffusion, and find that DPC can alleviate the compounding decoding error due to the parallel sampling of visual tokens. Our experiments show that DPC improves upon existing discrete latent space models for class-conditional image generation on ImageNet, and outperforms continuous diffusion models and GANs, according to standard metrics and user preference studies.

1. INTRODUCTION

Generative Adversarial Networks (GANs) are the leading model class for a wide variety of content creation tasks (Goodfellow et al., 2014; Brock et al., 2018; Karras et al., 2020) . Recently, however, likelihood-based models, such as diffusion models (Dhariwal & Nichol, 2021; Ho et al., 2020; 2022) and generative transformers (Ramesh et al., 2021; Esser et al., 2021b; Chang et al., 2022) , have started rivaling GANs in offering an alternative training paradigm with superior training stability and improved generation diversity. In particular, ADM (Dhariwal & Nichol, 2021) and CDM (Ho et al., 2022) presented diffusion models attaining better perceptual quality on the class-conditional ImageNet benchmark compared to BigGAN (Brock et al., 2018) . Sampling speed, however, is still a bottleneck hindering the practical application of diffusion models. These models can be orders of magnitude slower than GANs, due to the need to take up to hundreds of steps to synthesize a single image during inference. Recently, discrete diffusion has been receiving attention as a promising direction for achieving an improved trade-off between generation quality and efficiency. Like the continuous (Gaussian) diffusion process (Sohl-Dickstein et al., 2015; Song & Ermon, 2019; Ho et al., 2020; Gu et al., 2022) , these models incrementally corrupt training data until a known base distribution is reached, and this corruption process is reversed when sampling from the learned model. Unlike continuous diffusion models, the corruption is applied in a latent, possibly low-dimensional, discrete space. The image generation quality of discrete diffusion models is still inferior to that of continous diffusion models. For example, the state-of-the-art discrete diffusion model (i.e., VQ-Diffusion (Gu et al., 2022) ) still notably underperforms CDM (Ho et al., 2022) and BigGAN (Brock et al., 2018) on ImageNet without the guidance from external classifiers. Contemporarily, non-autoregressive transformers (Chang et al., 2022; Gu et al., 2022; Zhang et al., 2021; Lezama et al., 2022) have demonstrated promising performances in both perceptual image quality and efficiency on the ImageNet benchmark. In particular, a non-autoregressive transformer model named MaskGIT (Chang et al., 2022) achieves comparable generation quality to the leading diffusion model ADM on ImageNet, while enjoying two orders-of-magnitude faster inference speed. It brings down the generation time of ADM from ∼500 steps to only 18 steps, making it possible to generate an image within 0.1 second on a TPU device. Non-autoregressive transformers are trained to predict masked visual tokens, inspired by masked language models such as BERT (Devlin et al., 2019) . The decoding process, similar to non-autoregressive sampling techniques from machine translation (Ghazvininejad et al., 2019; Kong et al., 2020) , predicts all missing tokens in parallel, starting from a fully masked sequence, and subsequently following an iterative refinement schedule. By viewing non-autoregressive transformers through the lens of discrete diffusion, our analysis yields new insights into a critical issue with these models, namely the compounding decoding error, which causes a mismatch between the inference and training distributions of the intermediate latents produced during the parallel sampling process. To tackle the compounding decoding error, we propose Discrete Predictor-Corrector (DPC) diffusion models that introduce iterative refinement of the intermediate diffusion states. DPC learns a corrector kernel that, coupled with the reverse diffusion predictor, forms a discrete Markov Chain Monte Carlo (MCMC) predictor-corrector algorithm that has the correct marginal distribution of the intermediate latents as its limiting distribution. The proposed DPC model is a new type of discrete diffusion model that is distinct from both non-autoregressive transformers and conventional discrete diffusion models. Compared to non-autoregressive transformers (Chang et al., 2022; Lezama et al., 2022; Gu et al., 2022) , DPC introduces new techniques to perform multi-step correction of intermediate states in the sampling process. Furthermore, it provides a theoretical underpinning for this model class, and shows that, in the ideal case, it can completely correct the train/test mismatch between the distributions of intermediate states. DPC advances conventional discrete diffusion by introducing a new discrete MCMC correction kernel, which can be considered as a discrete analogue to the Langevin corrector that was of (Song et al., 2021) for the continuous case. We empirically validate that DPC is able to achieve a good quality-vs-efficiency trade-off on two tasks: class-conditional image generation on the ImageNet dataset and unconditional generation on the Places2 (Zhou et al., 2017) dataset. The results show that DPC performs favorably against both non-autoregressive transformers (Chang et al., 2022; Lezama et al., 2022; Esser et al., 2021b) and discrete diffusion baselines (Gu et al., 2022) while maintaining a fast inference speed. In particular, without the help of external classifiers, DPC outperforms the state-of-the-art continuous diffusion models (i.e., ADM (Dhariwal & Nichol, 2021) and CDM (Ho et al., 2022) ) in FID (Heusel et al., 2017) , thereby establishing a new state-of-the-art on the high-resolution (512×512) image synthesis task on ImageNet. When leveraging an external pre-trained classifier and upsampling, DPC produces state-of-the-art class-conditional generation, yielding better Inception Score (IS) and FID compared to state-of-the-art GANs, (e.g. StyleGAN-XL (Sauer et al., 2022) ) and continuous diffusion models (e.g. ADM (Dhariwal & Nichol, 2021) ). Furthermore, we present user preference studies that confirm the perceptual quality provided by DPC.

2.1. DISCRETE DIFFUSION MODELS

Let x 0 ∈ {1, . . . , K} N be a vector of discrete data, such as an image or a set of image tokens obtained by Vector-Quantized (VQ) encoding with a dictionary of K elements. Using a discrete diffusion process q(x t+1 |x t ), we can sample a sequence of latent variables x 1 , x 2 , . . . , x T such that the final latent x T has a simple known and fixed distribution p(x T ). Starting from a sample from p(x T ), we can then reverse this process using a learned reverse diffusion model p θ (x t-1 |x t ), eventually producing a sample x 0 . One instance of such a discrete diffusion process q, is the absorbing state diffusion process from (Austin et al., 2021) , in which we set x t = x 0 m t , where m t is a vector of binary masks that starts out as all ones, m 0 = 1, and ends up as all zeros, m T = 0. In between we gradually evolve m t by randomly setting more and more of its elements to zero according to m t ∼ q(m t |m t-1 ). The reverse diffusion model p θ (x t-1 |x t , m t ) can be constructed by first sampling a new mask q(m t-1 |m t ), and then sampling a new value for x t-1 corresponding to those values of m t-1 that are newly unmasked, i.e. sampling x i t-1 ∼ p θ (x i t-1 |x t ) for those i for which m i t-1 = 1, and m i t = 0.

2.2. ORDER-AGNOSTIC AUTOREGRESSIVE MODELS

When sampling from a discrete diffusion model with a sufficiently large number of steps t, only a single element of x is masked/unmasked in each transition, so we only need to sample a single x i t-1 at each step of the reverse diffusion process. In this case, p θ is an order-agnostic autoregressive model, as introduced by Uria et al. (2014) and discussed in more detail by Hoogeboom et al. (2022) . As shown in Uria et al. (2014) , these models can be trained efficiently by maximizing the likelihood i s.t. m i t =0 log p θ (x i t-1 = x i 0 |x t , m t ), where the sum is over all possible orderings of the data that are consistent with m t : this has lower variance than just maximizing the likelihood for the single ordering that was actually sampled by q(•), but is otherwise equivalent. Alternatively, it can be interpreted as denoising the corrupted data x t towards the original clean data x 0 , trained by maximizing the log-likelihood log pθ (x 0 |x t , m t ) for the whole vector x 0 at once, under a one-step model pθ defined by pθ (x 0 |x t , m t ) = i p θ (x i 0 |x t , m t ), where we use the pθ notation to distinguish this factorized one-step denoising model from the full multi-step generative model p θ (x 0 ) = xt>0 t p θ (x t-1 |x t ). In the following we use pθ for all distributions affected by using the one-step model in the generative process.

2.3. NON-AUTOREGRESSIVE TRANSFORMERS FOR IMAGE SYNTHESIS

Transformer-based models generate images in two stages (Esser et al., 2021b; Ramesh et al., 2021) . First, the image is quantized into a grid of discrete tokens by a VQ-autoencoder (van den Oord et al., 2017) . In the second stage, an autoregressive transformer decoder (Vaswani et al., 2017; Chen et al., 2020a) is learned on the flattened token sequence to generate image tokens sequentially using autoregressive decoding. In the end, the generated codes are mapped to pixel space using the VQ-decoder learned in the first stage (Esser et al., 2021b) . Recently, non-autoregressive transformers (Chang et al., 2022; Gu et al., 2022; Zhang et al., 2021) are proposed to improve the second stage, adapted from non-autoregressive machine translation (Ghazvininejad et al., 2019; Kong et al., 2020) . An example of such model is the Masked Generative Image Transformer (MaskGIT) (Chang et al., 2022) , which follows the masked modeling in BERT (Devlin et al., 2019) (equation 2). During decoding, MaskGIT starts with all the image tokens masked out. In each inference step, it uses parallel decoding, i.e. predicting all tokens simultaneously while only keeping the ones with the highest prediction scores. The remaining tokens are masked out and will be re-predicted in the next iteration. The mask ratio is made decreasing, according to a cosine function, until all tokens are generated within a few iterations of refinement.

2.4. COMPOUNDING DECODING ERROR IN NON-AUTOREGRESSIVE TRANSFORMERS

To sample from an autoregressive model, one would typically sample one element i at a time, which is costly when the total number of elements N in x is large. By ignoring the dependencies between the x i t-1 , one can accelerate this process by sampling several x i t-1 in parallel at each transition. This idea has been explored in discrete diffusion models by (Gu et al., 2022; Hoogeboom et al., 2022) , and is exploited by non-autoregressive generative transformers (Chang et al., 2022) . As suggested in prior work (Austin et al., 2021; Gu et al., 2022; Savinov et al., 2021; Lezama et al., 2022) , there exists a close relation between non-autoregressive transformers and discrete diffusion models. For example, the masked-modeling training of state-of-the-art MaskGIT (Chang et al., 2022) , can be modeled by the discrete diffusion process with the absorbing state ([MASK]) (Austin et al., 2021) . MaskGIT's parallel sampling may be understood as a reverse diffusion process using the one-step model in (2). However, parallel sampling introduces errors as the factorized pθ does not exactly match the original generative model p θ for the joint distribution of the elements of x. As a result, the marginal inference distribution pθ (x t ) will deviate from the training distribution q(x t ) as sampling progresses. We refer to this as a compounding decoding error, as small differences between q(x t-1 |x t ) and pθ (x t-1 |x t ) can accumulate into large differences between q(x t ) and pθ (x t ) after many sampling steps (cf . teacher forcing error).

2.5. PREDICTOR-CORRECTOR SAMPLERS

Predictor-Corrector samplers were proposed by Song et al. (2021) for use in Gaussian diffusion models: before applying each time transition in the generative model p θ (x t-1 |x t ), these samplers correct the distribution of x t by applying one or more steps of Langevin MCMC at constant timestep t, thereby bringing the distribution of x t closer to q(x t ). The corrector has the form x t ← x t + t s θ (x t , t) + √ 2 t z, z ∼ N (0, I), where s θ (x t , t) ≈ ∇ x log q t (x t ) is the learned score function and t is a step size.

3. DISCRETE PREDICTOR-CORRECTOR SAMPLERS

In this section, we address the critical issue of compounding decoding error due to parallel sampling in non-autoregressive transformers. We propose to mitigate the gap between q(x t ) and pθ (x t ) by applying a learned discrete corrector MCMC step to iteratively improve intermediate diffusion states. This allows our models to exploit the computational advantages of parallel sampling in non-autoregressive transformers, while maintaining high synthesis quality. To reduce the compounding decoding error in pθ (x t ), we adjust the sampled x t to more closely resemble samples from q(x t ), before applying the time-transition from x t to x t-1 . In the discrete diffusion case, however, there is no direct counterpart to the score function that gives the rate of change required to improve the likelihood of a sample (Campbell et al., 2022) . Instead, we propose to utilize an MCMC corrector step of the following form: x0 ∼ pθ ( x0 |x t , m t ) (3) x t ∼ p φ (x t | x0 , t), where p φ is a learned corrector distributionfoot_0 , trained so that when applying this procedure multiple times, the distribution of x t converges to q(x t ), making this a valid corrector kernel (Section 3.1). Since x t = x0 m t , the corrector distribution p φ can equivalently be stated in terms of the mask m t , which is resampled at every corrector step, m t ∼ p φ (m t | x0 , t). The mask is a collection of binary variables, of which a known number k are equal to one at a given time step t, and the others are zero. Thus, the corrector distribution p φ (m t | x0 , t) is a categorical distribution over all possible binary masks with k non-zero elements. We model this distribution using the Plackett-Luce model (Plackett, 1975) , which gives p φ (m t | x0 , t) = c k (mt) k i=1 exp[l c i k φ ( x0 , t)] k j=i exp[l c j k φ ( x0 , t)] + l:m l t =0 exp[l l φ ( x0 , t)] , where l φ ( x0 , t) is the N -dimensional output of a neural network corrector model, with elements l i φ ( x0 , t) representing the relative log-odds (logits) of m i t being positive for the given x0 , and where we sum over all possible permutations c k of the indices of the k non-zero elements of the mask m t . In numpy notation: c k = np.permute(np.nonzero(m t )). We provide more background on this model in Appendix A. In practice, we can sample from the mask model by forming perturbed logits by adding i.i.d. Gumbel noise, li = l i φ ( x0 , t) + G i , where G i ∼ Gumbel(0, 1), and setting the mask to 1 for the top-k elements of l, with k determined by t. This is known as the Gumbel-top-k trick (Kool et al., 2019) .

3.1. DISCRETE PREDICTOR-CORRECTOR TRAINING

The learned corrector function p φ is trained to achieve detailed balance between the MCMC corrector steps (3) and (4), and the training marginal q(x t ). Let J(x t |x t ) be the transition kernel of the Markov chain given by ( 3) and (4). Then, this Markov chain has q(x t ) as stationary distribution if p φ (x t | x0 ) = pθ ( x0 |x t )q(x t ) Z( x0 ) , with Z( x0 ) = xt pθ ( x0 |x t )q(x t ). Assuming ( 6), the detailed balance condition of J(x t |x t ) with respect to q(x t ) follows from: q(x t )J(x t |x t ) = q(x t ) x0 p φ (x t | x0 )p θ ( x0 |x t ) (7) = q(x t ) x0 pθ ( x0 |x t )q(x t ) Z( x0 ) pθ ( x0 |x t ) (8) = q(x t )q(x t ) x0 pθ ( x0 |x t )p θ ( x0 |x t ) Z( x0 ) = q(x t )J(x t |x t ), where the last step is derived from applying the same operations in the other direction. The goal of training the corrector kernel is then to obtain p φ (x t | x0 ) ∝ pθ ( x0 |x t )q(x t ) (equation 6), which can be done by minimizing the KL divergence KL pθ ( x0 |x t )q(x t ) p φ (x t | x0 )Z( x0 ) = -E ( x0,xt)∼ pθ ( x0|xt)q(xt) log p φ (x t | x0 ) + C, ( ) where C are terms that do not include φ. Importantly, the denoiser pθ is pre-trained and its parameters are frozen during the training of the corrector p φ . Thus, the corrector distribution can be trained by first obtaining a masked sample from the training distribution q(x t ), and then using the denoising model pθ ( x0 |x t ) (with frozen weights θ) to obtain a corresponding x0 , after which p φ (x t | x0 ) can be trained by maximum likelihood. Since x t is completely determined by x0 and the mask m t , this is equivalent to training p φ to maximize the likelihood of the sampled mask m t . Under our chosen model in ( 5), evaluating the likelihood for all mask elements jointly would require summing over all possible masks, which is computationally expensive. Another simple training objective is binary cross entropy applied independently to each mask element using p φ (m i t = 1| x0 , t) ≈ σ(l i φ ( x0 , t)) , where σ is the sigmoid function. In our experiments we find this simplified loss to be very effective. We discuss this choice in greater depth in Appendix A. Following MaskGIT, we use the cosine schedule for determining the masking ratio during both training and inference. Let T denote the time horizon. For S sampling iterations, we set the masking ratio to γ(t) = 1 -cos(t/S • π/2), for t ∈ {1, . . . , T /S}. During training of both the generator p θ and corrector p φ , we first sample a time t, i.e. t ∼ U [0,1] , before inputting it to compute γ(t). Likewise, during inference, we compute γ(t) using the current decoding step. In practice, this scheduling gives more weight to the transitions with lower masking rates, in a similar fashion to the cosine weighing scheme of other diffusion models (e.g., (Hoogeboom et al., 2021) ).

3.2. EFFICIENT IMPLEMENTATION USING SHORTCUT TIME TRANSITIONS

Our predictor-corrector sampler consists of taking one predictor step: x t-1 ∼ p θ (x t-1 |x t ) (step 1), followed by one or multiple corrector steps: x0 ∼ pθ ( x0 |x t-1 ) (step 2) followed by x t-1 ∼ p φ (x t-1 | x0 ) (step 3). Under the exact diffusion distribution q, we have that xt-1 q(x 0 |x t-1 )q(x t-1 |x t ) = q(x 0 |x t ). Since our model pθ (x 0 |x t ) is trained to approximate q(x 0 |x t ) as well as possible, we then also have xt-1 pθ ( x0 |x t-1 )p θ (x t-1 |x t ) ≈ pθ ( x0 |x t ), to a close approximation. Hence, step 1 and step 2 of our predictor-corrector sampler can be combined into a single step x0 ∼ pθ ( x0 |x t ). Using this result, the shortcut time transition becomes x0 ∼ pθ ( x0 |x t ) followed by x t-1 ∼ p φ (x t-1 | x0 ). This avoids one sampling step, and offers a more efficient way of implementing DPC. In the special case of a single corrector step, this implementation of DPC also closely resembles the non-autoregressive transformer Token-Critic (Lezama et al., 2022) . 4) GAN models: BigGAN (Brock et al., 2018) and StyleGAN-XL (Sauer et al., 2022) . We use FID (Heusel et al., 2017) , Inception Score (IS) (Salimans et al., 2016) , and Precision vs. Recall (Kynkäänniemi et al., 2019) to evaluate perceptual quality. We also report the number of neural function evaluations (NFE) required by each method. ImageNet generation is evaluated against the training set, and Places2 against the validation set, following Chang et al. (2022) . Numbers for the baseline models are quoted from their respective papers except for MaskGIT for which we used a provided model. In addition, we conduct user preference studies to further evaluate the perceptual quality and diversity of the models' samples. Refinement Schedule We evaluate two versions of DPC, varying the number of correction steps c(t) in each intermediate state t, where we take c(t) to include the corrector step in the shortcut time transition (section 3.2): • DPC-full(C): c(t) = C • DPC-light(C): c(t) = min(s C (t, τ 1 ), s C (T -t, T -τ 2 )), with s C (t) = max(1, C + min(0, t -τ )) a slanted step function, making c(t) a trapezoid with the top vertices at t = τ 1 and t = τ 2 . In DPC-light, c(t) concentrates more correction steps around t = τ 1 to t = τ 2 . To further reduce the cost, we stop the reverse process at t = T -5. Table 5 presents results varying C, and shows that more DPC correction steps effectively lead to improved sampling performance. Ablation studies for τ 1 , τ 2 and β are included in Appendix B. Upsampling in Discrete Latent Space Following the success of cascaded upsampling approaches in continuous diffusion models (Dhariwal & Nichol, 2021; Ho et al., 2022) , we experiment with using a cascaded super-resolution stage in the discrete latent space, and observe significant improvements compared to single resolution modeling (Table 2 ). Specifically, we train an upsampling discrete denoising model to model the sequence of high-resolution visual tokens, conditioned on the lowresolution sequence. The upsampling model is trained with the objective of (2), using a cosine masking schedule. We refer to Appendix C for further details on the upsampling stage.

Implementation Details

The input images are discretized into a set of 10-bit integers using a VQGAN encoder (Esser et al., 2021b ), e.g. a 512×512 image is represented as a grid of 32×32 integer indices over a codebook of 1024 elements. For class-conditional generation, we prepend a class token to the sequence of visual tokens. The denoiser pθ and corrector p φ are transformer models in which, following MaskGIT (Chang et al., 2022) and Token-Critic (Lezama et al., 2022) , the denoiser has 24 layers and 16 heads. Importantly, we used the same denoiser pθ for MaskGIT, Token-Critic and the proposed DPC, which was trained using maximum likelihood of the one-step model in (2). Thus, the comparison between these methods is restricted to the sampling scheme. The corrector is a 20-layer, 12-head transformer. The denoiser was trained for 600 epochs and the corrector for 300 epochs on 32 TPUs v4 with batch size 256. The VQGAN encoder was trained at 256×256 resolution on the same datasets. The upsampling stage is a 16layer transformer and was trained for 300 epochs to denoise the encodings of 512×512 images conditioned on the encoding of ground truth 256×256 images. For DPC, unless otherwise noted, we use T = 18 diffusion steps, C = 5 and τ 1 = 4, τ 2 = 5. We control the sampling temperature of p φ with the scale parameter β of the Gumbel noise, which is set to β = 0.6 for DPC-full and β = 0.5 for DPC-light.

4.2. MAIN RESULTS

Quantitative Evaluation Table 1 presents the quantitative comparison for the class-conditional generation on ImageNet. To compare the base modeling capacity of each method, no external pretrained classifiers or upsampling stages are used during training or sampling. As shown, the proposed DPC yields the best IS and FID scores on both 256×256 and 512×512 resolutions. DPC-light(5) achieves a good balance between the fidelity and the number of function evaluations (correction + time-transitions). For example, DPC-light(5) outperforms the previous best discrete diffusion VQ-Diffusion model by a large margin in FID while using fewer diffusion steps. It improves the Table 3 : Proportion of times our DPClight(5) model is preferred over other state-of-the-art class-conditional generation models in the user studies conducted on the ImageNet (512×512) benchmark. quality of non-autoregressive transformers MaskGIT and Token-Critic with a reasonable increase in NFE (36) which is still less than that of the continuous diffusion models (250). We show DPC's performance for the unconditional generation on the Places2 dataset in Table 4 . ImageNet 256×256 ImageNet 512×512 Model Type NFE FID ↓ IS ↑ Prec ↑ Rec ↑ FID ↓ IS ↑ Prec ↑ Rec ↑ BiGGAN GAN In Table 2 , we examine the comparison when the class-conditional generation is aided by pre-trained classifiers used as gradient guidance (G) or rejection samplers (R), or when leveraging an upsampling stage (U) for cascaded generation. Using external classifiers improves the generation quality but makes it difficult to compare models given the distinct types of classifier architectures being used such as DeiT (Touvron et al., 2021) in SyleGAN-XL, UNet with CLIP attention (Radford et al., 2021) in ADM, or ResNet (He et al., 2016) in MaskGIT. Nevertheless, we present results for this setup by using a rejection sampling scheme based on a ResNet-50 classifier pretrained on ImageNet, with a specified acceptance rate of 25%. We apply the rejection sampling in samples generated at 256×256 resolution and then apply 6 upsampling decoding iterations to obtain the final 512×512 samples. For reference, ADM+G+U (Dhariwal & Nichol, 2021) uses 128×128 generation followed by upsampling to 512×512. DPC achieves the best FID and Inception Score under this setup among the models we consider, and is highly competitive when using only generation and upsampling. User Preference Study To further understand the perceptual differences between the compared models, we conduct two user preference studies on Amazon Mechanical Turk to verify the visual quality and diversity of our model's samples. We compare our DPC-light(5), without classifier rejection or upsampling, to StyleGAN-XL, ADM with guidance and upsampling, and MaskGIT and Token-Critic without classifier rejection. The compared images are generated using the public models of StyleGAN-XL, ADM, and MaskGIT, obtained from their official websites. For the quality test, we present two randomly sampled images of the same class side-by-side (one from DPC-light and one from the compared model), and ask the graders to select which one looks more realistic. For the diversity test, we show graders two sets of six randomly sampled images each for four different classes, and ask them to select which set is more diverse looking. The quality and diversity tests are carried out on each of the 1,000 ImageNet classes, with 14 different graders, for a total of 14,000 comparisons. The comparison to each method is organized into 250 task groups. As shown in Table 3 , DPC is preferred for quality more times than all the compared methods. It is noteworthy that the compared models in Table 3 represent the best published class-conditional models on the ImageNet benchmark, including GAN, diffusion, and transformer models. The differences in the quality test are statistically significant at the p-value level of 0.05. It also appears that DPC shows preferred diversity compared to StyleGAN-XL and ADM+G+U.

Qualitative Evaluation

In Figure 1 , we show random samples from our ImageNet 512×512 DPClight(5) model, and from the publicly available implementations of StyleGAN-XL and ADM with classifier guidance and upsampling. These are the same models used in the user studies. We refer to Appendix D and the supplementary material for more comprehensive qualitative comparisons.

Diffusion Models

The most prominent type of diffusion model that operates entirely in continuous space is the Gaussian diffusion model (Sohl-Dickstein et al., 2015; Song & Ermon, 2019; Ho et al., 2020; Song et al., 2021; Kingma et al., 2021; Tzen & Raginsky, 2019 2021), which is formulated as a forward and a learned reverse process that are both parameterized as conditional Gaussian distributions. These models have attained high quality generation results in image and audio tasks (Dhariwal & Nichol, 2021; Ho et al., 2022; Saharia et al., 2021b; a; Nichol et al., 2021; Ramesh et al., 2022; Chen et al., 2020b; Kong et al., 2021b; Lugmayr et al., 2022) . Discrete Diffusion Recently, D3PM (Austin et al., 2021) and OA-ARDM (Hoogeboom et al., 2022) applied discrete diffusion (Sohl-Dickstein et al., 2015; Hoogeboom et al., 2021) to image modeling, focusing on density estimation and compression of raw image pixels. Campbell et al. (2022) introduce a framework to model discrete diffusion in continuous time, and demonstrate that a combination of the denoising model and forward noising process is a valid corrector. In contrast to Campbell et al. (2022) , our learned corrector can directly target the most unlikely tokens, achieving efficient sampling with one order of magnitude fewer steps. While these methods limit their analysis to low-resolution images (e.g., 32×32), VQ-Diffusion (Gu et al., 2022) is the state-of-the-art discrete diffusion model for ImageNet image synthesis, also employing a mask-and-replace diffusion strategy to predict missing VQ-GAN tokens iteratively. Different from our model, VQ-Diffusion may still suffer from the compounding decoding error discussed in this paper. Our results show that our model outperforms VQ-Diffusion in both quality and efficiency on the ImageNet benchmark, while further scaling discrete diffusion models from the 256×256 to the 512×512 resolution. Non-autoregressive transformers for image synthesis While early works (van den Oord et al., 2016; Salimans et al., 2017; Parmar et al., 2018; Chen et al., 2020a) modeled images directly in pixel space, various recent works (Ramesh et al., 2021; Esser et al., 2021b; a; Ding et al., 2022) used autoregressive transformers over the discrete latent space provided by VQVAE or VQGAN. Closest to our work is a class of non-autoregressive transformers (Chang et al., 2022; Zhang et al., 2021; Lezama et al., 2022; Kong et al., 2021a) which have recently demonstrated improved quality and efficiency over conventional autoregressive transformer models like Esser et al. (2021b) . There are two major differences to these non-autoregressive transformers. First, we propose new techniques to perform multi-step correction steps which, as shown in our experiments, are essential for achieving a desirable quality-vs-efficiency trade-off. Second, we show that, in the ideal case, the proposed method can correct the train/test mismatch in the distributions of intermediate states. Finally, we further demonstrate the applicability of cascaded upsampling in the discrete latent space.

6. CONCLUSION

Parallel sampling from generative transformers dramatically improves efficiency compared to full autoregressive sampling. However, from the perspective of discrete diffusion models, it exacerbates deviations from the ideal sampling distribution that accumulate in the marginal distributions of intermediate diffusion states. To mitigate this compounding decoding error, we proposed DPC, a new discrete diffusion model based on a learned MCMC corrector kernel that refines the samples of these intermediate states. Empirically, we demonstrated that the repeated application of the learned corrector improves the samples of a non-autoregressive transformer, as measured on the standard ImageNet class-conditional generation task. User preference studies showed that DPC is competitive with state-of-the-art generative vision transformers, GANs, and continuous diffusion models.

A LEARNING THE MASK MODEL

Let {l i } i=1...N be a set of N logits, given by the output of a neural network taking x0 , t as input, i.e. l i = l i φ ( x0 , t). We perturb these logits by adding i.i.d. Gumbel noise, i.e. li = l i + G i with G i ∼ Gumbel(0, 1), and we rank-order these perturbed logits to get a ranking c. Then the Plackett-Luce model (Plackett, 1975) states that the probability of obtaining an ordering c is given by p(c| x0 ) = N i=1 exp(l c i ) N j=i exp(l c j ) . ( ) If we only care about the top-k ranking c k = {c 1 , . . . , c k }, then the probability of that partial ranking can be obtained as p(c k | x0 ) = k i=1 exp(l c i k ) k j=i exp(l c j k ) + R , with R = N j=(k+1) exp(l c j ). We define our mask m t to be a vector with one boolean element m i t for each element li . If li is among the k t largest elements we set m i t = 1, and otherwise we set m i t = 0. Using numpy notation, we now have np.sort(c k ) = np.nonzero(m t ). For sampling a mask m t we don't care about the exact order c k , but only about the elements in this top-k ordering. The probability of sampling a mask m t can thus be found by summing over all the partial rankings c k that are consistent with this mask: p(m t | x0 ) = c k (mt) k i=1 exp(l c i k ) k j=i exp(l c j k ) + R (13) = exp   k i:m i t =1 l i   c k (mt) k i=1 1 k j=i exp(l c j k ) + R , with R = l:m l t =0 exp[l l φ ( x0 , t)]. Evaluating the mask likelihood in (13) requires summing over all partial orderings c k consistent with the mask m t , which is computationally expensive. Next we discuss two alternative formulations that are more computationally efficient. Under the Plackett-Luce model, the logits l i , l j encode the relative log-odds of i ending higher than j in the sampled ranking c. That is p( li > lj ) = σ(l i -l j ). (When applied to pairs of elements like this, the Plackett-Luce model is also known as the Bradley-Terry model (Bradley & Terry, 1952; Hunter, 2004; Huang et al., 2006) .) Thus, one alternative is to consider the relative scores of masked and unmasked elements, which yields the following objective: L pairwise = E (xt, mt) ∼ q(xt, mt) x0 ∼ p θ ( x0|xt) k i:m i t =1 N -k j:m j t =0 -log σ l i φ ( x0 , t) -l j φ ( x0 , t) . Another, simpler loss function which we found to work well in practice, is binary cross entropy independently applied to each mask element using p φ (m i t = 1| x0 , t) ≈ σ(l i φ ( x0 , t)), where σ is the sigmoid function. The reasoning here is that, if the cutoff between ending in the top-k and ending outside of it is represented by l, we may reasonably approximate the probability of ending in the top-k as p(i ∈ c k ) ≈ σ(l il). Furthermore, since the Plackett-Luce distribution is invariant to shifts in the logits l i , the absolute level of l can be chosen at will. We therefore choose l = 0, resulting in p φ (m i t = 1| x0 , t) ≈ σ(l i φ ( x0 , t)), which we fit against the marginal distribution of observed mask Model FID ↓ IS ↑ Prec ↑ Rec ↑ factorized 6.41 209.5 0.78 0.45 pairwise 10.89 231.6 0.77 0.38 Table 6 : Performance comparison between corrector models trained with the factorized loss ( 16) and the pairwise loss (15), for the same number of training steps (200 epochs), on ImageNet 512x512. In both cases we use the DPC-full(5) sampling schedule. elements q(m i t ) using maximum likelihood: L f actorized = E (xt, mt) ∼ q(xt, mt) x0 ∼ p θ ( x0|xt) - k i=1 m i t •log σ l i φ ( x0 , t) -(1-m i t )•log 1 -σ l i φ ( x0 , t) . Table 6 shows a quantitative comparison of the corrector models trained with L pairwise and L f actorized . We train both models for the same number of training steps (200 epochs), and sweep sampling hyperparameter (β) for better FID. We found the factorized loss to be more efficient and we use this loss in all the experiments reported in the main manuscript.

B ABLATION STUDIES

In Figure 2 we show resulting FID and Inception Score when varying the parameter for the sampling temperature β. Our observation is that the resulting sampling quality, as measured by FID and Inception Score, is highly sensitive to the value of β, and that the optimal β value also depends on the number of correction steps C and the refinement schedule c(t). In Figure 3 , we present an ablation study on the shape of the trapezoidal function c(t) for the DPClight model. We vary the starting point of the top side of the trapezoid, τ 1 , and its width τ 2 -τ 1 . We observe slightly better metrics by concentrating the trapezoid around the first steps of the decoding process. 

C UPSAMPLING IN DISCRETE SPACE

We perform upsampling in the discrete latent space by learning to generate the sequence of visual tokens corresponding to a high-resolution image x H 0 , given the sequence of visual tokens from a low-resolution downsampled version x L 0 . For the experiments in the paper we use the 512×512 and 256×256 resolutions, and use the same VQ-encoder for both resolutions, yielding sequences of Figure 3 : Ablation study for refinement schedule parameter τ 1 . We show the effect of varying the location of the top left vertix of the trapezoid τ 1 and its width τ 2 -τ 1 on FID and Inception Score, for DPC-light(5). We observe that concentrating more refinement steps towards the begining of the decoding process is beneficial for FID, and around the t = 4 for Inception Score. In all experiments the number of diffusion steps is T = 18.  L upsampling = E (x L 0 ,x H 0 )∼q -log i p u x H 0 (i)|x H t , x L 0 , m t , where the mask rate in m t follows a cosine scheduling function as for the denoiser model pθ . We use a transformer to model the high resolution sequence, with cross-attention to the low-resolution sequence in every layer and independent positional encodings for each resolution. During inference, we first generate an image (sequence of visual tokens xL 0 ) in 256×256 resolution. We then perform 6 decoding steps with the upsampling model p u , following the cosine scheduling function. For simplicity, when using DPC with upsampling, we only apply corrector steps in the low resolution, and use the prediction confidence in p u to keep/reject tokens (similarly to MaskGIT (Chang et al., 2022) ) in each upsampling decoding iteration. Following observations in cascaded continuous diffusion, where the low resolution images are augmented with random perturbations to reduce exposure bias (Ho et al., 2022) , we experimented with randomly perturbing the low-resolution sequence. To randomly perturb the low-resolution sequence we randomly mask a subset of tokens and ran the partially masked sequence through the low-resolution denoiser pθ . Table 7 shows the results obtained for different strengths of perturbation (masking rate). These results suggest that in the discrete case, low-resolution augmentation does not improve the upsampling stage. Thus, for simplicity we did not use any augmentation for the experiments in Section 4. 

D FURTHER QUALITATIVE RESULTS

ImageNet We present further 512x512 ImageNet class-conditional samples in Figure 4 , comparing the proposed DPC-light(5) to StyleGAN-XL (Sauer et al., 2022) and ADM with Guidance and Upsampling (Dhariwal & Nichol, 2021) . In Figure 5 we further show random samples of random classes for each of the three methods. We provide further qualitative results in the Supplementary Material.

Places2

In Figure 6 we show samples from the Places2 dataset (Zhou et al., 2017) , by the proposed DPC-light(5) (FID 16.2) and MaskGIT (FID 26.3), as well as original images from the dataset for reference. 

E FURTHER IMPLEMENTATION DETAILS

All the transformers used in this work have embedding dimension 768 and hidden dimension 3,072, learnable positional embedding (Devlin et al., 2019 ), LayerNorm (Ba et al., 2016) , and truncated normal initialization (stddev= 0.02). The following training hyperparameters were used for both transformers: dropout rate 0.1, Adam optimizer (Kingma & Ba, 2014) with β 1 = 0.9 and β 2 = 0.96. We used RandomResizeAndCrop for data augmentation and the denoiser is trained with label smoothing set to 0.1. In Figure 7 we present a screenshot for the quality user study. Users were presented with two generated images of the same class, one sampled from our method and one from the compared method, and were shown the following text prompt:

F USER PREFERENCE STUDY DETAILS

Here are two images generated from the same input category using different methods. Please use the radio buttons above to choose the more realistic looking one. It would be great if you could check both sets of images carefully, as they may contain subtle artifacts that are not immediately obvious. Thank you so much! In Figure 8 we show a screenshot for the diversity user preference study. Users were presented with two grids of random samples, one for our method and one for the compared method. The grid contained 6 images for each of 4 classes. The users were shown the following text prompt: In each tab, we have two sets of images generated from the same input categories using different methods. Please use the radio buttons above to choose the image set with more diverse content, lighting and background. Note that you must select one in each tab before you submit. It would be great if you could check both sets of images carefully, as they may contain subtle differences that are not immediately obvious. Thank you so much! Figure 8 : Screenshot for the diversity user preference study.



We use the subscript θ for models of the reverse process and φ for the corrector model, which goes in the forward direction. The random variables described by the models are determined by what is inside the brackets.



Figure 1: Random samples from ImageNet 512×512 class-conditional generation for selected classes: 'tusker', 'koala', 'model T' and 'steam locomotive'. Left: StyleGAN-XL (NFE = 1). Center: ADM + Classifier Guidance + Upsampling (NFE = 250 × 3). Right: DPC-light(5) (NFE = 66).

Figure2: Ablation study for sampling temperature parameter β. We show the effect of β on FID and Inception Score for models DPC-full(5), DPC-full(3), DPC-light(5) and DPC-light(3). We notice a high sensitivity to β and a strong correlation between the effect on FID and Inception Score.

Figure 4: Random samples from ImageNet 512×512 class-conditional generation for selected classes: 'lorikeet', 'wallaby', 'flat-coated retriever', 'groenendael', 'golfcart' and 'racing car'. Left: StyleGAN-XL (NFE = 1). Center: ADM + Classifier Guidance + Upsampling (NFE = 250 × 3). Right: DPC-light(5) (NFE = 66).

Figure 5: Random samples from random classes from ImageNet 512×512 class-conditional generation for selected classes. Left: StyleGAN-XL (NFE = 1). Center: ADM + Classifier Guidance + Upsampling (NFE = 250 × 3). Right: DPC-light(5) (NFE = 66).

Figure 6: Unconditional image generation in the Places2 dataset (Zhou et al., 2017) at 512 × 512 resolution. Random samples from left: Ground Truth, center: MaskGIT (FID 23.3, NFE = 66), right: DPC-light(5) (FID 16.2, NFE = 66).

Figure 7: Screenshot for the quality user preference study.

XL 77.6% ±0.42 54.4% ±0.50 ADM+G+U 67.6% ±0.42 53.2% ±0.50 MaskGIT 68.8% ±0.46 54.0% ±0.50 Token-Critic 68.0% ±0.46 48.4% ±0.50

; Kadkhodaie & Simoncelli,

Effect of random perturbation applied to the low-resolution sequence when training the upsampling stage. The perturbation consists in masking a portion of the tokens and replacing them with the predictions of the low-resolution denoising model. We use the one-step denoising model (without corrector) for this experiment.

ACKNOWLEDGEMENTS

We would like to thank Ming-Hsuan Yang and Douglas Eck for helpful comments during early stages of this work. We also thank the anonymous reviewers for their insightful comments and constructive feedback that helped to improve this paper.

