SYMMETRIC WASSERSTEIN AUTOENCODERS

Abstract

Leveraging the framework of Optimal Transport, we introduce a new family of generative autoencoders with a learnable prior, called Symmetric Wasserstein Autoencoders (SWAEs). We propose to symmetrically match the joint distributions of the observed data and the latent representation induced by the encoder and the decoder. The resulting algorithm jointly optimizes the modelling losses in both the data and the latent spaces with the loss in the data space leading to the denoising effect. With the symmetric treatment of the data and the latent representation, the algorithm implicitly preserves the local structure of the data in the latent space. To further improve the latent representation, we incorporate a reconstruction loss into the objective, which significantly benefits both the generation and reconstruction. We empirically show the superior performance of SWAEs over the state-of-the-art generative autoencoders in terms of classification, reconstruction, and generation. Deep generative models have emerged as powerful frameworks for modelling complex data. Widely used families of such models include Generative



Nevertheless, these methods typically involve a sophisticated objective function that either depends on unstable adversarial training or challenging approximation of the mutual information. In this paper, we leverage Optimal Transport (OT) (Villani, 2008; Peyré et al., 2019) to symmetrically match the encoding and the decoding distributions. The OT optimization is generally challenging particularly in high dimension, and we address this difficulty by transforming the OT cost into a simpler form amenable to efficient numerical implementation. Owing to the symmetric treatment of the observed data and the latent representation, the local structure of the data can be implicitly preserved in the latent space. However, we found that with the symmetric treatment only the performance of the generative model may not be satisfactory. To improve the generative model we additionally include a reconstruction loss into the objective, which is shown to significantly benefit the quality of the generation and reconstruction. Our contributions can be summarized as follows. Firstly, we propose a new family of generative autoencoders, called Symmetric Wasserstein Autoencoders (SWAEs). Secondly, we adopt a learnable latent prior, parameterized as a mixture of the conditional priors given the learnable pseudo-inputs, which prevents SWAEs from over-regularizing the latent variables. Thirdly, we empirically perform an ablation study of SWAEs in terms of the KNN classification, denoising, reconstruction, and sample generation. Finally, we empirically verify, using benchmark tasks, the superior performance of SWAEs over several state-of-the-art generative autoencoders.

2. SYMMETRIC WASSERSTEIN AUTOENCODERS

In this section we introduce a new family of generative autoencoders, called Symmetric Wasserstein Autoencoders (SWAEs).

2.1. OT FORMULATION

