DEEP GENERATIVE WASSERSTEIN GRADIENT FLOWS

Abstract

Deep generative modeling is a rapidly-advancing field with a wealth of modeling choices developed in the past decade. Amongst them, Wasserstein gradient flows (WGF) are a powerful and theoretically rich class of methods. However, their applications to high-dimensional distributions remain relatively underexplored. In this paper, we present Deep Generative Wasserstein Gradient Flows (DGGF), which constructs a WGF minimizing the entropy-regularized f -divergence between two distributions. We demonstrate how to train a deep density ratio estimator that is required for the WGF and apply it to the task of generative modeling. Experiments demonstrate that DGGF is able to synthesize high-fidelity images of resolutions up to 128 × 128, directly in data space. We demonstrate that DGGF has an interpretable diagnostic of sample quality by naturally estimating the KL divergence throughout the gradient flow. Finally, we show DGGF's modularity by composition with external density ratio estimators for conditional generation, as well as for unpaired image-to-image translation without modifications to the underlying framework.

1. INTRODUCTION

Gradient flow methods are a powerful and general class of techniques with diverse applications ranging from physics (Carrillo et al., 2019; Adams et al., 2011) and sampling (Bernton, 2018) to neural network optimization (Chizat & Bach, 2018) and reinforcement learning (Richemond & Maginnis, 2017; Zhang et al., 2018) . In particular, Wasserstein gradient flow (WGF) methods are a popular specialization that model the gradient dynamics on the space of probability measures with respect to the Wasserstein metric; these methods aim to construct the optimal path between two probability measures -a source distribution q(x) and a target distribution p(x) -where the notion of optimality refers to the path of steepest descent in Wasserstein space. The freedom in choosing q(x) and p(x) when constructing the WGF makes the framework a natural fit for a variety of generative modeling tasks. For data synthesis, we choose q(x) to be a simple distribution easy to draw samples from (e.g., Gaussian), and p(x) to be a complex distribution which we would like to learn (e.g., the distribution of natural images). The WGF then constructs the optimal path from the simple distribution to synthesize data resembling that from the complex distribution. Furthermore, we could choose both p(x) and q(x) to be distributions from different domains of the same modality (e.g., images from separate domains). The WGF then naturally performs domain translation. However, despite this fit and the wealth of theoretical work established over the past decades (Ambrosio et al., 2005; Santambrogio, 2017) , applications of WGFs to generative modeling of highdimensional distributions remain under-explored and limited. A key difficulty is that the 2-Wasserstein distance and divergence functionals are generally intractable. Existing works rely on complex optimization schemes with constraints that contribute to model complexity, such as approximations of the 2-Wasserstein distance with input convex neural networks (Mokrov et al., 2021) , dual variational optimization schemes with the Fenchel conjugate (Fan et al., 2021) or adopting a particle simulation approach, but amortizing sample generation to auxiliary generators (Gao et al., 2019; 2022) . In this work, we take a step towards resolving the shortcomings of WGF methods for deep generative modeling. We propose Deep Generative Wasserstein Gradient Flows (DGGF), which is formulated using the gradient flow of entropy-regularized f -divergences (Fig. 1 ). As this formulation involves

