VECTOR QUANTIZED WASSERSTEIN AUTO-ENCODER

Abstract

Learning deep discrete latent presentations offers a promise of better symbolic and summarized abstractions that are more useful to subsequent downstream tasks. Recent work on Vector Quantized Variational Auto-Encoder (VQ-VAE) has made substantial progress in this direction. However, this quantizes latent representations using the online k-means algorithm which suffers from poor initialization and non-stationary clusters. To strengthen the clustering quality for the latent representations, we propose Vector Quantized Wasserstein Auto-Encoder (VQ-WAE) intuitively developed based on the clustering viewpoint of Wasserstein (WS) distance. Specifically, we endow a discrete distribution over the codewords and learn a deterministic decoder that transports the codeword distribution to the data distribution via minimizing a WS distance between them. We develop further theories to connect it with the clustering viewpoint of WS distance, allowing us to have a better and more controllable clustering solution. Finally, we empirically evaluate our method on several well-known benchmarks, where it achieves better qualitative and quantitative performances than the baselines in terms of the codebook utilization and image reconstruction/generation.

1. INTRODUCTION

Learning compact yet expressive representations from large-scale and high-dimensional unlabeled data is an important and long-standing task in machine learning (Kingma & Welling, 2013; Chen et al., 2020; Chen & He, 2021; Zoph et al., 2020) . Among many different kinds of methods, Variational Auto-Encoder (VAE) (Kingma & Welling, 2013) and its variants (Tolstikhin et al., 2017; Alemi et al., 2016; Higgins et al., 2016; Voloshynovskiy et al., 2019) have shown great success in unsupervised representation learning. Although these continuous representation learning methods have been applied successfully to various problems ranging from images (Pathak et al., 2016; Goodfellow et al., 2014; Kingma et al., 2016), video and audio (Reed et al., 2017; Oord et al., 2016; Kalchbrenner et al., 2017) , in some contexts, input data are more naturally modeled and encoded as discrete symbols rather than continuous ones. For example, discrete representations are a natural fit for complex reasoning, planning and predictive learning (Van Den Oord et al., 2017) . This motivates the need of learning discrete representations, preserving the insightful characteristics of input data. Vector Quantization Variational Auto-Encoder (VQ-VAE) (Van Den Oord et al., 2017) is a pioneer generative model, which successfully combines the VAE framework with discrete latent representations. In particular, the vector quantized models learn a compact discrete representation using a deterministic encoder-decoder architecture in the first stage, and subsequently applied this highly compressed representation for various downstream tasks, examples including image generation (Esser et al., 2021 ), cross-modal translation (Kim et al., 2022) , and image recognition (Yu et al., 2021) . While VQ-VAE has been widely applied to representation learning in many areas (Henter et al., 2018; Baevski et al., 2020; Razavi et al., 2019; Kumar et al., 2019; Dieleman et al., 2018; Yan et al., 2021) , it is known to suffer from codebook collapse, which has a low codebook usage, i.e. most of embedded latent vectors are quantized to just few discrete codewords, while the other codewords are rarely used, or dead, due to the poor initialization of the codebook, reducing the information capacity of the bottleneck (Roy et al., 2018; Takida et al., 2022; Yu et al., 2021) . To mitigate this issue, additional training heuristics were proposed, such as the exponential moving average (EMA) update (Van Den Oord et al., 2017; Razavi et al., 2019) , soft expectation maximization (EM) update (Roy et al., 2018 ), codebook reset (Dhariwal et al., 2020; Williams et al., 2020) . Notably, soft expectation maximization (EM) update (Roy et al., 2018) connects the EMA update 1