Denote the random vector at the encoder as e (x e , z e ) ∈ X ×Z, which contains both the observed data x e ∈ X and the latent representation z e ∈ Z. We call the distribution p(e) ∼ p(x e )p(z e |x e ) the encoding distribution, where p(x e ) represents the data distribution and p(z e |x e ) characterizes an inference model. Similarly, denote the random vector at the decoder as d (x d , z d ) ∈ X × Z, which consists of both the latent prior z d ∈ Z and the generated data x d ∈ X . We call the distribution p(d) ∼ p(z d )p(x d |z d ) the decoding distribution, where p(z d ) represents the prior distribution and p(x d |z d ) characterizes a generative model. The objective of VAEs is equivalent to minimizing the (asymmetric) KL divergence between the encoding distribution p(e) and the decoding distribution p(d) (see Appendix A.1). To address the limitation in VAEs, first we propose to treat the data and the latent representation symmetrically instead of asymmetrically by minimizing the pth Wasserstein distance between p(e) and p(d) leveraging Optimal Transport (OT) (Villani, 2008; Peyré et al., 2019) . OT provides a framework for comparing two distributions in a Lagrangian framework, which seeks the minimum cost for transporting one distribution to another. We focus on the primal problem of OT, and Kantorovich's formulation (Peyré et al., 2019) is given by: In particular, it can be proved that the p-th Wasserstein distance is a metric hence symmetric, and metrizes the weak convergence (see, e.g., (Santambrogio, 2015) ). W c (p(e), p(d)) inf Γ∈P(e∼p(e),d∼p(d)) E (e,d)∼Γ c(e, d), Optimization of equation 1 is computationally prohibitive especially in high dimension (Peyré et al., 2019) . To provide an efficient solution, we restrict to the deterministic encoder and decoder. Specifically, at the encoder we have the latent representation z e = E(x e ) with the function E : X → Z, and at the decoder we have the generated data x d = D(z d ) with the function D : Z → X . It turns out that with the deterministic condition instead of searching for an optimal coupling in high dimension, we can find a proper conditional distribution p(z d |x e ) with the marginal p(z d ). Theorem 1 Given the deterministic encoder E : X → Z and the deterministic decoder D : Z → X , the OT problem in equation 1 can be transformed to the following: W c (p(e), p(d)) = inf p(z d |xe) E p(xe) E p(z d |xe) c(e, d), where the observed data follows the distribution p(x e ) and the prior follows the distribution p(z d ). The proof of Theorem 1 extends that of Theorem 1 in (Tolstikhin et al., 2018) , and is provided in Appendix A.2. If X × Z is the Euclidean space endowed with the L p norm, then the expression in equation 2 equals the following: W c (p(e), p(d)) = inf p(z d |xe) E p(xe) E p(z d |xe) x e -D(z d ) p p + E(x e ) -z d p p , where in the objective we call the first term the x-loss and the second term the z-loss. With the above transformation, we decompose the loss in the joint space into the losses in both the data and the latent spaces. Such decomposition is crucial and allows us to treat the data and the latent representation symmetrically. The x-loss, i.e., x e -D(z d ) p p , represents the discrepancy in the data space, and can be interpreted from two different perspectives. Firstly, since D(z d ) represents the generated data, the x-loss essentially minimizes the dissimilarity between the observed data and the generated data. Secondly, the x-loss is closely related to the objective of Denoising Autoencoders (DAs) (Vincent et al., 2008; 2010) . In particular, DAs aim to minimize the discrepancy between the observed data and a partially destroyed version of the observed data. The corrupted data can be obtained by means of a stochastic mapping from the original data (e.g., via adding noises). By contrast, the x-loss can be explained in the same way with the generated data being interpreted as the corrupted data. This is because the prior z d in D(z d ) is sampled from the conditional distribution p(z d |x e ), which depends on the observed data x e . Consequently, the generated data D(z d ), obtained by feeding z d to the decoder, is stochastically related to the observed data x e . With this insight, the same as the objective of DAs, the x-loss can lead to the denoising effect. The z-loss, i.e., E(x e ) -z d p p , represents the discrepancy in the latent space. The whole objective in equation 3 hence simultaneously minimizes the discrepancy in the data and the latent spaces. Observe that in equation 3 E(x e ) is the latent representation of x e at the encoder, while z d can be thought of as the latent representation of D(z d ) at the decoder. With such connection, the optimization of equation 3 can preserve the local data structure in the latent space. More specifically, since x e and D(z d ) are stochastically dependent, roughly speaking, if two data samples are close to each other in the data space, their corresponding latent representations are also expected to be close. This is due to the symmetric treatment of the data and the latent representation. In Figure 1 we illustrate this effect and compare SWAE with VAE. Comparison with WAEs (Tolstikhin et al., 2018) The objective in equation 3 minimizes the OT cost between the joint distributions of the data and the latent, i.e., W c (p(e), p(d)), while the objective of WAEs (Tolstikhin et al., 2018) minimizes the OT cost between the marginal distributions of the data, i.e., W c (p(x e ), p(x d )), where p(x d ) is the marginal data distribution induced by the decoding distribution p(d). The problem of WAEs is first formulated as an optimization with the constraint p(z e ) = p(z d ), where p(z e ) is the marginal distribution induced by the encoding distribution p(e), and then relaxed by adding a regularizer. With the deterministic decoder, the final optimization problem of WAEs is as follows: inf p(ze|xe) E p(xe) E p(ze|xe) c(x e , D(z e )) + λD(p(z e ), p(z d )), where D(, ) denotes some divergence measure. Comparing equation 4 to equation 3, we can see that both methods decompose the loss into the losses in the data and the latent spaces. Differently, in equation 4 the first term reflects the reconstruction loss in the data space and the second term represents the distribution-based dissimilarity in the latent space; while in equation 3 the x-loss is closely related to the denoising and the generation quality and the z-loss measures the sample-based dissimilarity. Moreover, equation 4 is optimized over the posterior p(z e |x e ) with a fixed prior p(z d ), while equation 3 is optimized over the conditional prior p(z d |x e ) with a potentially learnable prior.

2.2. IMPROVEMENT OF LATENT REPRESENTATION

