THINKING FOURTH DIMENSIONALLY: TREATING TIME AS A RANDOM VARIABLE IN EBMS

Abstract

Recent years have seen significant progress in techniques for learning highdimensional distributions. Many modern methods, from diffusion models to Energy-Based-Models (EBMs), adopt a coarse-to-fine approach. This is often done by introducing a series of auxiliary distributions that gradually change from the data distribution to some simple distribution (e.g., white Gaussian noise). Methods in this category separately learn each auxiliary distribution (or transition between pairs of consecutive distributions) and then use the learned models sequentially to generate samples. In this paper, we offer a simple way to generalize this idea by treating the "time" index of the series as a random variable and framing the problem as that of learning a single joint distribution of "time" and samples. We show that this joint distribution can be learned using any existing EBM method and that it allows achieving improved results. As an example, we demonstrate this approach using contrastive divergence (CD) in its most basic form. On CIFAR-10 and CelebA (32 × 32), this method outperforms previous CD-based methods in terms of inception and FID scores.

1. INTRODUCTION

Probability density estimation is among the most fundamental tasks in unsupervised learning. It is used in a wide array of applications, from image restoration and manipulation (Nichol et al., 2021; Du et al., 2021; Lugmayr et al., 2020; Kawar et al., 2021; 2022) to out-of-distribution detection (Du & Mordatch, 2019; Grathwohl et al., 2019; Zisselman & Tamar, 2020) . However, directly fitting an explicit probability model to high-dimensional data is a hard task, particularly when the data samples concentrate around a low-dimensional manifold, as is often the case with visual data. One way to circumvent this obstacle is by using coarse-to-fine approaches. In fact, in one form or another, coarse-to-fine strategies have been used with great success in most types of generative models (both implicit and explicit), including generative adversarial networks (GANs) (Karras et al., 2018) , variational autoencoders (VAEs) (Vahdat & Kautz, 2020) , energy-based models (EMBs) (Gao et al., 2018; Zhao et al., 2020) , score matching (Song & Ermon, 2019; Li et al., 2019) and diffusion models (Sohl-Dickstein et al., 2015; Ho et al., 2020) . The coarse-to-fine idea is commonly implemented through the introduction of a series of auxiliary distributions that gradually transition from the data distribution to some simple known distribution that is smoothly spread in space (e.g., a standard normal distribution). This construction is illustrated in Fig. 1a for a two-dimensional data. The index running over the series of distributions is typically referred to as "time". This is to reflect either the diffusion-like sequential manner in which samples are generated for training (from fine to coarse) (Sohl-Dickstein et al., 2015; Ho et al., 2020) or the annealing-like sequential order in which samples are generated from the model at test time (from coarse to fine) (Song & Ermon, 2019) . Methods that use this construction attempt to learn each of the distributions in the series (or each transition rule between pairs of consecutive distributions) separately of the other distributions 1In this paper, we explore a more general approach for exploiting the coarse-to-fine structure, which can be used in conjunction with almost any explicit distribution learning algorithm and leads to p v 0 (z) p v 1 (z) p v 2 (z) p v 3 (z) p v 4 (z) (a) The standard approach p z,t (z,t) t (b) Our approach Figure 1 : (a) Coarse-to-fine distribution learning methods introduce a series of auxiliary distributions that gradually transition from the data distribution (2D spiral in this example) to some simple distribution (a Gaussian here). These methods learn each auxiliary distribution (or pair of consecutive distributions) separately. (b) Here we treat the "time" index of the series as a random variable, t, and the samples from all distributions as samples from a single random vector z. We then train the model to learn the joint distribution p z,t (t, z) using samples (t, z). improved results. The key idea is to gather the samples from all auxiliary distributions and view them as coming from a single joint distribution. More specifically, we treat the "time" index of the series as a random variable, t, and the samples from all auxiliary distributions as samples of a random vector z. This allows learning a single model for the joint distribution p z,t (z, t), using pairs of samples (z, t) (see Fig. 1b ). To understand the benefit of this joint modeling, it is important to note that many of the individual distributions p z|t (z|t) commonly occupy only small regions of the space. Thus, when training a separate model for each t, each model is accurate over a different region in space, which can lead to inaccuracies at test time when switching between models. In contrast, here we learn the joint distribution p z,t (z, t), either directly (Sec. 3.4) or by breaking the problem in a reverse way and learning p z (z) and p t|z (t|z) (Sec. 3.3). Thus, during training, our unified model is exposed to samples from the entire space, leading to better stitching of the different parts. Once a model is trained using our approach, it can be used similarly to existing methods by extracting the auxiliary distributions p z|t (z|t) and sampling from them one after the other, from coarse to fine. It can also be used in alternative ways, as we discuss in Sec. 3.5. To illustrate the strength of our approach, we apply it together with the vanilla contrastive divergence (CD) method (Hinton, 2002) on the CIFAR10 (Krizhevsky et al., 2009) and CelebA (Liu et al., 2015) (32 × 32) datasets. It is important to note that although the vanilla CD method is theoretically justified (Yair & Michaeli, 2020) , it fails when directly applied to high dimensional visual data (Gao et al., 2018) . This is because it provides good estimates only nearby the data manifold. To date, good results have been obtained only with persistent contrastive divergence (PCD) (Tieleman, 2008; Du & Mordatch, 2019) , which maintains a buffer of past samples. With our approach, on the other hand, plain CD not only succeeds in learning the distribution, but it also improves upon all previous PCD-based techniques in terms of Inception Score (IS) and Fréchet Inception Distance (FID).

