LEARNING DISCRETE REPRESENTATION WITH OPTIMAL TRANSPORT QUANTIZED AUTOENCODERS

Abstract

Vector quantized variational autoencoder (VQ-VAE) has recently emerged as a powerful generative model for learning discrete representations. Like other vector quantization methods, one key challenge of training VQ-VAE comes from the codebook collapse, i.e. only a fraction of codes are used, limiting its reconstruction qualities. To this end, VQ-VAE often leverages some carefully designed heuristics during the training to use more codes. In this paper, we propose a simple yet effective approach to overcome this issue through optimal transport, which regularizes the quantization by explicitly assigning equal number of samples to each code. The proposed approach, named OT-VAE, enforces the full utilization of the codebook while not requiring any heuristics such as stop-gradient, exponential moving average, and codebook reset. We empirically validate our approach on three different data modalities: images, speech and 3D human motions. For all the modalities, OT-VAE shows better reconstruction with higher perplexity than other VQ-VAE variants on several datasets. In particular, OT-VAE achieves stateof-the-art results on the AIST++ dataset for 3D dance generation. Our code will be released upon publication.

1. INTRODUCTION

Unsupervised generative modeling aims at generating samples following the same distribution as the observed data. Recent deep generative models have shown impressive performance in generating various data modalities such as image, text and audio, owing to the use of a huge number of parameters in their models. The well known examples include VQ-GAN (Esser et al., 2021) for high-resolution image synthesis, DALLE (Ramesh et al., 2021) for realistic image generation from a description in natural language, and Jukebox (Dhariwal et al., 2020) for music generation. Surprisingly, all these models are based, at least partly, on Vector Quantized Variational Autoencoders (VQ-VAE) (Van Den Oord et al., 2017) . The success of VQ-VAE should be mostly attributed to its ability of learning discrete, rather than continuous, latent representations and its decoupling of learning the discrete representation and the prior. The quality of the discrete representation is essential to the quality of the generation and our work improves upon the discrete representation learning for arbitrary data modality. VQ-VAE is a variant of VAEs (Kingma & Welling, 2014) that first encodes the input data to a discrete variable in a latent space, and then decodes the latent variable to a sample of the input space. The discrete representation of the latent variable is enabled by vector quantization, generally through a nearest neighbor look up in a learnable codebook. A new sample is then generated by decoding a discrete latent variable sampled from an approximate prior, which is learned on the space of the encoded discrete latent variables in a decoupled fashion using any autoregressive model (Van Den Oord et al., 2017) . Despite its promising results in many tasks of generating complex data modalities, the naive training scheme of VQ-VAE used in (Van Den Oord et al., 2017) often suffers from codebook collapse (Takida et al., 2022) , i.e. only a fraction of codes are effectively used, which largely limits the quality of the discrete latent representations. To this end, many techniques and variants have been proposed, such as stop-gradient along with the commitment and embedding loss (Van Den Oord et al., 2017) , exponential moving average (EMA) for codebook update (Van Den Oord et al., 2017 ), codebook reset (Williams et al., 2020) and a stochastic variant (SQ-VAE) (Takida et al., 2022) . 1