The objective in equation 3 only seeks to match the encoding and the decoding distributions. Besides the encoder and the decoder structures, there is no explicit constraint on the correlation between the data and the latent representation within each joint distribution. Lacking of such constraint typically results in a low quality of reconstruction (Dumoulin et al., 2017; Li et al., 2017a) . Therefore, we incorporate a reconstruction-based loss into the objective associated with a controllable coefficient. Additionally, since the dimension of the latent space is usually much smaller than that of the data space, we associate a weighting parameter to balance these two types of losses. Overall, the objective function can be represented as follows: inf p(z d |xe) E p(xe) E p(z d |xe) β x e -D(z d ) p p + (1 -β) x e -D(z e ) p p + α E(x e ) -z d p p , where x e -D(z e ) p p denotes the reconstruction loss, and β(0 < β < 1) and α(α > 0) are the weighting parameters. The weighting parameter β controls the trade-off between the x-loss and the reconstruction loss, and a smaller value of β generally leads to better reconstruction. To achieve a better trade-off between the generation and reconstruction β needs to be carefully chosen. We will perform an ablation study of SWAEs and show the importance of including the reconstruction loss into the objective for the generative model in Section 3.

2.3. ALGORITHM

Similar to many VAE-based generative models, we assume that the encoder, the decoder, and the conditional prior are parameterized by deep neural networks. Unlike the canonical VAEs, where the prior distribution is simple and given in advance, the proposed method adopts a learnable prior. The benefits of a learnable prior, e.g., avoiding over-regularization and hence improving the quality of the latent representation, have been revealed in several recent works (Hoffman & Johnson, 2016; Tomczak & Welling, 2018; Atanov et al., 2019; Klushyn et al., 2019) . Obviously, the conditional prior is related to the marginal prior via E xe p(z d |x e ) = p(z d ). This indicates a way to design the prior as a mixture of the conditional distributions, i.e., p * (z d ) = 1 N N n=1 p(z d |x e,n ), where x e,1 , • • • , x e,N are the training samples. To avoid over-fitting, similar to (Tomczak & Welling, 2018) , we replace the training samples with learnable pseudo-inputs and parameterize the prior distribution p(z d ) as p γ (z d ) = 1 K K k=1 p γ (z d |u k ) , where γ denotes the parameters of the conditional prior network, u k ∈ X are the learnable pseudo-inputs, and K is the number of the pseudo-inputs. We emphasize that the conditional prior p(z d |x e ) (or approximated p(z d |u k )) is used to obtain the marginal prior p(z d ); while the posterior p(z e |x e ) is used for inference. In experiment, we parameterize the conditional prior as a Gaussian distribution. We call the proposed generative model Symmetric Wasserstein Autoencoders (SWAEs) as we treat the observed data and the latent representation symmetrically. We summarize the training algorithm in Algorithm 1 and show the network architecture in Figure 2 . As an example, we define the cost function c(, ) as the squared L2 norm. Algorithm 1: Symmetric Wasserstein Autoencoders (SWAEs) Require: The number of the pseudo-inputs K. The weighting parameters β and α. Initialize the parameters φ, θ, and γ of the encoder network, the decoder network, and the conditional prior network, respectively. while (φ, θ, γ, {u k }) not converged do 1. Sample {x e,1 , • • • , x e,N } from the training dataset. 2. Find the closest pseudo-input u (n) of each training sample from the set {u 1 , • • • , u K }. 3. Sample z d,n from the conditional prior p γ (z d |u (n) ) for n = 1, • • • , N . 4. Update (φ, θ, γ, {u k }) by descending the cost function 1 N N n=1 β x e,n -D(z d,n ) 2 2 + (1 -β) x e,n -D(E(x e,n )) 2 2 + α E(x e,n ) -z d,n 2 2 . Since we use the pseudo-inputs instead of the training samples in the conditional prior, given each training sample we need to find the closest pseudo-input in Step 2. To measure the similarity, we can use, e.g., the L2 norm or the cosine similarity. Since the dimension of the latent space is usually much smaller than that of the data space, to reduce the searching time we can alternatively perform Step 2 in the latent space as an approximation. Specifically, we can find the closest latent representation of E(x e,n ) from the set {E(u 1 ), • • • , E(u K )} so as to obtain the corresponding closest pseudo-input. From the experiment we found that such approximation results in little performance degradation, and we attribute it to the preservation of the local structure as explained before.

3. EXPERIMENTAL RESULTS

In this section, we compare the performance of the proposed SWAE with several contemporary generative autoencoders, namely VAE (Kingma & Welling, 2014), WAE-GAN (Tolstikhin et al., 2018) , WAE-MMD (Tolstikhin et al., 2018) , VampPrior (Tomczak & Welling, 2018) , and MIM (Livne et al., 2019) , using four benchmark datasets: MNIST, Fashion-MNIST, Coil20, and CIFAR10 with a subset of classes (denoted as CIFAR10-sub).

3.1. EXPERIMENTAL SETUP

The design of neural network architectures is orthogonal to that of the algorithm objective, and can greatly affect the algorithm performance (Vahdat & Kautz, 2020) . Since MIM has the same network architecture as that of VampPrior, for fair comparison we also build SWAE as well as VAE based on the VampPrior network architecture. In particular, VampPrior adopts the hierarchical latent structure with the convolutional layers (i.e., convHVAE (L = 2)), where the gating mechanism is utilized as an element-wise non-linearity. The building block of the network structure of VAE and SWAE is the same as that of VampPior except that the latent structure is non-hierarchical. Different from SWAE, the prior of VampPrior and MIM is designed as a mixture of the posteriors (instead VampPrior and MIM, the number of the pseudo-inputs K is carefully chosen via the validation set. Unlike these two algorithms, for SWAE we found that increasing K improves the algorithm performance. The setup of K for SWAE, VampPrior, and MIM on all datasets can be found in Appendix A.3. For SWAE, we set the weighting parameter α to 1 in all cases; in Step 2 we use the L2 norm as the similarity measure in the data space. WAE-GAN and WAE-MMD are the WAEbased models, where the divergence measure in the latent space is based on GAN and the maximum mean discrepancy (MMD), respectively. The network structure of WAE-GAN and WAE-MMD is the same as that used in (Tolstikhin et al., 2018) . The prior of VAE, WAE-GAN, and WAE-MMD is set as an isotropic Gaussian. A detailed description of the datasets, the applied network architectures, and the training parameters can be found in Appendix A.3.

3.2. LATENT REPRESENTATION

The latent representation is expected to capture salient features of the observed data and be useful for the downstream applications. The considered datasets are all associated with the labels. In the experiment we use the latent representation for the K-Nearest Neighbor (KNN) classification and compare the classification accuracy of 5-NN in Table 1 , where dim-z denotes the dimension of the latent space. The results of 3-NN and 10-NN are similar to those of 5-NN and thus are omitted. We found that the classification results of all algorithms on CIFAR10 are unsatisfactory based on the current networks (accuracy was around 0.3 -0.4; this may due to the limited expressive power of the shallow network architectures used), so instead we create a subset of CIFAR10 (CIFAR10-sub) which contains 3 classes: bird, cat, and ship. Since the prior of VAE, WAE-GAN, and WAE-MMD is an isotropic Gaussian, setting dim-z greater than the intrinsic dimensionality of the observed data would force p(z e ) to be in a manifold in the latent space (Tolstikhin et al., 2018) . This makes it impossible to match the marginal p(z e ) with the prior p(z d ) and thus leads to unsatisfactory latent representation. Such concern can be verified particularly on Fashion-MNIST where the classification accuracy of VAE and WAE-GAN drops dramatically when dim-z is increased. For SWAE, we consider two cases: β = 1 (i.e., without the reconstruction loss) and β = 0.5. The classification accuracy of SWAE (β = 1) is comparable to SWAE (β = 0.5) and is generally superior for different values of dim-z to the benchmarks. To further show the structure of the latent representation, we project the latent representation to 2D using t-SNE (Maaten & Hinton, 2008) as the visualization tool. As an example, we show the projection of the latent representation on MNIST in Figure 3 . We can see that SWAEs keep the local structure of the observed data in the latent space and lead to tight clusters, which is consistent to our expectation as explained in Section 2.1. 

3.3. GENERATION AND RECONSTRUCTION

To generate new data, latent samples are first drawn from the marginal prior distribution p(z d ) based on the conditional priors p(z d |u k ), and then fed to the decoder. We put the generated images of all methods in Appendix A.4, and show the Fréchet Inception Distance (FID) (Heusel et al., 2017) , which is commonly used for evaluating the quality of generated images, in Table 2 . For SWAEs, we observe that the reconstruction loss term is crucial for improving the generation quality as SWAE (β = 1) generally cannot lead to the lowest FID. On MNIST and Fashion-MNIST, the FID of the best SWAE (indicated as β * ) is slightly higher than that of WAE-GAN, but lower than all the other benchmarks. The visual difference between SWAE (β * ) and WAE-GAN on MNIST and Fashion-MNIST is however negligible. In Section 2.1, we compare the formulation of SWAEs (β = 1) with WAE. In particular, the objective of WAE includes a distribution-based dissimilarity in the latent space while the z-loss in SWAEs measures the sample-based dissimilarity. On Coil20 and CIFAR10-sub, SWAE (β * ) achieves the lowest FID and generates new images that are visually much better than those generated by the benchmarks. In Table 3 , we compare the reconstruction loss, defined as x e -D(z e ) 2 2 , on the four datasets. As expected, increasing the value of dim-z can reduce the reconstruction loss but the reduction becomes marginal when dim-z is large enough. Additionally, since a smaller value of β leads to more emphasis on the reconstruction-based loss the quality of reconstruction is generally better. We observe that SWAE (β = 0.5) results in the lowest reconstruction loss in all cases. The reconstructed images of all methods are provided in Appendix A.4 for reference. Without including the reconstruction loss into the objective, the reconstruction quality of SWAE (β = 1) can be unsatisfactory (e.g., on CIFAR10-sub).

3.4. DENOISING EFFECT WITH SWAE (β = 1)

As discussed in Section 2.1, the x-loss has a close relationship to the objective of Denoising Autoencoders (DAs). After training, we feed the noisy images, which are obtained by adding the Gaussian random samples with mean zero and standard deviation 0.3 to the clean test samples, to the encoder. In Figure 4 , as an example, we show the reconstructed images on Fashion-MNIST. Since the reconstruction loss is highly related to the dimension of the latent space, for fair comparison we set dim-z to 80 for all methods. We observe that only SWAE (β = 1) can recover clean images. This observation confirms the denoising effect induced by the x-loss, and thus the resultant latent representation is robust to partial destruction of the observed data.

4. RELATED WORK

The objective of VAEs uses the asymmetric KL divergence between the encoding and the decoding distributions (see Appendix A.1). To improve VAEs (Livne et al., 2019; Chen et al., 2018; Pu et al., 2017b) propose symmetric divergence measures instead of the asymmetric KL divergence in VAEbased generative models. For example, MIM (Livne et al., 2019) adopts the Jensen-Shannon (JS) divergence between the encoding and the decoding distributions together with a regularizer maximizing the mutual information between the data and the latent representation. Due to the difficulty of estimating the mutual information and the unavailability of the data distribution, an upper bound of the desired loss is proposed. AS-VAE (Pu et al., 2017b) and the following work (Chen et al., 2018) propose a symmetric form of the KL divergence optimized with adversarial training. These methods typically involve a difficult objective either depending on (unstable) adversarial training or containing the mutual information that requires further approximation. In contrast, the proposed SWAEs yield a simple expression of objective and do not involve adversarial training. Compared to VAEs, GANs lack an efficient inference model thus are incapable of providing the corresponding latent representation given the observed data. To bridge the gap between VAEs and GANs, recent works attempt to integrate an inference mechanism into GANs by symmetrically treating the observed data and the latent representation, i.e., the discriminator is trained to discriminate the joint samples in both the data and the latent spaces. In particular, the JS divergence between the encoding and the decoding distributions is deployed in ALI (Dumoulin et al., 2017) and BiGANs (Donahue et al., 2017) . To address the non-identifiability issue in ALI (e.g., unfaithful reconstruction), later ALICE (Li et al., 2017a) proposes to regularize ALI using conditional entropy. Generative modelling is closely related to minimizing a dissimilarity measure between two distributions. As opposed to many other commonly adopted dissimilarity measures, e.g., the JS and the KL divergences, the Wasserstein distances that arise from the OT problem provide a weaker distance between probability distributions (see (Santambrogio, 2015; Peyré et al., 2019; Kolouri et al., 2017) for more background on OT). This is crucial as in many applications the observed data are essentially supported on a low dimensional manifold. In such cases, common dissimilarity measures may fail to provide a useful gradient for training. Consequently, the Wasserstein distances have received a surge of attention for learning generative models (Arjovsky et al., 2017; Balaji et al., 2019; Sanjabi et al., 2018; Kolouri et al., 2019; Patrini et al., 2019; Tolstikhin et al., 2018; Deshpande et al., 2019; Nguyen et al., 2020) . Particularly, the VAE-based models (Tolstikhin et al., 2018; Kolouri et al., 2019; Patrini et al., 2019) are all based on minimizing the OT cost of the marginal distributions in the data space with the difference of how to measure the divergence in the latent space: (Tolstikhin et al., 2018) proposes the GAN-based and the MMD-based divergences, (Kolouri et al., 2019) adopts the sliced-Wasserstein distance, and (Patrini et al., 2019) exploits the Sinkhorn divergence. Unlike these works, our proposed SWAEs directly minimize the OT cost of the joint distributions of the observed data and the latent representation with the inclusion of a reconstruction loss for further improving the generative model.

5. CONCLUSION AND FUTURE WORK

We contributed a novel family of generative autoencoders, termed Symmetric Wasserstein Autoencoders (SWAEs) under the framework of OT. We proposed to symmetrically match the encoding and the decoding distributions with the inclusion of a reconstruction loss for further improving the generative model. We conducted empirical studies on benchmark tasks to confirm the superior performance of SWAEs over state-of-the-art generative autoencoders. We believe that symmetrically aligning the encoding and the decoding distributions with a proper regularizer is crucial to improving the performance of generative models. To further enhance the performance of SWAEs, it is worthwhile to exploit other methods for the prior design, e.g., the flowbased approaches (Rezende & Mohamed, 2015; Dinh et al., 2014; 2016) , and other forms of the reconstruction loss, e.g., the cross entropy.  It can be also shown that the objective of VAEs is equivalent to minimizing the KL divergence (or maximizing the negative KL divergence) between the encoding and the decoding distributions (Livne et al., 2019; Esmaeili et al., 2019; Pu et al., 2017b; Chen et al., 2018) : -D KL (p(x e , z e )||p(x d , z d )) = E p(xe,ze) log p(x d , z d ) p(z e |x e ) -E p(xe) [log p(x e )]. The right hand side of equation 7 is only different from equation 6 in terms of a constant, which is the entropy of the observed data.

A.2 PROOF OF THEOREM 1

The proof extends that of Theorem 1 in (Tolstikhin et al., 2018) . In particular, (Tolstikhin et al., 2018) aims to minimize the OT cost of the marginal distributions p(x e ) and p(x d ), and the proof there is based on the joint probability of three random variables: the observed data, the generated data, and the latent representation. In contrast, we propose to minimize the OT cost of the joint distributions of the observed data and the latent representation induced by the encoder and the decoder. As a result our proof is based on the joint distribution of four random variables (x e , z e , x d , z d ) ∈ X × Z × X × Z. We assume that the joint distribution p(x e , z e , x d , z d ) satisfies the following three conditions: 1. e (x e , z e ) ∼ p(x e )p(z e |x e ); 



where P(e ∼ p (e), d ∼ p(d)), called the coupling between e and d, denotes the set of the joint distributions of e and d with the marginals p(e) and p(d), respectively, and c(e, d) : (X , Z) × (X , Z) → [0, +∞] denotes the cost function. When ((X , Z) × (X , Z), d) is a metric space and the cost function c(e, d) = d p (e, d) for p ≥ 1, W p , the p-th root of W c is defined as the p-th Wasserstein distance.

Figure 1: Latent representations of 100 GMM samples (mode 5 and dimension 10) with dim-z = 2.The indexes of these latent representations are sorted based on the distance to a target sample in the data space, i.e., Index 0 is associated with the target sample and Index 100 is associated with the furthest sample to the target in the data space. With SWAE (left) data samples that are close in the data space are also close in the latent space, while VAE (right) cannot preserve such correspondence.

Figure 2: Network architecture of SWAEs. To generate new data latent samples are first drawn from the marginal prior p(z d ) based on the conditional priors p(z d |u k ), and then fed to the decoder.

Figure 3: Projection of the latent representation to 2D via t-SNE on MNIST. dim-z = 80 for all methods.

Figure 4: Denoising effect: reconstructed images on Fashion-MNIST. dim-z = 80 for all methods.

2. d (x d , z d ) ∼ p(z d )p(x d |z d ); and 3. x d ⊥ ⊥ x e |z d (conditional independence). The first two conditions specify the encoder and the decoder respectively, and the last condition indicates that given the latent prior the generated data and the observed data are independent. Denote the set of the above joint distributions as P(x e , z e , x d , z d ). Obviously, we have P(x e , z e , x d , z d ) ⊆ P(e ∼ p (e), d ∼ p(d)) due to the third condition. If the decoder is deterministic, p(x d |z d ) is a Dirac distribution thus P(x e , z e , x d , z d ) = P(e ∼ p (e), d ∼ p(d)). With this result, we can rewrite the objective of the underlying OT problem as follows: W c (p(e), p(d)) = inf Γ∈P(xe,ze,x d ,z d ) E (e,d)∼Γ c(e, d) = inf Γ∈P(xe,ze,z d ) E (xe,ze,z d )∼Γ c(e, d) (8) = inf p(ze|xe), p(z d |xe,ze) E p(xe) E p(ze|xe) E p(z d |xe,ze) c(e, d) (9) = inf p(z d |xe) E p(xe) E p(z d |xe) c(e, d), (10) where in equation 8 P(x e , z e , z d ) denotes the set of the joint distributions of (x e , z e , z d ) induced by P(x e , z e , x d , z d ) and it holds due to the deterministic decoder, and equation 10 holds due to the deterministic encoder.

Figure 5: Latent representation on MNIST; dim-z = 2 for all methods. With more compressed latent representation, the classification accuracy generally decreases except VAE (5-NN accuracy with dim-z = 2: SWAE(β = 1)(0.75), SWAE(β = 0.5)(0.86), VAE(0.84), WAE-GAN(0.81), WAE-MMD(0.81), VampPrior(0.82), and MIM(0.87)). Also, the quality of reconstruction and generation decreases when dim-z = 2.

Figure 6: Generated new samples on MNIST. dim-z = 8 for all methods.

Figure 7: Generated new samples on Fashion-MNIST. dim-z = 8 for all methods.

Figure 8: Generated new samples on Coil20. dim-z = 80 for all methods.

Figure 9: Generated new samples on CIFAR10-sub. dim-z = 512 for all methods.

Figure 10: Generated new samples on Celeba. dim-z = 512 for all methods. VampPrior has the best generation quality visually. SWAE (β * = 0.5) is better than SWAE (β = 1) and SWAE (β = 0).

Figure 11: Reconstructed images on MNIST. dim-z = 80 for all methods. As expected, for SWAEs a smaller β leads to a higher quality of reconstruction.

Figure 12: Reconstructed images on Fashion-MNIST. dim-z = 80 for all methods.

Figure 13: Reconstructed images on Coil20. dim-z = 80 for all methods. For SWAEs, the difference of the reconstruction error for different values of β is insignificant, and the reconstructed images look visually the same.

Figure 14: Reconstructed images on CIFAR10-sub. dim-z = 512 for all methods. Excluding the reconstruction loss in the objective, the reconstruction of SWAE (β = 1) is blurry.

Figure 15: Reconstructed images on Celeba. dim-z = 512 for all methods. The average reconstruction error over three seeds is as follows: SWAE (β = 1) : 301.11 ±4.85, SWAE (β = 0.5): 37.78 ±6.71, SWAE (β = 0): 30.14 ±0.18, VAE: 38.64±5.39, WAE-GAN: 23.63±5.79, WAE-MMD: 28.64±2.85, VampPrior: 37.96±0.71, and MIM: 36.40±2.84. As expected, for SWAE a smaller value of β leads to a lower reconstruction loss.

Figure 16: Denoising effect: reconstructed images on MNIST. dim-z = 80 for all methods. SWAE (β = 1), WAE-GAN, and WAE-MMD can recover clean images. However, for WAE-GAN and WAE-MMD, we can still see some noisy dots around the digits.

Figure 17: Denoising effect: reconstructed images on Coil20. dim-z = 80 for all methods. Except WAE-GAN and WAE-MMD, the other methods can produce clear images.

Classification accuracy of 5-NN (averaged over 5 trials). The standard deviation is generally less than 0.01 and is omitted in the table. of the conditional priors as in SWAE) conditioned on the learnable pseudo-inputs. The pseudo-inputs in SWAE, VampPrior, and MIM are initialized with the training samples. For