2. RELATED WORK

The idea of learning an explicit generative model by using an auxiliary coarse-to-fine series of distributions, has been used in many works. We briefly mention its use within popular models. Song & Ermon (2019) constructed a series of distributions by adding increasing amounts of white Gaussian noise to the training samples. They learned the gradients of the distributions using denoising score matching (Vincent, 2011) , and used the trained model to solve various generative tasks using gradient based simulated annealing. Diffusion models (Sohl-Dickstein et al., 2015; Ho et al., 2020; Song et al., 2020; Dhariwal & Nichol, 2021) are currently the state-of-the-art methods for image generation. In this approach, the series of distributions is defined through a diffusion process that begins with samples from the training set and gradually transforms the distribution into white Gaussian noise. A series of models (denoisers) is then trained to capture the reverse process, allowing to generate samples from each distribution given the preceding one. This process can also be viewed as starting from a noisy image and repeatedly removing small portions of the noise until reaching a clean image (Ho et al., 2020) . Gao et al. (2020b) used a construction similar to that of diffusion models, but with much fewer distributions. They employed a conditional version of maximum likelihood for training an EBM for each distribution in the series. This was done using short MCMC chains for generating adversarial samples from a selected distribution based on samples from consecutive distributions. The adversarial samples are then used along with true samples from the selected distribution for calculating the optimization objective. Rhodes et al. ( 2020) address an arbitrary series of distributions that transitions between the dataset and some reference distribution. A series of classifiers is then trained to discriminate between each pair of consecutive distributions. The fact that the ratio between the distributions can be extracted from these classifiers is then used for computing the target distribution in a telescopic manner.

3.1. INTRODUCING THE JOINT DISTRIBUTION

