VECTOR QUANTIZED WASSERSTEIN AUTO-ENCODER

Abstract

Learning deep discrete latent presentations offers a promise of better symbolic and summarized abstractions that are more useful to subsequent downstream tasks. Recent work on Vector Quantized Variational Auto-Encoder (VQ-VAE) has made substantial progress in this direction. However, this quantizes latent representations using the online k-means algorithm which suffers from poor initialization and non-stationary clusters. To strengthen the clustering quality for the latent representations, we propose Vector Quantized Wasserstein Auto-Encoder (VQ-WAE) intuitively developed based on the clustering viewpoint of Wasserstein (WS) distance. Specifically, we endow a discrete distribution over the codewords and learn a deterministic decoder that transports the codeword distribution to the data distribution via minimizing a WS distance between them. We develop further theories to connect it with the clustering viewpoint of WS distance, allowing us to have a better and more controllable clustering solution. Finally, we empirically evaluate our method on several well-known benchmarks, where it achieves better qualitative and quantitative performances than the baselines in terms of the codebook utilization and image reconstruction/generation.

1. INTRODUCTION

Learning compact yet expressive representations from large-scale and high-dimensional unlabeled data is an important and long-standing task in machine learning (Kingma & Welling, 2013; Chen et al., 2020; Chen & He, 2021; Zoph et al., 2020) . Among many different kinds of methods, Variational Auto-Encoder (VAE) (Kingma & Welling, 2013) and its variants (Tolstikhin et al., 2017; Alemi et al., 2016; Higgins et al., 2016; Voloshynovskiy et al., 2019) have shown great success in unsupervised representation learning. Although these continuous representation learning methods have been applied successfully to various problems ranging from images (Pathak et al., 2016; Goodfellow et al., 2014; Kingma et al., 2016) , video and audio (Reed et al., 2017; Oord et al., 2016; Kalchbrenner et al., 2017) , in some contexts, input data are more naturally modeled and encoded as discrete symbols rather than continuous ones. For example, discrete representations are a natural fit for complex reasoning, planning and predictive learning (Van Den Oord et al., 2017) . This motivates the need of learning discrete representations, preserving the insightful characteristics of input data. Vector Quantization Variational Auto-Encoder (VQ-VAE) (Van Den Oord et al., 2017) is a pioneer generative model, which successfully combines the VAE framework with discrete latent representations. In particular, the vector quantized models learn a compact discrete representation using a deterministic encoder-decoder architecture in the first stage, and subsequently applied this highly compressed representation for various downstream tasks, examples including image generation (Esser et al., 2021) , cross-modal translation (Kim et al., 2022) , and image recognition (Yu et al., 2021) . While VQ-VAE has been widely applied to representation learning in many areas (Henter et al., 2018; Baevski et al., 2020; Razavi et al., 2019; Kumar et al., 2019; Dieleman et al., 2018; Yan et al., 2021) , it is known to suffer from codebook collapse, which has a low codebook usage, i.e. most of embedded latent vectors are quantized to just few discrete codewords, while the other codewords are rarely used, or dead, due to the poor initialization of the codebook, reducing the information capacity of the bottleneck (Roy et al., 2018; Takida et al., 2022; Yu et al., 2021) . To mitigate this issue, additional training heuristics were proposed, such as the exponential moving average (EMA) update (Van Den Oord et al., 2017; Razavi et al., 2019) , soft expectation maximization (EM) update (Roy et al., 2018) , codebook reset (Dhariwal et al., 2020; Williams et al., 2020) . Notably, soft expectation maximization (EM) update (Roy et al., 2018) connects the EMA update with an EM algorithm and softens the EM algorithm with a stochastic posterior. Codebook reset randomly reinitializes unused/low-used codewords to one of the encoder outputs (Dhariwal et al., 2020) or those near codewords of high usage Williams et al. (2020) . Takida et al. (2022) suspects that deterministic quantization is the cause of codebook collapse and extends the standard VAE with stochastic quantization and trainable posterior categorical distribution, showing that the annealing of the stochasticity of the quantization process significantly improves the codebook utilization. Additionally, WS distance has been applied successfully to generative models and continuous representation learning (Arjovsky et al., 2017; Gulrajani et al., 2017; Tolstikhin et al., 2017) owing to its nice properties and rich theory. It is natural to ask: "Can we take advantages of intuitive properties of the WS distance and its mature theory for learning highly compact yet expressive discrete representations?" Toward this question, in this paper, we develop solid theories by connecting the theory bodies and viewpoints of the WS distance, generative models, and deep discrete representation learning. In particular, a) we first endow a discrete distribution over the codebook and propose learning a "deterministic decoder transporting the codeword to data distributions" via minimizing the WS distance between them; b) To devise a trainable algorithm, we develop Theorem 3.1 to equivalently turn the above WS minimization to push-forwarding the data to codeword distributions via minimizing a WS distance between "the latent representation and codeword distributions"; c) More interestingly, our Corollary 3.1 proves that when minimizing the WS distance between the latent representation and codeword distributions, the codewords tend to flexibly move to the clustering centroids of the latent representations with a control on the proportion of latent representations associated to a centroid. We argue and empirically demonstrate that using the clustering viewpoint of a WS distance to learn the codewords, we can obtain more controllable and better centroids than using a simple k-means as in VQ-VAE (cf. Sections 3.1 and 5.2). Our method, called Vector Quantized Wasserstein Auto-Encoder (VQ-WAE), applies the WS distance to learn a more controllable codebook, hence leading to an improvement in the codebook utilization. We conduct comprehensive experiments to demonstrate our key contributions by comparing with VQ-VAE (Van Den Oord et al., 2017) and SQ-VAE (Takida et al., 2022) (i.e., the recent work that can improve the codebook utilization). The experimental results show that our VQ-WAE can achieve better codebook utilization with higher codebook perplexity, hence leading to lower (compared with VQ-VAE) or comparable (compared with SQ-VAE) reconstruction error, with significantly lower reconstructed Fréchlet Inception Distance (FID) score (Heusel et al., 2017) . Generally, a better quantizer in the stage-1 can naturally contribute to stage-2 downstream tasks (Yu et al., 2021; Zheng et al., 2022) . To further demonstrate this, we conduct comprehensive experiments on four benchmark datasets for both unconditional and class-conditional generation tasks. The experimental results indicate that from the codebooks of our VQ-WAE, we can generate better images with lower FID scores.

