LEARNING DIFFUSION BRIDGES ON CONSTRAINED DOMAINS

Abstract

Diffusion models have achieved promising results on generative learning recently. However, because diffusion processes are most naturally applied on the unconstrained Euclidean space R d , key challenges arise for developing diffusion based models for learning data on constrained and structured domains. We present a simple and unified framework to achieve this that can be easily adopted to various types of domains, including product spaces of any type (be it bounded/unbounded, continuous/discrete, categorical/ordinal, or their mix). In our model, the diffusion process is driven by a drift force that is a sum of two terms: one singular force designed by Doob's h-transform that ensures all outcomes of the process to belong to the desirable domain, and one non-singular neural force field that is trained to make sure the outcome follows the data distribution statistically. Experiments show that our methods perform superbly on generating tabular data, images, semantic segments and 3D point clouds.

1. INTRODUCTION

Diffusion-based deep generative models, notably score matching with Langevin dynamics (SMLD) (Song & Ermon, 2019 , 2020) , denoising diffusion probabilistic models (DDPM) (Ho et al., 2020) , and their variants (e.g., Song et al., 2020b,a; Kong & Ping, 2021; Song et al., 2021; Nichol & Dhariwal, 2021) , have shown to achieve new state of the art results for image synthesis (Dhariwal & Nichol, 2021; Ramesh et al., 2022; Ho et al., 2022; Liu et al., 2021) , audio synthesis (Chen et al., 2020; Kong et al., 2020) , point cloud synthesis (Luo & Hu, 2021a,b; Zhou et al., 2021) , and many other AI tasks. These methods train a deep neural network to drive as drift force a diffusion process to generate data, and are shown to outperform competitors, mainly GANs and VAEs, on stability and sample diversity (Xiao et al., 2021; Ho et al., 2020; Song et al., 2020b) . However, due to the continuous nature of diffusion processes, the standard approaches are restricted to generating unconstrained continuous data in R d . For generating data constrained on special structured domains, such as discrete, bound data or mixes of them, special techniques , e.g., dequantization (Uria et al., 2013; Ho et al., 2019) and multinomial diffusion (Hoogeboom et al., 2021; Austin et al., 2021) , need to be developed case by case and the results still tend to be unsatisfying despite promising recent advances (Hoogeboom et al., 2021; Austin et al., 2021) . This work proposes a simple and unified framework for learning diffusion models on general constrained domains Ω embedded in the Euclidean space R d . The idea is to learn a continuous R d -valued diffusion process Z t on time interval t ∈ [0, T ], with a carefully designed force field, such that the final state Z T guarantees to 1) fall into the desirable domain Ω, and 2) follows the data distribution asymptotically. We achieve both steps by leveraging a key tool in stochastic calculus called Doob's h-transform (Doob, 1984) , which provides formula for deriving the diffusion processes whose final states are guaranteed to fall into a specific set or equal a specific value. ⇧ ⇤ < l a t e x i t s h a 1 _ b a s e 6 4 = " b a X Q B V Q z d l Q 1 v B l k b V v d q N y T 5 g 8 = " > A A A B 7 H i c d V D L S g M x F M 3 U V 6 2 v q k s 3 w S K I i y E z H V u 7 K 7 p x W c F p C + 1 Y M m m m D c 1 k h i Q j l N J v c O N C E b d + k D v / x v Q h q O i B C 4 d z 7 u X e e 8 K U M 6 U R + r B y K 6 t r 6 x v 5 z c L W 9 s 7 u X n H / o K m S T B L q k 4 Q n s h 1 i R T k T 1 N d M c 9 p O J c V x y G k r H F 3 N / N Y 9 l Y o l 4 l a P U x r E e C B Y x A j W R v K 7 D X Z 3 1 i u W k O 2 6 F V S u Q G R 7 V Q + 5 n i G 1 8 5 r j O t C x 0 R w l s E S j V 3 z v 9 h O S x V R o w r F S H Q e l O p h g q R n h d F r o Z o q m m I z w g H Y M F T i m K p j M j 5 3 C E 6 P 0 Y Z R I U 0 L D u f p 9 Y o J j p c Z x a D p j r I f q t z c T / / I 6 m Y 4 u g g k T a a a p I I t F U c a h T u D s c 9 h n k h L N x 4 Z g I p m 5 F Z I h l p h o k 0 / B h P D 1 K f y f N F 3 b 8 e z y j V e q X y 7 j y I M j c A x O g Q O q o A 6 u Q Q P 4 g A A G H s A T e L a E 9 W i 9 W K + L 1 p y 1 n D k E P 2 C 9 f Q K c Z 4 6 S < / l a t e x i t > Time 𝑡 ℙ 𝜽