Given a dataset of i.i.d. samples {x i }, we would like to learn a model for the distribution p x from which the samples were drawn. To do so, we first select some known distribution p n as a reference. A possible choice can be white Gaussian noise. We then introduce a series of T auxiliary distributions {p vt } T -1 t=0 which begin with the reference distribution, p v0 = p n , and end with the distribution of the data, p v T -1 = p x . The main requirements from those distributions are that we know how to draw samples from them and that the effective overlap between consecutive distributions be small. One common construction relies on linear combinations of samples from p x and p n . That is, given a sample from the dataset, x, and a sample from the reference distribution, n, an intermediate sample v t can be generated as v t = α t x + β t n with some predefined coefficients {α t , β t } T -1 t=0 that monotonically transition from (α 0 , β 0 ) = (0, 1) to (α T -1 , β T -1 ) = (1, 0). Up to this point, this is the standard coarse-to-fine construction. Here, however, we proceed to view the samples from all distributions {p vt } T -1 t=0 as coming from a single joint distribution. Specifically, we introduce two random variables. The first is a random "time" index t ∼ p t , which is used to select an auxiliary distribution from the series. A simple choice would be a uniform distribution over [0, T -1] (in Sec. 3.5 we discuss how p t can be modified at test time to aid the sample generation process). Using t, we define a second random variable, z = v t . Namely, z is a random draw from one of the auxiliary distributions, where the index is randomly chosen according to t, so that p z|t (z|t) = p vt (z). Given these two variables, we define our auxiliary problem as that of learning the joint distribution p z,t based on samples of (z, t). Learning the joint distribution can be done in several ways: 1. Using the decomposition p z,t (z, t) = p z|t (z|t)p t (t). Since p t is known, this only requires learning p z|t , which is the standard approachfoot_1 . Particularly, for each t ∈ {0, . . . , T -1}, the distribution p z|t (•|t) is learned using an optimization problem that is defined only in terms of samples corresponding to that t. This approach is illustrated in Fig. 1a . 2. Using the decomposition p z,t (z, t) = p t|z (t|z)p z (z). Here we need to learn both terms, which are illustrated in Figs. 2b and 2a . (a) Learning p t|z amounts to solving a classification problem (i.e., learning to predict the discrete "time" variable t ∈ {0, . . . , T -1} from a sample z). This can be done using the cross-entropy loss. (b) The distribution p z can be learned e.g., using any EBM learning method. Importantly, this is a significantly simpler task than learning p z|t (especially for the finer t's), as p z is a smoother function that is more spread in space. 3. Directly learning p z,t , which appears in Fig. 1b . This requires slightly adapting existing methods as this is a joint distribution over two domains with different properties. It should be noted that although each of these options suffices on its own for extracting p z,t , it is also possible to use combinations of these losses in order to further improve the training (e.g., a loss on p t|z in addition to the standard loss on p z|t ). Our key observation is that substantially improved results are obtained when using the losses of Option 2 and/or 3 in conjunction or instead of the standard loss on p z|t (Option 1). This is because the losses on p t|z , p z , and p z,t are optimized using samples from all auxiliary distributions. This is while the standard approach of training a model per t using a loss on p z|t , involves samples only from that particular t and these samples come from a very restricted region in space. In Sections 3.3 and 3.4 we show how options 2 and 3, respectively, can be used in conjunction with the CD method. Option 2 leads to better results than Option 3, and improves upon all previous CD-based methods.

3.2. THE PARAMETRIC MODEL

As opposed to the common approach in which t is fed as input to the model (Song & Ermon, 2019; Rhodes et al., 2020; Ho et al., 2020) , here we propose to have the model output a vector of length T that contains the values of log p z,t (z, t) for all t ∈ {1, . . . , T -1}. Namely, we use a parametric model f θ (a neural network in our experiments), which accepts z as input and outputs f θ (z) = [log pz,t;θ (z, t = 0) , log pz,t;θ (z, t = 1) , . . . , log pz,t;θ (z, t = T -1)] ⊤ . (3) Here p denotes the model's estimate of the probability density p. This is illustrated in Fig. 3 . In this design, each element of the output vector is basically an EBM for the corresponding auxiliary distribution (here we define the log probabilities without the minus sign). This is related to the observation in (Grathwohl et al., 2019) , which draws the connection between classifiers and EBMs. However, as opposed to EBMs, here we strive to learn a normalized distribution. This can be done by using an optimization process that mixes values along with using the fact that p z|t (z|0) = p n (z) in order to fix the first element of the output vector of f θ (z) to be log p n (z). p z (z) Defining the model this way allows efficiently computing log pz;θ (z) and log pt|z;θ (t|z), as log pz;θ (z) = logSumExp (f θ (z)) , (4) log pt|z;θ (t|z) = logSoftmax (f θ (z)) t . (5) These terms can be used both during training (for applying losses on these distributions) and, as we discuss in Sec. 3.5, at test time (for generating samples from the model). In order to aid the training of the model, we make use of an observation made by Rhodes et al. (2020) , stating that it is beneficial to represent a ratio of distributions as a product of many small ratio terms. We apply this idea to our model by adding an additional cumulative-summation layer (g(x) i = i j=0 x j ) before the output of the network. This additional layer does not impact the representation power of the network as it is an invertible linear operation that can, in theory, be absorbed into the preceding linear layer. However, this layer does affect the optimization of the model, as it modifies the initialization of the effective linear operation preceding the output, as well as the optimization dynamics. In practice, we found this implicit bias to be crucial for the success of the model when applied to high-dimensional distributions.

3.3. TRAINING THE MODEL USING CD+CE

In this section we describe how our suggested model can be trained using a CD loss on the marginal distribution p z together with a cross-entropy (CE) loss on p t|z . The overall loss in this method is therefore L CD (θ) + L CE (θ), and we refer to it as CD+CE (see Alg. 1).

Contrastive divergence

The CD method uses an MCMC process to generate contrastive samples at each training iteration. Specifically, let us denote by T pz;θ the transition operator of an MCMC process (an operator that performs a single MCMC step) designed to draw samples from pz;θ . Note that in each training iteration, the MCMC process operates with respect to the current estimate of the joint distribution. To generate a contrastive sample, the MCMC process is initialized with a sample from the dataset, z0 = z, and is run for K steps, zk+1 = T pz;θ ( zk ). This results in a contrastive sample z = zK . The CD loss is then defined as L CD (θ) = E [log pz;θ (z) -log pz;θ (z)] = E [logSumExp (f (z)) -logSumExp (f (z))] , where the expectation is overdraws of contrastive samples (first term) and samples from the dataset (second term). A popular choice for an MCMC process over a continuous distributions is Langevin dynamics in which the transition operator is given by: where ε ∼ N (0, I) and µ is the step size. To keep the MCMC process accurate, we use a Metropolis-Hastings rejection step (Hastings, 1970) as part of the transition operator. T pz;θ zk = zk + µ 2 2 ∇ z log pz;θ zk + µε = zk + µ 2 2 ∇ z logSumExp (f (z)) + µε, Cross entropy Learning p t|z can be achieved by minimizing the standard cross-entropy loss over the outputs of our model, L CE (θ) = -E log pt|z;θ (t|z) = -E [logSoftmax (f (z)) t ] . ( ) This is equivalent to training a classifier to predict the discrete "time" index t given a sample z.

3.4. TRAINING THE MODEL USING JOINTCD

Our suggested model can also be trained using a CD loss on the joint distribution p z,t . We refer to this method as JointCD (see Alg. 2). The distribution p z,t is visualized in Fig. 1b . To apply CD on p z,t , we need a transition operator T pz,t;θ of an MCMC process that is designed to draw samples from pz,t;θ . Having such an operator, we can initialize the process with a sample from the dataset, ( z0 , t0 ) = (z, t), and run ( zk+1 , tk+1 ) = T pz,t;θ ( zk , tk ) for K steps to generate a contrastive sample ( z, t) = ( zK , tK ). The problem is that popular MCMC techniques, like Langevin Dynamics and Hamiltonian Monte Carlo (HMC), are relevant only for continuous distributions, whereas in our case p z,t is a mixed distribution (z is continuous and t is discrete). Nevertheless, as we show in in App. A, any continuous MCMC process can be extended to work on our joint mixed distribution simply by performing a step on z using the continuous MCMC operator T pz;θ and then sampling t from pt|z . Namely, a single MCMC step in our case takes the form T pz,t;θ ( zk , tk ) = zk+1 = T pz;θ ( zk ) tk+1 ∼ pt|z;θ (•| zk+1 ). We show in App. A that this process obeys the detailed balance criterion and thus its stationary distribution is indeed pt,z as desired. Note that the intermediate tk values are not required for sampling the next step. Therefore, in practice, we draw only the last one, tK . We illustrate this joint MCMC process in Fig. 4 . Algorithm 1: CD + CE while not converged do Sample t ∼ p t , z ∼ p z|t=t z ← z for 1 to K do z ← T pz;θ ( z) end Take gradient step on log pz;θ (z)-log pz;θ (z)-log pt|z;θ (t|z) , computing the densities using (4),(5). end Algorithm 2: JointCD while not converged do Sample t ∼ p t , z ∼ p z|t=t z ← z for 1 to K do z ← T pz;θ ( z) end Sample t from pt|z=z;θ Take gradient step on log pz,t;θ z, t -log pz,t;θ (z, t) , computing the densities using (3). end

3.5. SAMPLING FROM THE MODEL

Sampling from the trained model can be done via simulated annealing. In this approach, one runs an MCMC process while constantly replacing the underlying distribution, starting from a simple smooth distribution and gradually refining it into the target distribution. In our context, this is commonly done by running through the series of learned auxiliary distributions, starting with p z|t=0 and gradually increasing t until reaching p z|t=T -1 (which equals p x ). This algorithm is outlined in Alg. 3 in App. B, and illustrated in Fig. 5a . We note that our approach allows viewing simulated annealing as a special case of a more general sampling scheme. Specifically, we can interpret simulated annealing as running through a series of distributions p (n) z ′ , where p (n) z ′ |t ′ (z|t) = p z|t (z|t) for all n and p (n) t ′ (t) = δ(t -n) (here δ(•) denotes kronecker's delta function). We can therefore generalize this method by using any sequence of distributions {p (n) t ′ (t)} whose centers of mass gradually move from the small values of t to the larger ones. In this generalized setting, we have p (n) z ′ ,t ′ (z, t) = p z|t (z|t)p (n) t ′ (t) = p z,t (z, t)p (n) t ′ (t)/p t (t), from which p (n) z ′ can be extracted by summation over t. The resulting generalized simulated annealing algorithm is outlined in Alg. 4 in App. B. One particular choice of p (n) t ′ is p (n) t ′ = Uniform[t(n), T -1] where t(n) is a linear function growing from 0 to T -1. For this choice, p (n) z ′ is essentially the mean of all the auxiliary distributions, from t(n) to t = T -1. We refer to this variant as soft simulated annealing and exemplify it in Fig. 5b . 

4. EXPERIMENTS

We now illustrate the efficacy of our method on a toy problem as well as on the CFIAR10 and CelebA datasets. Code for all experiments will be released upon acceptance of the paper.

4.1. TOY MODEL

The toy problem appearing in Fig. 5c involves data lying on a 2D shifted spiral, where the shifting has been introduced to aid the visualization. To apply our method, we selected a Gaussian distribution centered at the origin as a reference distribution and defined 256 auxiliary distributions, according to (1) . We then used JointCD to train a neural network. For the MCMC process, we used Metropolis-Hastings adjusted Langevin dynamics with K = 3 steps. The details of the network and the training process can be found in App. C. The resulting learned model is shown in Fig. 5a .

4.2. CELEB A & CIFAR10

We trained models using both CD+CE and JointCD on CIFAR10 and on CelebA at 32 × 32 resolution. Here we have used 1024 auxiliary distribution, also according to the linear interpolation in (1) As our parametric model, we used ResNet18 (He et al., 2016) with minor changes. Here as well, we used adjusted Langevin dynamics, but during training we gradually increased the lengths of the (Karras et al., 2020) 9.83 2.92

Normalizing flows

Residual flow (Chen et al., 2019) 46.37 FCE (Gao et al., 2020a) 37.3 Score based NCSN-v2 (Song & Ermon, 2020) 8.4 10.87 DDPM (Ho et al., 2020) 9.46 3.17 EBMs (ML based) CoopNets (Xie et al., 2018) 6.55 33.61 Multi-grid EBM (Gao et al., 2018) 40.01 6.56 CF-EBM (Zhao et al., 2020) 16.71 EBM-DRL (Gao et al., 2020b) 8 For generating samples, we used simulated annealing. We found that on CIFAR10, the CD+CE method performs better than JointCD. The generation process is visualized in Fig. 6a , and generated samples are shown in Fig. 6b . These results are all from the CD+CE model. Results from the JointCD method can be found in App. D. We used the inception score (Salimans et al., 2016) and FID (Heusel et al., 2017) to evaluate the CIFAR10 resultsfoot_2 . With the CD+CE method, we achieved an inception score of 8.5 and an FID score of 23.7. As can be seen in Table 1 , these results improve upon previous CD based techniques.

5. CONCLUSION

We presented new methods for harnessing coarse-to-fine series of distributions for learning EBMs. Our approach views the "time" index of the series as a random variable and defines an auxiliary task of learning the joint distribution of "time" and samples. We illustrated how using this approach in conjunction with the CD method, leads to substantially improved results. One limitation of our method is that it requires relatively long training times. However, we believe that with further hyperparameter tuning and correct architectural choices, this can be somewhat alleviated in the future. Our joint modelling approach can in principle be used within other generative models, but we leave those directions for future research.

A WHY JOINTCD OBEYS DETAILED BALANCE

A sufficient condition for an MCMC process to converge is that it maintains a relation that is know as detailed balance between the conditional distribution of the transitions and the underlying distribution from which we would like to sample. Specifically given an underlying distribution p z (z) and an MCMC process with a conditional probability of p z|z ( z|z) to step form z to z, then detailed balance criteria is given by: p z|z (z|z) p z (z) p z|z (z|z) p z (z) = 1 That exist various MCMC process that operate on continuous random variables and obey (usually approximately) the detailed balance criteria. In order to apply one of these existing MCMC process to the joint distribution of p z,t we suggest the following way to take a step from (z, t) to ( ṽz, t). In each step we begin by ignoring t and generate z using the continuous MCMC process on z with regard to the marginal distribution p z . We then sample t based on z according to the conditional distribution p t|z . I.e., : p z, t|z,t z, t|z, t = p t|z t|z p z|z (z|z) (11) As long as the continuous MCMC process obeys the detailed balance criteria than so does the suggested process with respect to p z,t . This can be seen from following derivation:  p z,

B THE SAMPLING ALGORITHMS

We outline in alg. 3 the process for generating samples from the model using the common simulated annealing, and in alg. 4 the process of using the generalized simuulated annealing described in section 3.5. For the toy model, we used a network of 4 residual blocks containing fully connected layers with a width of 256. As a reference distribution we used white Gaussian distribution with zero mean and STD σ n = 0.3. We found it beneficial to have the STD of the reference distribution slightly larger then that of the data. We trained the model using JointCD (Alg. 2) with Metropolis-Hastings adjusted Langevin dynamics. The Langevin step size was been adaptively adjusted during the run to maintain an average acceptance rate of 60% in the Metropolis-Hastings adjustment stage. This was done by keeping an array of an individual step size for each value of t. We have found the step size of each t to converge in an early state of the training to about 0.4β t σ n , where 



It is common to represent all models by a single neural network that accepts the "time" index as input. But for each "time" step, the network is exposed only to samples from the corresponding distribution. Strictly speaking, when using EBM-based methods, each p z|t is learned up to an unknown normalization constant, preventing computation of pz,t. Methods using this technique, only use the individual p z|t at test time. Using the PyTorch implementation from https://pypi.org/project/pytorch-gan-metrics/, which has been shown to reproduce the scores of the original implementation with an error smaller then 0.2%.



Figure 2: (a) Samples from the marginal distribution p z . (b) The conditional distribution p t|z for the three red points of z in (a).

Figure 3: Our parametric model consumes a sample z and outputs a vector of log probabilities indicating the joint likelihoods of z and t for all values of t = 0, . . . , T -1.

Figure 4: We present three MCMC processes running over the joint distribution of the toy model. Note that the intermediate t values are used neither in CD+CE nor in JointCD, and are shown here only for completeness. Particularly, in JointCD the MCMC runs over the marginal distribution of z, so that t plays no role, and in JointCD we draw t only once for the final point.

Figure 5: We depict intermediate samples along with the underlying distribution from the basic simulated annealing (a) and the soft simulated annealing (b) processes applied to the learned spiral toy model. The ground truth distribution is shown in pane (c).

(a) Sample generation using simulated annealing from a model trained with CD+CE. (b) Images generated from models trained on CelebA and CIFAR10 using CD+CE.

Figure 6: Results on CIFAR10 and CelebA.

Basic Simulated Annealing Draw a sample z from p n = p z|t=0 for t = 0 to T -1 do for n = 0 to N do z ← T pz|t=t;θ ( z) end end Algorithm 4: General Simulated Annealing Draw a sample z from p n = p z|t=0 for n = 0 to N do p

Figure 7: Generated images from CelebA using a model trained using CD+CE

FID and inception scores for methods trained unconditionally on CIAFR10

It is worth noting that due to the need to take a large number of MCMC steps before each gradient step, the training is slow and took 7 days on 4 RXT-2080Ti GPUs. The full details of the network and the training process can be found in App. C.

t|z,t z, t|z, t p z,t (z, t) p z, t|z,t z, t|z, t p z,t z, t = p t|z t|z p z|z (z|z) p z,t (z, t) p t|z (t|z) p z|z (z|z) p z,t z, t (12) = p t|z t|z p z|z (z|z) p t|z (t|z) p z (z) p t|z (t|z) p z|z (z|z) p t|z t|z p