2. VECTOR QUANTIZED VARIATIONAL AUTO-ENCODER

Given a training set D = {x 1 , ..., x N } ⊂ R V , VQ-VAE (Van Den Oord et al., 2017) aims at learning a codebook which is formed by set of codewords C = [c k ] K k=1 ∈ R K×D on the latent space Z ∈ R D , an encoder f e to map the data examples to the codewords, and a decoder f d (i.e., q (x | z)) to reconstruct accurately the data examples from the codewords. Given a data example x, the encoder f e (i.e., p (z | x)) associates x to the codeword fe (x) = c defined as c = argmin k d z (f e (x) , c k ) , where d z is a metric on the latent space. The objective function of VQ-VAE is as follows: E x∼Px d x f d fe (x) , x + d z (sg (f e (x)) , C) + βd z (f e (x) , sg (C)) , where P x = 1 N N n=1 δ xn is the empirical data distribution, sg specifies stop gradient, d x is a cost metric on the data space, and β is set between 0.1 and 2.0 (Van Den Oord et al., 2017) and d z (f e (x) , C) = c∈C d z (f e (x) , c). The purpose of VQ-VAE training is to form the latent representations in clusters and adjust the codewords to be the centroids of these clusters. We present the theoretical development of our VQ-WAE framework which connects the viewpoints of the WS distance, generative models, and deep discrete representation learning in Section 3.1. Specifically, we propose to learn a "deterministic decoder transporting the codeword to data distributions" via minimizing the WS distance between them (Figure 1a (Top)). We then turn the above WS minimization to push-forwarding the data to codeword distribution via minimizing a WS distance between "the latent representation and codeword distributions" (Figure 1a (Bottom)). We prove that when minimizing the WS distance between the latent representation and codeword distributions, the codewords tend to flexibly move to the clustering centroids of the latent representations with a control on the proportion of latent representations associated with a centroid (Figure 1b ). Based on the theoretical development, we devise a practical algorithm for VQ-WAE in Section 3.2.

Given a training set