Algorithm 1 Learning Diffusion Models on Constrained Domains (a Simple Example)

Input: A dataset D := {x (i) } drawn from distribution Π * on a domain Ω = {e 1 , e 2 , . . . , e K }. Goal: Learn a diffusion model that terminates at time T to generate samples from Π * . Learning: Solve the optimization below with stochastic gradient descent (or other optimizers) θ * = arg min θ T 0 E x∼D f θ (Z t , t) -∇ Zt log ω Ω (x | Z t , t) 2 dt, where ω Ω (x | z, t) = exp - ∥z -x∥ 2 2(T -t) e∈Ω exp - ∥z -e∥ 2 2(T -t) , Z t = t T x + (1 - t T )x 0 + t(T -t) T ξ, with x drawn from the dataset D, ξ ∼ N (0, I), and x 0 any initial point. Sampling: Generate sample Z T from dZ t = f θ * (Z t , t) + ∇ Zt log e∈Ω exp - ∥Z t -e∥ 2 2(T -t) dt + dW t , Z 0 = x 0 . Remark When the domain Ω is a manifold (e.g., line or surface) in R d , simply replace the sum e∈Ω with the corresponding line or surface (in general Hausdorff) integration Ω on Ω. Our simple procedure can be applied to any domain Ω once a properly defined summation (for discrete sets) or integration (for continuous domains) can be evaluated. To give a quick overview on the practical intuition without invoking the mathematical theory, we show in Algorithm 1 a simple instance of the framework when the domain is a discrete set Ω = {e 1 , . . . , e K }. The idea is to set up the diffusion model to have a form of dZ t = f θ (Z t , t) + ∇ Zt ψ Ω (Z t , t) dt + dW t , ψ Ω (z, t) := log exp - ∥z -e∥ 2 2(T -t) , where the drift is a sum of a non-singular (e.g., bounded) term f θ (z, t) which is a trainable neural force field with parameter θ, and a singular term ∇ z ψ Ω (z, t), which drives Z t towards set Ω as a gradient ascent on ψ Ω (z, t). The ψ Ω (z, t) measures the closeness of z to set Ω, as the log-likelihood of a Gaussian mixture model (GMM) centered on the elements in Ω with variance Tt. When t approaches to the terminal time T , the variance Tt of the GMM goes to zero, and the magnitude of ∇ z ψ Ω (z, t) grows to infinity, hence ensuring that Z T must belong to Ω. In particular, note that ∇ z ψ Ω (z, t) = e∈Ω ω Ω (e | z, t) e -z T -t , ω Ω (e | z, t) = exp -∥z-e∥ 2 2(T -t) exp(ψ Ω (z, t)) , which increases with an O(1/(Tt)) rate as t → T ; here ω Ω (e | z, t) is the softmax probability measuring the relative closeness of z to the elements e in Ω (see also Eq (2)). As we show in Section 2.3, once f θ is non-singular in the sense of the mild condition of T 0 E[ f θ (Z t , t) 2 ]dt < +∞, the diffusion model in (3) guarantees to yield a final state Z T that belongs to Ω, and hence provides a flexible model family on Ω. Moreover, as shown in Eq 1 in Algorithm 1, the neural field f θ can be simply trained to approximate ∇ log ω Ω (e | z, t) with e plugged as the data point that we expect to achieve when starting from z at time t. Intuitively, such fitted f θ increases the relative probability of the observed data points and hence allows us to fit the data distribution. Empirically, diffusion models learned through Ω-bridge achieves favorable results in generating mixed discrete/continuous tabular data, point clouds on grids, categorical semantic segments and discrete CIFAR10 images.



Figure 1: An Ω-Bridge on discrete domain Ω = {1, 2, 3, 4}.

availability

//github.com/gnobitab