Reconstruction loss (averaged over 5 trials). .14 23.20 ± 0.04 24.34 ± 0.07 26.86 ± 0.37 24.76 ± 0.31 24.05 ± 0.10 24.04 ± 0.10 40 26.29 ± 0.13 6.93 ± 0.05 18.40 ± 0.08 16.06 ± 0.15 13.78 ± 0.77 17.32 ± 0.09 18.14 ± 0.33 80 26.10 ± 0.09 1.25 ± 0.02 18.50 ± 0.11 10.78 ± 0.11 9.63 ± 0.05 17.42 ± 0.06 17.29 ± 0.20 .04 71.03 ± 0.06 72.56 ± 0.02 78.17 ± 1.41 74.50 ± 0.60 72.20 ± 0.04 72.34 ± 0.03 40 73.39 ± 0.08 57.90 ± 0.25 69.85 ± 0.04 74.84 ± 0.23 75.86 ± 0.41 68.67 ± 0.07 70.22 ± 0.87 80 73.35 ± 0.08 44.30 ± 0.71 69.90 ± 0.08 70.74 ± 1.16 71.28 ± 3.80 68.54 ± 0.10 69.10 ± 0.13

Pascal Vincent, Hugo Larochelle, Isabelle Lajoie, Yoshua Bengio, and Pierre-Antoine Manzagol. Stacked denoising autoencoders: Learning useful representations in a deep network with a local denoising criterion. Journal of machine learning research, 11(Dec):3371-3408, 2010. E p(xe) E p(ze|xe) [log p(x d |z)] -D KL (p(z e |x e )||p(z d )) .

A.3 DATASETS AND NETWORK ARCHITECTURES

In this section, we briefly describe the datasets, the network architectures, and the hyperparameters that are used in our training algorithm.• MNIST: The dataset includes 70, 000 binarized images of numeric digits from 0 to 9, each of the size 28 × 28. There are 7, 000 images per class. The training set contains 50, 000 images, the validation set contains 10, 000 images for choosing the best model based on the loss function, and the test set contains 10, 000 images.• Fashion-MNIST: The dataset includes 70, 000 binarized images of fashion products in 10 classes. This dataset has the same image size and the split of the training, validation, and test sets as in MNIST. Network architecture of SWAE: The building block of the network structure of SWAE is based on VampPrior, called GatedConv2d. GatedConv2d contains two convolutional layers with the gating mechanism utilized as an element-wise non-linearity. The parameters in the function GatedConv2d() represent the number of the input channels, the number of the output channels, kernel size, stride, and padding, respectively. The conditional prior network outputs the mean and the log-variance of a Gaussian distribution, based on which the latent prior is sampled.• The structure of the encoder network: GatedConv2d(1,32,7,1,3)-GatedConv2d(32,32,3,2,1)-GatedConv2d(32,64,5,1,2)-GatedConv2d(64,64,3,2,1)-GatedConv2d(64,6,3,1,1), followed by one fully-connected layer with no activation function.• The structure of the conditional prior network: The layers of GatedConv2d are the same as those in the encoder network, which are followed by two fully-connected layers. One produces the mean, and the other produces the log-variance with the activation function Hardtanh.• The structure of the decoder network: Two fully-connected layers with the gating mechanism, followed by GatedConv2d(1,64,3,1,1)-GatedConv2d(64,64,3,1,1)-GatedConv2d(64,64,3,1,1)-GatedConv2d(64,64,3,1,1), followed by a convolutional layer with the activation function Sigmoid.The algorithm is trained by Adam with the learning rate = 0.001, β 1 = 0.9, and β 2 = 0.999.Setup of the number of the pseudo-inputs K: As suggested in (Tomczak & Welling, 2018; Livne et al., 2019) we set the value of K in VampPrior and MIM on MNIST and Fashion-MNIST to 500.We found K = 500 is also suitable for VampPrior and MIM on Coil20, CIFAR10-sub, and Celeba. Unlike VampPrior and MIM, for SWAE we found that increasing K improves the performance and we set K to 4000 on MNIST, Fashion-MNIST, CIFAR10-sub, and Celeba. Coil20 is a relatively small dataset and we set K to 500 for SWAE, VampPrior, and MIM.

A.4 MORE EXPERIMENTAL RESULTS

In this section, we show more experimental results based on the comparison with the benchmarks.