D = {x 1 , ..., x N } ⊂ R V , we wish to learn a codebook C = {c k } K k=1 ⊂ R K×D on a latent space Z and an encoder to map each data example to a given codebook, preserving insightful characteristics carried in the data. We first endow a discrete distribution over the codewords as P c,π = K k=1 π k δ c k with the Dirac delta function δ and the weights π ∈ ∆ K-1 = {π ′ ≥ 0 : ∥π ′ ∥ 1 = 1}. We aim to learn a decoder function f d : Z → X (i.e., mapping from the latent space Z ⊂ R D to the data space X ), the codebook C, and the weights π, to minimize: min C,π min f d W dx (f d #P c,π , P x ) , where P x = 1 N N n=1 δ xn is the empirical data distribution and d x is a cost metric on the data space. We interpret the optimization problem (OP) in Eq. (1) as follows. Given a discrete distribution P c,π on the codewords, we use the decoder f d to map the codebook C to the data space and consider W dx (f d #P c , P x ) as the codebook-data distortion w.r.t. f d . We subsequently learn f d to minimize the codebook-data distortion given P c,π and finally adjust the codebook C and π to minimize the optimal codebook-data distortion. To offer more intuition for the OP in Eq. (1), we introduce the following lemma. Lemma 3.1. Let C * = {c * k } k , π * , and f * d be the optimal solution of the OP in Eq. (1). Assume K < N , then C * = {c * k } k , π * , and f * d are also the optimal solution of the following OP: min f d min π min σ∈Σπ N n=1 d x x n , f d c σ(n) , where Σ π is the set of assignment functions σ : {1, ..., N } → {1, ..., K} such that the cardinalities σ -foot_0 (k) , k = 1, ..., K are proportional to π k , k = 1, ..., K. 1 Lemma 3.1 states that for the optimal solution C * = {c * k } , π * , and f * d of the OP in Eq. ( 1), {f * d (c * k )} K k=1 become the optimal clustering centroids of the optimal clustering solution which minimizes the distortion. Inspired by Wasserstein Auto-Encoder (Tolstikhin et al., 2017) , we establish the following theorem to engage the OP in (1) with the latent space. Theorem 3.1. We can equivalently turn the optimization problem in (1) to min C,π,f d min fe: fe#Px=Pc,π E x∼Px d x f d fe (x) , x , where fe is a deterministic discrete encoder mapping data example x directly to the codebook. First, we learn both the codebook C and the weights π. Second, ours seeks a deterministic discrete encoder fe mapping data example x directly to a codeword, concurring with vector quantization and serving our further derivations, whereas Theorem 1 in Tolstikhin et al. (2017) involves a probabilistic/stochastic encoder mapping to a continuous latent distribution (i.e., a larger space to search). More importantly, our proof is totally different from that in Tolstikhin et al. ( 2017) (all proof details are given in Appendix A). Additionally, fe is a deterministic discrete encoder mapping a data example x directly to a codeword. To make it trainable, we replace fe by a continuous encoder f e : X → Z and arrive the OP: min C,π min f d ,fe {E x∼Px [d x (f d (Q C (f e (x))) , x)] + λW dz (f e #P x , P c,π )} , where Q C (f e (x)) = argmin c∈C d z (f e (x) , c ) is a quantization operator which returns the closest codeword to f e (x) and the parameter λ > 0. Particularly, we can rigorously prove that the two optimization problems of interest in (3) and ( 4) are equivalent under some mild conditions in Theorem 3.2. This rationally explains why we could solve the OP in (4) for our final tractable solution. Theorem 3.2. If we seek f d and f e in a family with infinite capacity (e.g., the family of all measurable functions), the three OPs of interest in (1, 3, and 4) are equivalent. Moreover, the OP in (4) conveys important meaningful interpretations. Specifically, by minimizing W dz (f e #P x , P c,π ) w.r.t. C, π, we aim to learn the codewords that are clustering centroids of f e #P x according to the clustering viewpoint of OT as shown in Corollary 3.1, and similar to VQ-VAE, we quantize f e (x) to the closest codeword using , c ) and try to reconstruct x from this codebook. Corollary 3.1. Consider minimizing the second term: min fe,C W dz (f e #P x , P c,π ) in ( 4) given π and assume K < N , its optimal solution f * e and C * are also the optimal solution of the OP: Q C (f e (x)) = argmin c∈C d z (f e (x) min fe,C min σ∈Σπ N n=1 d z f e (x n ) , c σ(n) , where Σ π is the set of assignment functions σ : {1, ..., N } → {1, ..., K} such that the cardinalities σ -1 (k) , k = 1, ..., K are proportional to π k , k = 1, ..., K. Corollary 3.1 indicates the aim of minimizing the second term W dz (f e #P x , P c,π ) in (4). By which, we adjust the encoder f e and the codebook C such that the codewords of C become the clustering centroids of the latent representations {f e (x n )} n to minimize the codebook-latent distortion (see Figure 1 (Right)). Additionally, at the optimal solution, the optimal assignment function σ * , which indicates how latent representations (or data examples) associated with the clustering centroids (i.e., the codewords) has a valuable property, i.e., the cardinalities (σ * ) -1 (k) , k = 1, ..., K are proportional to π k , k = 1, ..., K. Remark: Recall the codebook collapse issue, i.e. most of embedded latent vectors are quantized to just few discrete codewords while the other codewords are rarely used. Corollary 3.1 give us important properties: (1) we can control the number of latent representations assigned to each codeword by adjust π, guaranteeing all codewords are utilized, (2) codewords become the clustering centroids of the associated latent representations to minimize the codebook-latent distortion, to develop our VQ-WAE framework.

3.2. PROPOSED FRAMEWORK

One of crucial aims of learning meaningful and well-distributed codewords is to make use of each individual codeword efficiently by solving the OP in (4). Specifically, we wish the latent representations are more uniformly associated with the codewords. Based on Corollary 3.1, pointing out that the numbers of latent representations associated with the k th codeword is proportional to π k , we hence fix π as a uniform distribution (i.e., P c,π = K k=1 1 K δ c k ) to make all the codewords utilized equally by the model, hence boosting the perplexity or the codebook usage. We now present the practical method based on the OP in (4) with W dz (f e #P b , P c,π ) = min P c,π = K k=1 1 K δ c k . At each iteration, P ∈Γ(1 B ,1 C ) ⟨P, D c,x ⟩ , where 1 B = 1 B B is the vector of atom masses of P b , 1 C = 1 C C is the vector of atom masses of P c,π , Γ (1 B , 1 C ) is the set of feasible transportation plans, and D c,x = [d z (x i , c k )] i,k ∈ R B×K is the cost matrix. The pseudcode of our VQ-WAE is summarized in Algorithm 1. We use the copy gradient trick (Van Den Oord et al., 2017) to deal with the back-propagation from decoder to encoder for reconstruction term while Wasserstein regularization term W dz (f e #P b , P c,π ) can be optimized directly without further manipulation. Additionally, W dz (f e #P b , P c,π ) term is only utilized in the training phase. Encode: z i→B = f e (x i→B ) // i → B : for i = 1, ..., B 5: Quantize: c i→B = arg min k d z (z i→B , c k ) // Nearest neighbor assignment 6: Decode: xi→B = f d (c i→B ) 7: Optimize f e , f d and C by minimizing the objective in (4): 1 B B i=1 [d x (x i , x i )] + λ W dz (f e #P b , P c,π ) min P ∈Γ(1 B ,1 C ) ⟨P,Dc,x⟩ 8: end for 9: Return: The optimal f e , f d and C.

4. RELATED WORK

Variational Auto-Encoder (VAE) was first introduced by Kingma & Welling (Kingma & Welling, 2013) for learning continuous representations. However, learning discrete latent representations has proved much more challenging because it is nearly impossible to accurately evaluate the gradients which are required to train models. To make the gradients tractable, one possible solution is to apply the Gumbel Softmax reparameterization trick (Jang et al., 2016) to VAE, which allows us to estimate stochastic gradients for updating the models. Although this technique has a low variance, it brings up a high-bias gradient estimator. Another possible solution is to employ the REINFORCE algorithm (Williams, 1992) , which is unbiased but has a high variance. Additionally, the two techniques can be complementarily combined (Tucker et al., 2017) . To enable learning the discrete latent codes, VQ-VAE (Van Den Oord et al., 2017) uses deterministic encoder/decoder and encourages the codebooks to become the clustering centroids of latent representations. Additionally, the copy gradient trick is employed in back-propagating gradients from the decoder to the encoder (Bengio, 2013) . Some further works were proposed to extend VQ-VAE, notably (Roy et al., 2018; Wu & Flierl, 2020) . Particularly, Roy et al. (2018) uses the Expectation Maximization (EM) algorithm in the bottleneck stage to train the VQ-VAE for improving the quality of the generated images. However, to maintain the stability of this approach, we need to collect a large number of samples on the latent space. Wu & Flierl (2020) (2017) employs the gradient penalty trick to improve the stability of WGAN. In terms of theory development, mostly related to our work is Wasserstein Auto-Encoder (Tolstikhin et al., 2017) which aims to learn continuous latent representation preserving the characteristics of input data.

5. EXPERIMENTS

In this section, we conduct extensive experiments to show the effectiveness of our proposed method compared to other advances. Datasets: we empirically evaluate the proposed VQ-WAE in comparison with VQ-VAE (Van Den Oord et al., 2017) that is the baseline method and recently proposed SQ-VAE (Takida et al., 2022) which is the state-of-the-art work of improving the codebook usage, on four different benchmark datasets: CIFAR10 (Van Den Oord et al., 2017) , MNIST (Deng, 2012) , SVHN (Netzer et al., 2011) , CelebA (Liu et al., 2015) and the high-resolution images dataset FFHQ Karras et al. (2019) . Implementation: For a fair comparison, we utilize the same architectures and hyper-parameters for all methods. Additionally, in the primary setting, we use the codeword (discrete latent) dimensionality of 64 and codebook size |C| = 512 for all datasets except FFHQ with codeword dimensionality of 256 and |C| = 1024, while the hyper-parameters {β, τ, λ} are specified as presented in the original papers, i.e., β = 0.25 for VQ-VAE and VQ-GAN (Esser et al., 2021) , τ = 1e -5 for SQ-VAE and λ = 1 for our VQ-WAE. The details of the experimental settings are presented in Appendix C.

5.1. RESULTS ON BENCHMARK DATASETS

In order to quantitatively assess the quality of the reconstructed images, we report the results on most common evaluation metrics, including the pixel-level peak signal-to-noise ratio (PSNR), patchlevel structure similarity index (SSIM), feature-level LPIPS (Zhang et al., 2018) , and dataset-level Fréchlet Inception Distance (FID) (Heusel et al., 2017) . We report the test-set reconstruction results on four datasets in Table 1 . With regard to the codebook utilization, we employ perplexity score which is defined as e -K k=1 pc k log pc k where p c k = Nc k K i=1 Nc i (i.e. , N ci is the number of latent representations associated with the codeword c i ) is the probability of the i th codeword being used. Note that by formula, perplexity max = |C| as P (c) becomes to the uniform distribution, which means that all the codewords are utilized equally by the model. We compare VQ-WAE with VQ-VAE, SQ-VAE and VQ-GAN for image reconstruction in Table 1 . All instantiations of our model significantly outperform the baseline VQ-VAE under the same compression ratio, with the same network architecture. While the latest state-of-the-art SQ-VAE holds slightly better scores for traditional pixel-and patch-level metrics, our method achieves much better rFID scores which evaluate the image quality at the dataset level. Note that our VQ-WAE significantly improves the perplexity of the learned codebook. This suggests that the proposed method significantly improves the codebook usage, resulting in better reconstruction quality. Finally, to complete the assessment, the qualitative results are visualized in Figure 4 (Appendix B).

5.2. DETAILED ANALYSIS

We run a number of ablations to analyze the properties of VQ-VAE, SQ-VAE and VQ-WAE, in order to assess if our VQ-WAE can simultaneously achieve (i) efficient codebook usage, (ii) reasonable latent representation. We observe the codebook utilization of three methods with different codebook sizes {64, 128, 256, 512} on MNIST and CIFAR10 datasets. Particularly, we present the reconstruction performance for different settings in Table 2 and the histogram of latent representations over the codebook in Figure 2 .

5.2.1. CODEBOOK USAGE

As discussed in Section 3.1 and Section 3.2, the number of used centroids reflects the capability of the latent representations. In other words, it represents the certain amount of information is pre- served in the latent space. By explicitly defining the numbers of latent representations associated with the codebooks to be uniform (i.e., fixing π in (4) as a uniform distribution) in the Wasserstein regularization term, VQ-WAE is able to maximize the information in the codebooks, hence improving the reconstruction capacity. It can be seen from Figure 2 that the latent distribution of VQ-WAE over the codebook is nearly uniform and the codebook's perplexity almost reaches the optimal value (i.e., the value of perplexities reach to corresponding codebook sizes) in different settings. It is also observed that as the size of the codebook increases, the perplexity of codebook of VQ-WAE also increases, leading to the better reconstruction performance (Table 2 ), in line with the analysis in (Wu & Flierl, 2018) . SQ-VAE also has good codebook utilization as its perplexity is proportional to the size of the codebook. However, its codebook utilization becomes less efficient when the codebook size becomes large, especially in low texture dataset (i.e., MNIST). On the contrary, the codebook usage of VQ-VAE is less efficient, i.e., there are many zero entries in its codebook usage histogram, indicating that some codewords have never been used (Figure 2 ). Furthermore, Table 2 also shows the instability of VQ-VAE's reconstruction performance with different codebook sizes.

5.2.2. VISUALIZATION OF LATENT REPRESENTATION

To better understand the codebook's representation power, we employ t-SNE (van der Maaten & Hinton, 2008) to visualize the latent representations that have been learned by VQ-VAE, SQ-VAE and VQ-WAE on the MNIST dataset with two codebook sizes of 64 and 512. Figure 3 shows the latent distributions of different classes in the latent space, in which the samples are colored accordingly to their class labels. Figure 3c shows that representations from different classes of VQ-WAE are well clustered (i.e., each class focuses on only one cluster) and clearly separated to other classes. In contrast, the representations of some classes in VQ-VAE and SQ-VAE are distributed to several clusters and or mixed to each other (Figure 3a, b ). Moreover, the class-clusters of SQ-VAE are uncondensed and tend to overlap with each other. These results suggest that the representations learned by VQ-WAE can better preserve the similarity relations of the data space better than the other models.

5.2.3. IMAGE GENERATION

As discussed in the previous section, VQ-WAE is able to optimally utilize its codebook, leading to meaningful and diverse codewords that naturally improve the image generation. To confirm this ability, we perform the image generation on the benchmark datasets. Since the decoder reconstructs images directly from the discrete embeddings, we only need to model a prior distribution over the discrete latent space (i.e., codebook) to generate images. We employ a conventional autoregressive model, the CNN-based PixelCNN (Van den Oord et al., 2016) , to estimate a prior distribution over the discrete latent space of VQ-VAE, SQ-VAE and VQ-WAE on CIFAR10, MNIST, SVHN and CelebA. The details of generation settings are presented 

6. CONCLUSION

In this paper, inspired by the nice properties and mature theory of the WS distance allowing it to be applied successfully to generative models and continous representation learning, we propose Vector Quantized Wasserstein Auto-Encoder (VQ-WAE), which endows a discrete distribution over the codewords and learns a deterministic decoder that transports the codeword distribution to the data distribution via minimizing a WS distance between them. We then developed theoretical analysis to show the equivalence of this WS minimization to another OP regarding push-forwarding the data distribution to the codeword distribution, which can be realized by minimizing a WS distance between the latent representation and codeword distributions. We conduct comprehensive experiments to show that our VQ-WAE utilizes the codebooks more efficiently than the baselines, hence leading to better reconstructed and generated image quality.

7. REPRODUCIBILITY STATEMENT

We provide the implementation of our framework in the supplementary material.

APPENDIX

This appendix is organized as follows: • In Section A, we present all proofs for theory developed in the main paper. • In Section B, we present additional experimental results on the high-quality image dataset FFHQ. • In Section C, we present experimental settings and implementation specification of VQ-WAE. A THEORETICAL DEVELOPMENT Given a training set D = {x 1 , ..., x N } ⊂ R V , we wish to learn a set of codebooks C = {c k } K k=1 ∈ R K×D on a latent space Z and an encoder to map each data example to a given codebook, preserving insightful characteristics carried in data. We first endow a discrete distribution over the codebooks as P c,π = K k=1 π k δ c k with the Dirac delta function δ and the weights π ∈ ∆ K . We aim to learn a decoder function f d : Z → X (i.e., mapping from the latent space Z ⊂ R D to the data space X ), the codebooks C, and the weights π, to minimize: min C,π min f d W dx (f d #P c,π , P x ) , where P x = 1 N N n=1 δ xn is the empirical data distribution and d x is a cost metric on the data space. Lemma A.1. (Lemma 3.1 in the main paper) Let C * = {c * k } k , π * , and f * d be the optimal solution of the OP in Eq. ( 7). Assume K < N , then C * = {c * k } k , π * , and f * d are also the optimal solution of the following OP: min f d min C,π min σ∈Σπ N n=1 d x x n , f d c σ(n) , where Σ π is the set of assignment functions σ : {1, ..., N } → {1, ..., K} such that the cardinalities σ -1 (k) , k = 1, ..., K are proportional to π k , k = 1, ..., K. Proof of Lemma A.1 It is clear that f d #P c,π = K k=1 π k δ f d (c k ) . Therefore, we reach the following OP: min C,π min f d W dx 1 N N n=1 δ xn , K k=1 π k δ f d (c k ) . By using the Monge definition, we have W dx 1 N N n=1 δ xn , K k=1 π k δ f d (c k ) = min T :T #Px=f d #Pc,π E x∼Px [d x (x, T (x))] = 1 N min T :T #Px=f d #Pc,π N n=1 d x (x n , T (x n )) . Since T #P x = f d #P c,π , T (x n ) = f d (c k ) for some k. Additionally, T -1 (f d (c k )) , k = 1, ..., K are proportional to π k , k = 1, ..., K. Denote σ : {1, ..., N } → {1, ..., K} such that T (x n ) = f d (c σ(n) ), ∀i = 1, ..., N , we have σ ∈ Σ π . It follows that W dx 1 N N n=1 δ xn , K k=1 π k δ f d (c k ) = 1 N min σ∈Σπ N n=1 d x x n , f d c σ(n) . Finally, the the optimal solution of the OP in Eq. ( 7) is equivalent to  E x∼Px d x f d fe (x) , x , where fe is a deterministic discrete encoder mapping data example x directly to the codebooks. Proof of Theorem A.1 We first prove that the OP of interest in ( 7) is equivalent to min C,π,f d min fe: fe#Px=Pc,π E x∼Px,c∼ fe(x) [d x (f d (c) , x)] , where fe is a stochastic discrete encoder mapping data example x directly to the codebooks. To this end, we prove that W dx (f d #P c,π , P x ) = min fe: fe#Px=Pc,π E x∼Px,c∼ fe(x) [d x (f d (c) , x)] , where fe is a stochastic discrete encoder mapping data example x directly to the codebooks. Let fe be a stochastic discrete encoder such that fe #P x = P c,π (i.e., x ∼ P x and c ∼ fe (x) implies c ∼ P c,π ). We consider γ d,c as the joint distribution of (x, c) with x ∼ P x and c ∼ fe (x). We also consider γ f c,d as the joint distribution including (x, x ′ ) ∼ γ f c,d where x ∼ P x ,c ∼ fe (x), and x ′ = f d (c). This follows that γ f c,d ∈ Γ (f d #P c,π , P x ) which admits f d #P c,π and P x as its marginal distributions. We also have: E x∼Px,c∼ fe(x) [d x (f d (c) , x)] = E (x,c)∼γ d,c [d x (f d (c) , x)] (1) = E (x,x ′ )∼γ f c,d [d x (x, x ′ )] ≥ min γ f c,d ∈Γ(f d #Pc,π,Px) E (x,x ′ )∼γ f c,d [d x (x, x ′ )] = W dx (f d #P c,π , P x ) . Note that we have the equality in (1) due to (id, f d ) #γ d,c = γ f c,d . Therefore, we reach min fe: fe#Px=Pc,π E x∼Px,c∼ fe(x) [d x (f d (c) , x)] ≥ W dx (f d #P c,π , P x ) . Let γ f c,d ∈ Γ (f d #P c,π , P x ). Let γ f c,c ∈ Γ (f d #P c,π , P c,π ) be a deterministic coupling such that c ∼ P c,π and x = f d (c) imply (c, x) ∼ γ c,f c . Using the gluing lemma (see Lemma 5.5 in Santambrogio (2015) ), there exists a joint distribution α ∈ Γ (P c,π , f d #P c,π , P x ) which admits γ f c,d and γ f c,c as the corresponding joint distributions. By denoting γ d,c ∈ Γ (P x , P c,π ) as the marginal distribution of α over P x , P c,π , we then have E (x,x ′ )∼γ f c,d [d x (x, x ′ )] = E (c,x ′ ,x)∼α [d x (x, x ′ )] = E (c,x)∼γ d,c ,x ′ =f d (c) [d x (x, x ′ )] = E (c,x)∼γ d,c [d x (f d (c) , x)] = E x∼Px,c∼ fe(x) [d x (f d (c) , x)] . ≥ min fe: fe#Px=Pc,π E x∼Px,c∼ fe(x) [d x (f d (c) , x)] , where fe (x) = γ d,c (• | x). This follows that W dx (f d #P c,π , P x ) = min γ f c,d ∈Γ(f d #Pc,π,Px) E (x,x ′ )∼γ f c,d [d x (x, x ′ )] ≥ min fe: fe#Px=Pc,π E x∼Px,c∼ fe(x) [d x (f d (c) , x)] . This completes the proof for the equality in Eq. ( 12), which means that the OP of interest in ( 7) is equivalent to min C,π,f d min fe: fe#Px=Pc,π E x∼Px,c∼ fe(x) [d x (f d (c) , x)] , We now further prove the above OP is equivalent to min C,π,f d min fe: fe#Px=Pc,π E x∼Px d x f d fe (x) , x , where fe is a deterministic discrete encoder mapping data example x directly to the codebooks. It is obvious that the OP in ( 14) is special case of that in (13) when we limit to search for deterministic discrete encoders. Given the optimal solution C * 1 , π * 1 , f * 1 d , and f * 1 e of the OP in ( 13), we show how to construct the optimal solution for the OP in ( 14). Let us construct C * 2 = C * 1 , f * 2 d = f * 1 d . Given x ∼ P x , let us denote f * 2 e (x) = argmin c d x f * 2 d (c) , x . Thus, f * 2 e is a deterministic discrete encoder mapping data example x directly to the codebooks. We define π * 2 k = P r f * 2 e (x) = c k : x ∼ P x , k = 1, ..., K, meaning that f * 2 e #P x = P c * 2 ,π * 2 . From the construction of f * 2 e , we have E x∼Px d x f * 2 d f * 2 e (x) , x ≤ E x∼Px,c∼ f * 1 e (x) d x f * 1 d (c) , x . Furthermore, because C * 2 , π * 2 , f * 2 d , and f * 2 e are also a feasible solution of the OP in ( 14), we have E x∼Px d x f * 2 d f * 2 e (x) , x ≥ E x∼Px,c∼ f * 1 e (x) d x f * 1 d (c) , x . This means that E x∼Px d x f * 2 d f * 2 e (x) , x = E x∼Px,c∼ f * 1 e (x) d x f * 1 d (c) , x , and C * 2 , π * 2 , f * 2 d , and f * 2 e are also the optimal solution of the OP in ( 14). Additionally, fe is a deterministic discrete encoder mapping data example x directly to the codebooks. To make it trainable, we replace fe by a continuous encoder f e : X → Z and arrive the following OP: min C,π min f d ,fe {E x∼Px [d x (f d (Q C (f e (x))) , x)] + λW dz (f e #P x , P c,π )} , where Q C (f e (x)) = argmin c∈C d z (f e (x) , c) is a quantization operator which returns the closest codebook to f e (x) and the parameter λ > 0. We now propose and prove the following lemma that is necessary for the proof of Theorem A.2. Lemma A.2. Consider C, π, f d , and f e as a feasible solution of the OP in (15). Let us denote fe (x) = argmin c d z (f e (x)), c) = Q C (x), then fe (x) is a Borel measurable function. Proof of Lemma A.2. We denote the set A k on the latent space as A k = {z : d z (z, c k ) < d(z, c j ), ∀j ̸ = k} = {z : Q C (z) = c k }. A k is known as a Voronoi cell w.r.t. the metric d z . If we consider a continuous metric d z , A k is a measurable set. Given a Borel measurable function B, we prove that f -1 e (B) is a Borel measurable set on the data space. Let B ∩ {c 1 , .., c K } = {c i1 , ..., c im }, we prove that f -1 e (B) = ∪ m j=1 f -1 e A ij . Indeed, take x ∈ f -1 e (B), then fe (x) ∈ B, implying that fe (x) = Q C (x) = c ij for some j = 1, ..., m. This means that f e (x) ∈ A ij for some j = 1, ..., m. Therefore, we reach f -1 e (B) ⊂ ∪ m j=1 f -1 e A ij . We now take x ∈ ∪ m j=1 f -1 e A ij . Then f e (x) ∈ A ij for j = 1, ..., m, hence fe (x) = Q C (x) = c ij for some j = 1, ..., m. Thus, fe (x) ⊂ B or equivalently x ∈ f -1 e (B), implying f -1 e (B) ⊃ ∪ m j=1 f -1 e A ij . Finally, we reach f -1 e (B) = ∪ m j=1 f -1 e A ij , which concludes our proof because f e is a measurable function and A ij are measurable sets. Theorem A.2. (Theorem 3.2 in the main paper) If we seek f d and f e in a family with infinite capacity (e.g., the space of all measurable functions), the three OPs of interest in (7, 10, and 15) are equivalent. Proof of Theorem A.2. Given the optimal solution C * 1 , π * 1 , f * 1 d , and f * 1 e of the OP in (15), we conduct the optimal solution for the OP in (10)  . Let us conduct C * 2 = C * 1 , f * 2 d = f * 1 d . We next define f * 2 e (x) = argmin c d z f * 1 e (x) , c = Q C * 1 f * 1 e (x) = Q C * 2 f * 1 e (x) . We prove that C * 2 , π * 2 , f * 2 d , and f * 2 e are optimal solution of the OP in (10) . By this definition, we yield f * 2 e #P x = P c * 2 ,π * 2 and hence W dz f * 2 e #P x , P c * 2 ,π * 2 = 0. Therefore, we need to verify the following: (i) f * 2 e is a Borel-measurable function. (ii) Given a feasible solution C, π, f d , and fe of (10), we have E x∼Px d x f * 2 d f * 2 e (x) , x ≤ E x∼Px d x f d fe (x) , x . We first prove (i). It is a direct conclusion because the application of Lemma A.2 to C * 1 , π * 1 , f * 1 d , and f * 1 e . We next prove (ii). We further derive as E x∼Px d x f * 2 d f * 2 e (x) , x + λW dz f * 2 e #P x , P c * 2 ,π * 2 = E x∼Px d x f * 2 d f * 2 e (x) , x = E x∼Px d x f * 1 d Q C * 2 f * 1 e (x) , x = E x∼Px d x f * 1 d Q C * 1 f * 1 e (x) , x ≤ E x∼Px d x f * 1 d Q C * 1 f * 1 e (x) , x + λW dz f * 1 e #P x , P c * 1 ,π * 1 . Moreover, because fe #P x = P c,π which is a discrete distribution over the set of codewords C, we obtain Q C ( fe (x)) = fe (x). Note that C, π, f d , and fe is also a feasible solution of (15) because fe is also a specific encoder mapping from the data space to the latent space, we achieve E x∼Px d x f d Q C fe (x) , x + λW dz fe #P x , P c,π ≥ E x∼Px d x f * 1 d Q C * 1 f * 1 e (x) , x + λW dz f * 1 e #P x , P c * 1 ,π * 1 . Noting that fe #P x = P c,π and Q C ( fe (x)) = fe (x), we arrive at E x∼Px d x f d fe (x) , x ≥ E x∼Px d x f * 1 d Q C * 1 f * 1 e (x) , x + λW dz f * 1 e #P x , P c * 1 ,π * 1 . ( ) Combining the inequalities in ( 17) and ( 18), we obtain Inequality ( 16) as E x∼Px d x f * 2 d f * 2 e (x) , x ≤ E x∼Px d x f d fe (x) , x . This concludes our proof.  where Σ π is the set of assignment functions σ : {1, ..., N } → {1, ..., K} such that the cardinalities σ -1 (k) , k = 1, ..., K are proportional to π k , k = 1, ..., K. • For CIFAR10, MNIST and SVHN datasets, the models have an encoder with two convolutional layers of stride 2 and filter size of 4 × 4 with ReLU activation, followed by 2 residual blocks, which contained a 3 × 3, stride 1 convolutional layer with ReLU activation followed by a 1 × 1 convolution. The decoder was similar, with two of these residual blocks followed by two deconvolutional layers. • For CelebA dataset, the models have an encoder with two convolutional layers of stride 2 and filter size of 4 × 4 with ReLU activation, followed by 6 residual blocks, which contained a 3 × 3, stride 1 convolutional layer with ReLU activation followed by a 1 × 1 convolution. The decoder was similar, with two of these residual blocks followed by two deconvolutional layers. • For high-quality image dataset FFHQ, we utilize the well-known VQGAN framework Esser et al. (2021) as the baseline. We only replace the regularization module of VQ-VAE i.e., two last terms of objective function: d z (sg (f e (x)) , C) + βd z (f e (x) , sg (C)) by our proposed by Wasserstein regularization λW dz (f e #P x , P c,π )) in Eq. ( 4) for VQ-WAE. Additional, we employ the POT library (Flamary et al., 2021) to compute WS distance for simplicity. However, our VQ-WAE does not require optimal transport map the from WS distance in ( 6) to update the model. Therefore, we can employ a wide range of speed-up algorithms to solve optimization problem (OP) in ( 6) such as Sinkhorn algorithm (Cuturi, 2013) or entropic regularized dual for (Genevay et al., 2016) . Hyper-parameters: following (Takida et al., 2022) , we adopt the adam optimizer for training with: learning-rate is e -4 , batch size of 32, embedding dimension of 64 and codebook size |C| = 512 for all datasets except FFHQ with embedding dimension of 256 and |C| = 1024. Finally, we train model for CIFAR10, MNIST, SVHN, FFHQ in 100 epoches and for CelebA in 70 epoches respectively.

C.2 GENERATION MODEL

Implementation: It is worth to noting that we employ the codebooks learned from reported VQmodels to extract codeword indices and we employ PixelCNN ( Van den Oord et al., 2016) with the same setting for generation for all VQ-VAE, SQ-VAE and VQ-WAE. In particular, we feed PixelCNN over the "pixel" values of the 8 × 8 1-channel latent space for CIAR10, MNIST, SVHN, and 16 × 16 1-channel latent space for CelebA. Hyper-parameters: we adopt the adam optimizer for training with: learning-rate is 3e -4 , batch size of 32. Finally, we PixelCNN over the "pixel" values of the 8 × 8 1-channel latent space in 100 epoches.



E.g., σ is the nearest assignment:σ -1 (k) = { fe (x) = c k | k = argmin k dz (fe (x) , c k )} is set of latent representations which are quantized to k th codeword.



Figure 1: (a): Illustration of our VQ-WAE derivation. We depart with the minimization of the WS distance on the data space in (1) and further turn it to minimizing the reconstruction error in (2) and the WS distance on the latent space in (3); (b): Visualisation of the embedding space with WS regularization. The output of the encoder f e (x) is distributed and moved to codewords c k in which the cardinalities σ -1 (k) (i.e., the number of latent representation which are assigned to k th codeword) are proportional to π k . At the same time, the codewords tend to flexibly move to the clustering centroids of the latent representations (cf. Corollary 3.1).

we sample a mini-batch x 1 , ..., x B and then solve the OP in (4) by updating f d , f e and C based on this mini-batch as follows. Let us denote P b = 1 B B j=i δ xi as the empirical distribution of embedded vectors. over the current batch. Basically, we learn the optimal transportation plan P * by solving:

Algorithm 1 VQ-WAE 1: Initialize: encoder f e , decoder f d and codebook C. 2: for iter in iterations do 3: Sample a mini-batch of samples x 1 , ..., x B forming the empirical batch distribution P b 4:

Figure 2: Latent distribution over the codebook on test-set.

Figure 3: The t-SNE feature visualization on the MNIST dataset.

n , f d c σ(n) , which directly implies the conclusion. Theorem A.1. (Theorem 3.1 in the main paper) We can equivalently turn the optimization problem in (7) to min C,π,f d min fe: fe#Px=Pc,π

Corollary A.1. (Corollary 3.1 in the main paper) Consider minimizing the second term: min fe,C W dz (f e #P x , P c,π ) in (15) given π and assume K < N , its optimal solution f * e and C * are also the optimal solution of the following OP: min fe,C min σ∈Σπ N n=1 d z f e (x n ) , c σ(n) ,

imposes noises on the latent codes and uses a Bayesian estimator to optimize the quantizer-based representation. The introduced bottleneck Bayesian estimator outputs the posterior mean of the centroids to the decoder and performs soft quantization of the noisy latent codes which have latent representations preserving the similarity relations of the data space. Recently,Takida et al. (2022) extends the standard VAE with stochastic quantization and trainable posterior categorical distribution, showing that the annealing of the stochasticity of the quantization process significantly improves the codebook utilization.

Reconstruction performance (↓: the lower the better and ↑: the higher the better).DatasetModel Latent Size SSIM ↑ PSNR ↑ LPIPS ↓ rFID ↓ Perplexity ↑

Distortion and Perplexity with different codebook sizes.

FID scores of generated images. .2 of the supplementary material. The quantitative results in Table3indicate that the codebook of VQ-WAE leads to a better generation ability than VQ-VAE and SQ-VAE.

annex

Proof of Corollary A.1.By the Monge definition, we haveIt also follows that

B VISUALIZATION OF RECONSTRUCTION RESULTS

Figure 4 : Reconstruction results for the FFHQ dataset. Qualitative assessment: We present the reconstructed samples from FFHQ (high-resolution images) for qualitative evaluation. It can be clearly seen that the high-level semantic features of the input image and colors are better preserved with VQ-WAE than the baseline. Particularly, we notice that VQGAN often produces repeated artifact patterns in image synthesis (see the hair of man is second column in Figure 4 ) while VQ-WAE does not. This is because VQ-GAN is lack of diversity in the codebook, which will be further analyzed in Section 5.2.1. Consequently, the quantization operator embeds similar patches into the same quantization index and ignores the variance in these patches (e.g., VQ-GAN reconstructs the background in third column of Figure 4 as hair of woman).

C EXPERIMENTAL SETTINGS C.1 VQ-MODEL

Implementation: For fair comparison, we utilize the same framework architecture and hyperparameters for both VQ-VAE and VQ-WAE. Specifically, we construct the VQ-VAE and VQ-WAE models as follows:

