VECTOR QUANTIZED WASSERSTEIN AUTO-ENCODER

Abstract

Learning deep discrete latent presentations offers a promise of better symbolic and summarized abstractions that are more useful to subsequent downstream tasks. Recent work on Vector Quantized Variational Auto-Encoder (VQ-VAE) has made substantial progress in this direction. However, this quantizes latent representations using the online k-means algorithm which suffers from poor initialization and non-stationary clusters. To strengthen the clustering quality for the latent representations, we propose Vector Quantized Wasserstein Auto-Encoder (VQ-WAE) intuitively developed based on the clustering viewpoint of Wasserstein (WS) distance. Specifically, we endow a discrete distribution over the codewords and learn a deterministic decoder that transports the codeword distribution to the data distribution via minimizing a WS distance between them. We develop further theories to connect it with the clustering viewpoint of WS distance, allowing us to have a better and more controllable clustering solution. Finally, we empirically evaluate our method on several well-known benchmarks, where it achieves better qualitative and quantitative performances than the baselines in terms of the codebook utilization and image reconstruction/generation.

1. INTRODUCTION

Learning compact yet expressive representations from large-scale and high-dimensional unlabeled data is an important and long-standing task in machine learning (Kingma & Welling, 2013; Chen et al., 2020; Chen & He, 2021; Zoph et al., 2020) . Among many different kinds of methods, Variational Auto-Encoder (VAE) (Kingma & Welling, 2013) and its variants (Tolstikhin et al., 2017; Alemi et al., 2016; Higgins et al., 2016; Voloshynovskiy et al., 2019) have shown great success in unsupervised representation learning. Although these continuous representation learning methods have been applied successfully to various problems ranging from images (Pathak et al., 2016; Goodfellow et al., 2014; Kingma et al., 2016), video and audio (Reed et al., 2017; Oord et al., 2016; Kalchbrenner et al., 2017) , in some contexts, input data are more naturally modeled and encoded as discrete symbols rather than continuous ones. For example, discrete representations are a natural fit for complex reasoning, planning and predictive learning (Van Den Oord et al., 2017) . This motivates the need of learning discrete representations, preserving the insightful characteristics of input data. Vector Quantization Variational Auto-Encoder (VQ-VAE) (Van Den Oord et al., 2017) is a pioneer generative model, which successfully combines the VAE framework with discrete latent representations. In particular, the vector quantized models learn a compact discrete representation using a deterministic encoder-decoder architecture in the first stage, and subsequently applied this highly compressed representation for various downstream tasks, examples including image generation (Esser et al., 2021 ), cross-modal translation (Kim et al., 2022) , and image recognition (Yu et al., 2021) . While VQ-VAE has been widely applied to representation learning in many areas (Henter et al., 2018; Baevski et al., 2020; Razavi et al., 2019; Kumar et al., 2019; Dieleman et al., 2018; Yan et al., 2021) , it is known to suffer from codebook collapse, which has a low codebook usage, i.e. most of embedded latent vectors are quantized to just few discrete codewords, while the other codewords are rarely used, or dead, due to the poor initialization of the codebook, reducing the information capacity of the bottleneck (Roy et al., 2018; Takida et al., 2022; Yu et al., 2021) . To mitigate this issue, additional training heuristics were proposed, such as the exponential moving average (EMA) update (Van Den Oord et al., 2017; Razavi et al., 2019) , soft expectation maximization (EM) update (Roy et al., 2018) , codebook reset (Dhariwal et al., 2020; Williams et al., 2020) . Notably, soft expectation maximization (EM) update (Roy et al., 2018) connects the EMA update with an EM algorithm and softens the EM algorithm with a stochastic posterior. Codebook reset randomly reinitializes unused/low-used codewords to one of the encoder outputs (Dhariwal et al., 2020) or those near codewords of high usage Williams et al. (2020) . Takida et al. (2022) suspects that deterministic quantization is the cause of codebook collapse and extends the standard VAE with stochastic quantization and trainable posterior categorical distribution, showing that the annealing of the stochasticity of the quantization process significantly improves the codebook utilization. Additionally, WS distance has been applied successfully to generative models and continuous representation learning (Arjovsky et al., 2017; Gulrajani et al., 2017; Tolstikhin et al., 2017) owing to its nice properties and rich theory. It is natural to ask: "Can we take advantages of intuitive properties of the WS distance and its mature theory for learning highly compact yet expressive discrete representations?" Toward this question, in this paper, we develop solid theories by connecting the theory bodies and viewpoints of the WS distance, generative models, and deep discrete representation learning. In particular, a) we first endow a discrete distribution over the codebook and propose learning a "deterministic decoder transporting the codeword to data distributions" via minimizing the WS distance between them; b) To devise a trainable algorithm, we develop Theorem 3.1 to equivalently turn the above WS minimization to push-forwarding the data to codeword distributions via minimizing a WS distance between "the latent representation and codeword distributions"; c) More interestingly, our Corollary 3.1 proves that when minimizing the WS distance between the latent representation and codeword distributions, the codewords tend to flexibly move to the clustering centroids of the latent representations with a control on the proportion of latent representations associated to a centroid. We argue and empirically demonstrate that using the clustering viewpoint of a WS distance to learn the codewords, we can obtain more controllable and better centroids than using a simple k-means as in VQ-VAE (cf. Sections 3.1 and 5.2). Our method, called Vector Quantized Wasserstein Auto-Encoder (VQ-WAE), applies the WS distance to learn a more controllable codebook, hence leading to an improvement in the codebook utilization. We conduct comprehensive experiments to demonstrate our key contributions by comparing with VQ-VAE (Van Den Oord et al., 2017) and SQ-VAE (Takida et al., 2022) (i.e., the recent work that can improve the codebook utilization). The experimental results show that our VQ-WAE can achieve better codebook utilization with higher codebook perplexity, hence leading to lower (compared with VQ-VAE) or comparable (compared with SQ-VAE) reconstruction error, with significantly lower reconstructed Fréchlet Inception Distance (FID) score (Heusel et al., 2017) . Generally, a better quantizer in the stage-1 can naturally contribute to stage-2 downstream tasks (Yu et al., 2021; Zheng et al., 2022) . To further demonstrate this, we conduct comprehensive experiments on four benchmark datasets for both unconditional and class-conditional generation tasks. The experimental results indicate that from the codebooks of our VQ-WAE, we can generate better images with lower FID scores. The purpose of VQ-VAE training is to form the latent representations in clusters and adjust the codewords to be the centroids of these clusters.



Given a training set D = {x 1 , ..., x N } ⊂ R V , VQ-VAE(Van Den Oord et al., 2017)  aims at learning a codebook which is formed by set of codewords C = [c k ] K k=1 ∈ R K×D on the latent space Z ∈ R D , an encoder f e to map the data examples to the codewords, and a decoder f d (i.e., q (x | z)) to reconstruct accurately the data examples from the codewords. Given a data example x, the encoder f e (i.e., p (z | x)) associates x to the codeword fe (x) = c defined as c = argmin k d z (f e (x) , c k ) , where d z is a metric on the latent space.The objective function of VQ-VAE is as follows:E x∼Px d x f d fe (x) , x + d z (sg (f e (x)) , C) + βd z (f e (x) , sg (C)) ,where P x = 1 N N n=1 δ xn is the empirical data distribution, sg specifies stop gradient, d x is a cost metric on the data space, and β is set between 0.1 and 2.0(Van Den Oord et al., 2017)  and d z (f e (x) , C) = c∈C d z (f e (x) , c).

