LEARNING DISENTANGLED REPRESENTATIONS WITH THE WASSERSTEIN AUTOENCODER

Abstract

Disentangled representation learning has undoubtedly benefited from objective function surgery. However, a delicate balancing act of tuning is still required in order to trade off reconstruction fidelity versus disentanglement. Building on previous successes of penalizing the total correlation in the latent variables, we propose TCWAE (Total Correlation Wasserstein Autoencoder). Working in the WAE paradigm naturally enables the separation of the total-correlation term, thus providing disentanglement control over the learned representation, while offering more flexibility in the choice of reconstruction cost. We propose two variants using different KL estimators and perform extensive quantitative comparisons on data sets with known generative factors, showing competitive results relative to state-of-the-art techniques. We further study the trade off between disentanglement and reconstruction on more-difficult data sets with unknown generative factors, where the flexibility of the WAE paradigm in the reconstruction term improves reconstructions.

1. INTRODUCTION

Learning representations of data is at the heart of deep learning; the ability to interpret those representations empowers practitioners to improve the performance and robustness of their models (Bengio et al., 2013; van Steenkiste et al., 2019) . In the case where the data is underpinned by independent latent generative factors, a good representation should encode information about the data in a semantically meaningful manner with statistically independent latent variables encoding for each factor. Bengio et al. (2013) define a disentangled representation as having the property that a change in one dimension corresponds to a change in one factor of variation, while being relatively invariant to changes in other factors. While many attempts to formalize this concept have been proposed (Higgins et al., 2018; Eastwood & Williams, 2018; Do & Tran, 2019) , finding a principled and reproducible approach to assess disentanglement is still an open problem (Locatello et al., 2019) . Recent successful unsupervised learning methods have shown how simply modifying the ELBO objective, either re-weighting the latent regularization terms or directly regularizing the statistical dependencies in the latent, can be effective in learning disentangled representation. Higgins et al. (2017) and Burgess et al. (2018) control the information bottleneck capacity of Variational Autoencoders (VAEs, (Kingma & Welling, 2014; Rezende et al., 2014) ) by heavily penalizing the latent regularization term. Chen et al. (2018) perform ELBO surgery to isolate the terms at the origin of disentanglement in β-VAE, improving the reconstruction-disentanglement trade off. Esmaeili et al. (2018) further improve the reconstruction capacity of β-TCVAE by introducing structural dependencies both between groups of variables and between variables within each group. Alternatively, directly regularizing the aggregated posterior to the prior with density-free divergences (Zhao et al., 2019) or moments matching (Kumar et al., 2018) , or simply penalizing a high Total Correlation (TC, (Watanabe, 1960) ) in the latent (Kim & Mnih, 2018) has shown good disentanglement performances. In fact, information theory has been a fertile ground to tackle representation learning. Achille & Soatto (2018) re-interpret VAEs from an Information Bottleneck view (Tishby et al., 1999) , re-phrasing it as a trade off between sufficiency and minimality of the representation, regularizing a pseudo TC between the aggregated posterior and the true conditional posterior. Similarly, Gao et al. (2019) use the principle of total Correlation Explanation (CorEX) (Ver Steeg & Galstyan, 2014 ) and maximize the mutual information between the observation and a subset of anchor latent points. Maximizing the mutual information (MI) between the observation and the latent has been broadly used (van den Oord et al., 2018; Hjelm et al., 2019; Bachman et al., 2019; Tschannen et al., 2020) , showing encouraging results in representation learning. However, Tschannen et al. (2020) argued that MI maximization alone cannot explain the disentanglement performances of these methods. Building on the Optimal Transport (OT) problem (Villani, 2008) , Tolstikhin et al. (2018) introduced the Wasserstein Autoencoder (WAE), an alternative to VAE for learning generative models. Similarly to VAE, WAE maps the data into a (low-dimensional) latent space while regularizing the averaged encoding distribution. This is in contrast with VAEs where the posterior is regularized at each data point, and allows the encoding distribution to capture significant information about the data while still matching the prior when averaged over the whole data set. Interestingly, by directly regularizing the aggregated posterior, WAE hints at more explicit control on the way the information is encoded, and thus better disentanglement. The reconstruction term of the WAE allows for any cost function on the observation space, opening the door to better suited reconstruction terms, for example when working with continuous RGB data sets where the Euclidean distance or any metric on the observation space can result in more accurate reconstructions of the data. In this work, following the success of regularizing the TC in disentanglement, we propose to use the Kullback-Leibler (KL) divergence as the latent regularization function in the WAE. We introduce the Total Correlation WAE (TCWAE) with an explicit dependency on the TC of the aggregated posterior. Using two different estimators for the KL terms, we perform extensive comparison with succesful methods on a number of data sets. Our results show that TCWAEs achieve competitive disentanglement performances while improving modelling performance by allowing flexibility in the choice of reconstruction cost.

2.1. TOTAL CORRELATION

The TC of a random vector Z ∈ Z under P is defined by TC(Z) d Z d=1 H p d (Z d ) -H p (Z) where p d (z d ) is the marginal density over only z d and H p (Z) -E p log p(Z) is the Shannon differential entropy, which encodes the information contained in Z under P . Since d Z d=1 H p d (Z d ) ≤ H p (Z) with equality when the marginals Z d are mutually independent, the TC can be interpreted as the loss of information when assuming mutual independence of the Z d ; namely, it measures the mutual dependence of the marginals. Thus, in the context of disentanglement learning, we seek a low TC of the aggregated posterior, p(z) = X p(z|x) p(x) dx, which forces the model to encode the data into statistically independent latent codes. High MI between the data and the latent is then obtained when the posterior, p(z|x), manages to capture relevant information from the data.

2.2. TOTAL CORRELATION IN ELBO

We consider latent generative models p θ (x) = Z p θ (x|z) p(z) dz with prior p(z) and decoder network, p θ (x|z), parametrized by θ. VAEs approximate the intractable posterior p(z|x) by introducing an encoding distribution (the encoder), q φ (z|x), and learning simultaneously θ and φ when optimizing the variational lower bound, or ELBO, defined in Eq. 3: L ELBO (θ, φ) E pdata(X) E q φ (Z|X) [log p θ (X|Z)] -KL q φ (Z|X) p(Z) ≤ E pdata(X) log p θ (X) Following Hoffman & Johnson (2016) , we can decompose the KL term in Eq. 3 as: 1 N batch N n=1 KL q φ (Z|x n ) p(Z) = KL q(Z, N ) q(Z)p(N ) i index-code MI + KL q(Z) p(Z) ii marginal KL (4) where p(n) = 1 N , q(z|n) = q(z|x n ), q(z, n) = q(z|n)p(n) and q(z) = N n=1 q(z|n) p(n). i refers to the index-code mutual information and represents the MI between the data and the latent under the join distribution q(z, n), and ii to the marginal KL matching the aggregated posterior to the prior. While discussion on the impact of a high index-code MI on disentanglement learning is still open, the marginal KL term plays an important role in disentanglement. Indeed, it pushes the encoder network to match the prior when averaged, as opposed to matching the prior for each data point. Combined with a factorized prior p(z) = d p d (z d ), as it is often the case, the aggregated posterior is forced to factorize and align with the axis of the prior. More specifically, the marginal KL term in Eq. 4 can be decomposed the as sum of a TC term and a dimensionwise-KL term: KL q(Z) p(Z) = TC q(Z) + d Z d=1 KL q d (Z d ) p d (Z d ) Thus maximizing the ELBO implicitly minimizes the TC of the aggregated posterior, enforcing the aggregated posterior to disentangle as Higgins et al. (2017) and Burgess et al. (2018) observed when strongly penalizing the KL term in Eq. 3. Chen et al. (2018) leverage the KL decomposition in Eq. 5 by refining the heavy latent penalization to the TC only. However, the index-code MI term in Eq. 4 seems to have little to no role in disentanglement (see ablation study of Chen et al. (2018) ), potentially arming the reconstruction performances (Hoffman & Johnson, 2016) .

3. WAE NATURALLY GOOD AT DISENTANGLING?

In this section we introduce the OT problem and the WAE objective, and discuss the compelling properties of WAEs for representation learning. Mirroring β-TCVAE decomposition, we derive the TCWAE objective.

3.1. WAE

The Kantorovich formulation of the OT between the true-but-unknown data distribution P D and the model distribution P θ , for a given cost function c, is defined by: OT c (P D , P θ ) = inf Γ∈P(P D ,P θ ) X ×X c(x, x) γ(x, x) dxdx where P(P D , P θ ) is the space of all couplings of P D and P θ ; namely, the space of joint distributions Γ on X × X whose densities γ have marginals p D and p θ . Tolstikhin et al. (2018) derive the WAE objective by restraining this space and relaxing the hard constraint on the marginal using a soft constraint with a Lagrange multiplier (see Appendix A for more details): W D,c (θ, φ) E p D (x) E q φ (z|x) E p θ (x|z) c(x, x) + λ D q(Z) p(Z) where D is any divergence function and λ a relaxation parameter. The decoder, p θ (x|z), and the encoder, q φ (z|x), are optimized simultaneously by dropping the closed-form minimization over the encoder network, with standard stochastic gradient descent methods. Similarly to the ELBO, objective 7 consists of a reconstruction cost term and a latent regularization term, preventing the latent codes to drift away from the prior. However, WAE explicitly penalizes the aggregate posterior. This motivates, following Section 2.2, the use of WAE in disentanglement learning. Rubenstein et al. (2018) have shown promising disentanglement performances without modifying the objective 7. Another important difference lies in the functional form of the reconstruction cost in the reconstruction term. Indeed, WAE allows for more flexibility in the reconstruction term with any cost function allowed, and in particular, it allows for cost functions better suited to the data at hand and for the use of deterministic decoder networks (Tolstikhin et al., 2018; Frogner et al., 2015) . This can potentially result in an improved reconstruction-disentanglement trade off as we empirically find in Sections 4.2 and 4.1.

3.2. TCWAE

In this section, for notation simplicity, we drop the explicit dependency of the distributions to their respective parameters. Following Section 2.2 and Eq. 5, we chose the divergence function, D, in Eq. 7, to be the KL divergence and assume a factorized prior (e.g. p(z) = N (0 d Z , I d Z ) ), obtaining the same decomposition than in Eq. 5. Re-weighting each term in Eq. 5 with hyper-parameters β and γ, and plugging into Eq. 7, we obtain our TCWAE objective: W T C E p(xn) E q(z|xn) E p(xn|Z) c(x n , xn ) + βKL q(Z) d Z d=1 q d (Z d ) + γ d Z d=1 KL q d (Z d ) p d (Z d ) Given the positivity of the KL divergence, the TCWAE in Eq. 8 is an upper-bound of the WAE objective of Eq. 7 with λ = min(β, γ). Eq. 8 can be directly related to the β-TCVAE objective of Chen et al. ( 2018): -L β-T C E p(xn) E q(z|xn) -log p(x n |Z) + βKL q(Z) d Z d=1 q d (Z d ) + γ d Z d=1 KL q d (Z d ) p d (Z d ) + αI q q(Z, N ); q(Z)p(N ) As already mentioned, the main differences are the absence of index-code MI and a different reconstruction cost function. Setting α = 0 in Eq. 9 makes the two latent regularizations match but breaks the inequality in Eq. 3. Matching the two reconstruction terms would be possible if we could find a ground cost function c such that E p(xn|Z) c(x n , xn ) = -log p(x n |Z).

3.3. ESTIMATORS

While being grounded and motivated by information theory and earlier works on disentanglement, using the KL as the latent divergence function, as opposed to other sampled-based divergences (Tolstikhin et al., 2018; Patrini et al., 2018) , presents its own challenges. Indeed, the KL terms are intractable, and especially, we need estimators to approximate the entropy terms. We propose to use two estimators, one based on importance weight-sampling Chen et al. (2018) , the other on adversarial estimation using the denisty-ratio trick (Kim & Mnih, 2018) . Chen et al. (2018) propose to estimate the intractable terms E q log q(Z) and E q d log q d (Z) in the KL terms of Eq. 8 with Minibatch-Weighted Sampling (MWS). Considering a batch of observation {x 1 , . . . x Nbatch} , they sample the latent codes z i ∼ q(z|x i ) and compute:

TCWAE-MWS

E q(z) log q(z) ≈ 1 N batch Nbatch i=1 log 1 N × N batch Nbatch j=1 q(z i |x j ) This estimator, while being easily computed from samples, is a biased estimator of E q log q(Z). Chen et al. (2018) also proposed an unbiased version, the Minibatch-Stratified Sampling (MSS). However, they found that it did not result in improved performances, and thus, as Chen et al. (2018) , we chose to use the simpler MWS estimator. We call the resulting algorithm the TCWAE-MWS. Other sampled-based estimators of the entropy or the KL divergence have been proposed (Rubenstein et al., 2019; Esmaeili et al., 2018) . However, we choose the solution of Chen et al. (2018) for 1) its simplicity and 2) the similarities between the TCWAE and β-TCVAE objectives.

TCWAE-GAN

A different approach, similar in spirit to the WAE-GAN originally proposed by training to estimate the intractable terms in Eq. 8. The the density-ratio trick (Nguyen et al., 2008; Sugiyama et al., 2011) estimates the KL divergence as: KL q(z) d Z d=1 q d (z d ) ≈ E q(z) log D(z) 1 -D(z) where D plays the same role than the discriminator in GANs and ouputs an estimate of the probability that z is sampled from q(z) and not from d Z d=1 q d (z d ). Given that we can easily sample from q(z), we can use Monte-Carlo sampling to estimate the expectation in Eq. 11. The discriminator D is adversarially trained alongside the decoder and encoder networks. We call this adversarial version the TCWAE-GAN.

4. EXPERIMENTS

We perform a series of quantitative and qualitative experiments, starting with an ablation study on the impact of using different latent regularization functions in WAEs followed by a quantitative comparison of the disentanglement performances of our methods with existing ones on toy data sets before moving to qualitative assessment of our method on more challenging data sets. Details of the data sets, the experimental setup as well as the networks architectures are given in Appendix B. In all the experiments we fix the ground-cost function of the WAE-based methods to be the square Euclidean distance: c(x, y) = x -y 2 L2 .

4.1. QUANTITATIVE ANALYSIS: DISENTANGLEMENT ON TOY DATA SETS

Ablation study of the latent divergence function We compare the impact of the different latent regularization functions in WAE-MMD (Tolstikhin et al., 2018) , TCWAE-MWS and TCWAE-GAN. We take β = γ in the TCWAE objectives isolating the impact of the different latent divergence functions used in the TCWAE and the original WAE. We train the methods with β ∈ {1, 2, 4, 6, 8, 10}, and report the results Figure 1 in the case of the NoisydSprites data set (Locatello et al., 2019) . As expected, the higher the penalization on the latent regularization (high β), the poorer the reconstructions. We can see that the trade off between reconstruction and latent regularization is more sensible for TCWAE-GAN, where a relatively modest improvement in latent regularization results in an important deterioration of reconstruction performances while TCWAE-MWS is less sensible. This is better illustrated in Figure 1c with a much higher slope for TCWAE-GAN than for TCWAE-MWS. WAE seems to be relatively little impacted by the latent penalization weight. We note in Figure1b the bias of the MWS estimator (Chen et al., 2018) . Finally, we plot the reconstruction versus the MMD between the aggregated posterior and the prior for all the models in Figure (1d ). Interestingly, TCWAEs actually achieved a lower MMD (left part of the plot) even if they are not being trained with that regularization function. However, as expected given that the TCWAE do not optimized the reconstruction-MMD trade off, the WAE achieved a better reconstruction (bottom part of the plot). Disentanglement performances We compare our methods with β-TCVAE (Chen et al., 2018) , FactorVAE (Kim & Mnih, 2018) and the original WAE-MMD (Tolstikhin et al., 2018) on the dSprites (Matthey et al., 2017) , NoisydSprites (Locatello et al., 2019) , ScreamdSprites (Locatello et al., 2019) and smallNORB (LeCun et al., 2004 ) data sets whose ground-truth generative-factors are known and given in Table 3 , Appendix B.1. We use three different disentanglement metrics to assess the disentanglement performances: the Mutual Information Gap (MIG, Chen et al. ( 2018)), the factorVAE metric (Kim & Mnih, 2018) and the Separated Attribute Predictability score (SAP, Kumar et al. (2018) ). We follow Locatello et al. (2019) for the implementation of these metrics. We use the Mean Square Error (MSE) of the reconstructions to assess the reconstruction performances of the methods. For each model, we use 6 different values for each parameter, resulting in thirty-six different models for TCWAEs, and six for the remaining methods (see Appendix B.1 for more details). Mirroring the benchmark methods, we first tune γ in the TCWAEs, regularizing the dimensionwise-KL, subsequently focusing on the role of the TC term in the disentanglement performances. The heat maps of the different scores for each method and data set are given Figures 5, 6 , 7 and 8 in Appendix C. As expected, while β controls the trade off between reconstruction and disentanglement, γ affects the range achievable when tuning β. Especially, for γ > 1, we can see Figures5,6, 7 and 8 that better disentanglement is obtained without much deterioration in reconstruction. Table 1 reports the results, averaged over 5 random runs, for the four different data sets. For each method, we report the best β taken to be the one achieving an overall best ranking on the four different metrics (see Appendix ?? for details). Note that the performances of WAE on the dSprites data set, both in term of reconstruction and disentanglement where significantly worse and meaningless, thus, in order to avoid unfair extra tuning of the parameters, we chose not to include them. TCWAEs achieve competitive performances across all the data sets, with top scores in several metrics. Especially, the square Euclidean distance seems to improve the trade off and perform better than the cross-entropy with color images (NoisydSprites, ScreamdSprites) but less so with black and white images (dSprites). See Appendix C for more results on the different data sets. 1 and 7 . Each row i corresponds to the traversal of the latent z i while the columns correspond to a step in the that traversal. The rows are order by increasing KL 1/N test testset q(z i |x) p(z i ) and the traversal range is [-2, 2] . Finally, we visualise the reconstruction-disentanglement trade off by plotting the different disentanglement metrics against the MSE in Figure 3 . As expected, when the TC regularization weight is increased, the reconstruction deteriorates while the disentanglement improves up to a certain point. Then, when too much penalization is put on the TC term, the poor quality of the reconstructions prevents any disentanglement in the generative factors. Reflecting the results of Table 1 , TCWAE-MWS seems to perform better (top-left corner represents better reconstruction and disentanglement). TCWAE-GAN presents better reconstruction but slightly lower disentanglement performances (bottom left corner). We train our methods on 3Dchairs (Aubry et al., 2014) and CelebA (Liu et al., 2015) whose generative factors are not known and qualitatively find that TCWAEs achieve good disentanglement. inspection of the reconstructions and samples in Appendix D shows that FactorVAE in fact struggle to generalize and to learn a smooth latent manifold.

5. CONCLUSION

Leveraging the surgery of the KL regularization term of the ELBO objective, we design a new disentanglement method based on the WAE objective whose latent divergence function is taken to be the KL divergence between the aggregated posterior and the prior. The WAE framework naturally enables the latent regularization to depend explicitly on the TC of the aggregated posterior, quantity previously associated with disentanglement. Using two different estimators of the KL terms, we show that our methods achieve competitive disentanglement on toy data sets. Moreover, the flexibility in the choice of the reconstruction cost function offered by the WAE framework makes our method more compelling when working with more challenging data sets.

A WAE DERIVATION

We recall the Kantorovich formulation of the OT between the true-but-unknown data distribution P D and the model distribution P θ , with given cost function c: OT c (P D , P θ ) = inf Γ∈P(P D ,P θ ) X ×X c(x, x) γ(x, x) dx dx where P(P D , P θ ) is the space of all couplings of P D and P θ : P(P D , P θ ) = Γ X γ(x, x) dx = p D (x), X γ(x, x) dx = p θ (x) Tolstikhin et al. ( 2018) first restrain the space of couplings to the joint distributions of the form: γ(x, x) = Z p θ (x|z) q(z|x) p D (x) dz where q(z|x), for x ∈ X , plays the same role as the variational distribution in variational inference. While the marginal constraint on x (first constraint in Eq. 13) in Eq. 14 is satisfied by construction, the second marginal constraint (that over x giving p θ in in Eq. 13) is not guaranteed. A sufficient condition is to have for all z ∈ Z: X q(z|x) p D (x) dx = p(z) Secondly, Tolstikhin et al. (2018) relax the constraint in Eq. 15 using a soft constraint with a Lagrange multiplier: W c (P D , P θ ) = inf q(Z|X) X ×X c(x, x) γ(x, x) dx dx + λ D q(Z) p(Z) where D is any divergence function, λ a relaxation parameter, γ is defined in Eq. 14 and q(Z) is the aggregated posterior as define in Section 2. Finally, they drop the closed-form minimization over the variational distribution q(z|x), to obtain the WAE objective, as defined in Section 3.1: We use a batch size of 64 in Section 4.2, while in the main experiments of Section 4.1, we take a batch size of 100. In the ablation study of Section 4.1, we use a bigger batch size of 256 in order to reduce the impact of the bias of the MWS estimator (Chen et al. (2018) however show that there is very little impact on the performance of the MWS when using smaller batch size). For all experiments, we use the Adam optimizer (Kingma & Ba, 2015) with a learning rate of 0.0005, beta1 of 0.9, beta2 of 0.999 and epsilon of 0.0008 and train for 300,000 iterations. For all the data sets of Section 4.1, we take the latent dimension d Z = 10, while we use d Z = 16 for 3Dchairs and d Z = 32 for CelebA. We use Gaussian encoders with diagonal covariance matrix in all the models and deterministic decoder networks when possible (WAE-based methods). We follow The different parameter values used for each experiment are given Table 4 . γ is chosen such that the resulting method achieves the best score s, when averaging over all the β values, where the score is defined as the sum of the ranking on each individual metric: s = r M SE + metric r metric where r M SE designed the ranking of the MSE (lower is better) and r metric , for metric in {MIG, FactorVAE, SAP}, is the ranking of the disentanglement performances as measured by the given metric (higher is better). β is then chosen such that the resulting method, with the previously found γ, achieves the best overall score s defined above. In Section 4.1, we use a validation run to select the parameters values and report the MSE and FID scores on a test run. MSE are computed on a test set of size 10,000 with batch size of 1,000, while we follow Heusel et al. (2017) for the FID implementation: we first compute the activation statistics of the features maps on the full test set for both the reconstruction, respectively samples, and the true observations. We then compute the Frechet distance between two Gaussian with the computed statistics. W D,c (θ, φ) E p D (X) E q φ (z|x) E p θ (x|z) c(x, x) + λ D q(Z) p(Z) ≈ E p(xn) E q φ (z|xn) E p θ (xn|z) c(x, xn ) + λ D q(Z) p(Z)

B.2 MODELS ARCHITECTURES

The Gaussian encoder networks, q φ (z|x) and decoder network, p θ (x|z), are parametrized by neural networks as follow: p θ (x|z) = δ f θ (z) if WAE based method, N µ θ (z), σ 2 θ (z) otherwise. q φ (z|x) = N µ φ (x), σ 2 φ (x) where f θ , µ θ , σ 2 θ , µ φ and σ 2 φ are the outputs of convolutional neural networks. All the experiments use the architectures of Locatello et al. (2019) except for CelebA where we use the architecture inspired by Tolstikhin et al. (2018) . The details for the architectures are given Table 5 . All the discriminator networks, D, are fully connected networks and share the same architecture given Table 5 . The optimisation setup for the discriminator is given Table 6 . Learning rate 1e -4 (Section 4.1) / 1e -5 (Section 4.2) beta 1 0.5 beta 2 0.9 epsilon 1e-08 For each method, we plot the distribution (over five random runs) of the different metrics for different β values. 1 and 7 (a) Reconstructions 



Figure 1: Reconstruction and latent regularization terms as functions of β for the NoisydSprites data set. (a): reconstruction error. (b): latent regularization term (MMD for WAE, KL for TCWAE). (c): reconstruction error against latent regularization. (d): reconstruction error against MMD. Shaded regions show ± one standard deviation.

Figure 2: Latent traversals for each model on smallNORB. The parameters are the same than the ones reported in Tables1 and 7. Each row i corresponds to the traversal of the latent z i while the columns correspond to a step in the that traversal. The rows are order by increasing KL 1/N test testset q(z i |x) p(z i ) and the traversal range is [-2, 2].

Figure 3: Disentanglement versus reconstruction on the ScreamdSprites data set. Annotations at each point are values of β. Points with low reconstruction error and high scores (top-left corner) represent better models.

Figure 4: Latent traversals for TCWAE-MWS and TCWAE-GAN. Each line corresponds to one input data point while each subplot corresponds to one latent factor. We vary evenly the encoded latent codes in the interval [-4, 4].

Figure 9: Violin plots of the different scores versus γ on dSprites.

Figure 10: Violin plots of the different scores versus γ on NoisydSprites.

Figure 11: Violin plots of the different scores versus γ on ScreamdSprites.

Figure 12: Violin plots of the different scores versus γ on smallNORB.

Figure 13: Samples and reconstructions for each model on the NoisydSprites. (a): Reconstructions. Top-row: input data, from second-to-top to bottom row: WAE, TCWAE-MWS, TCWAE-GAN , β-TCVAE, FactorVAE. (b) Samples. From top to bottom row: WAE, TCWAE-MWS, TCWAE-GAN, β-TCVAE, FactorVAE. Parameters are the ones reported in Tables1 and 7

Figure 14: Same than Figure 13 but for ScreamdSprites.

Reconstruction and disentanglement scores (± one standard deviation).

MSE and FID scores for the different data sets. Details of the methodology is given in Appendix B

We train and compare our methods on four different data sets, two with known ground-truth generative factors (see Table3): dSprites(Matthey et al., 2017) with 737,280 binary, 64 × 64 images and smallNORB(LeCun et al., 2004) with 48,600 greyscale, 64 × 64 images; and two with unknown ground-truth generative factors: 3Dchairs(Aubry et al., 2014) with 86,366 RGB, 64 × 64 images and CelebA(Liu et al., 2015) with 202,599 RGB 64 × 64 images.

Ground-truth generative-factors of the dSprites and smallNORB data sets.

Hyper parameters values ranges used in the different Sections.

Locatello et al. (2019) for the architectures in all the experiments expect for CelebA where we followTolstikhin et al. (2018) (details of the networks architectures given Section B.2). We use a (positive) mixture of Inverse MultiQuadratic (IMQ) kernels and the associated reproductive Hilbert space to compute the MMD when it is needed (WAE and ablation study of Section 4.1).

Networks architectures CONV. 4 × 4 × 32 stride 2 ReLU FC 256 ReLU FC 1000 ReLU CONV. 4 × 4 × 32 stride 2 ReLU FC 4 × 4 × 64 ReLU FC 1000 ReLU CONV. 4 × 4 × 64 stride 2 ReLU CONV. 4 × 4 × 64 stride 2 ReLU FC 1000 ReLU CONV. 4 × 4 × 64 stride 2 ReLU CONV. 4 × 4 × 32 stride 2 ReLU FC 1000 ReLU FC 256 Relu CONV. 4 × 4 × 32 stride 2 ReLU FC 1000 ReLU FC 2 × d Z CONV. 4 × 4 × c stride 2 FC 1000 ReLU FC 2 (a) Locatello et al. (2019) architectures Input: d Z Input: d Z CONV. 4 × 4 × 32 stride 2 BN ReLU FC 8 × 8 × 256 BN ReLU FC 1000 ReLU CONV. 4 × 4 × 64 stride 2 BN ReLU CONV. 4 × 4 × 128 stride 2 BN ReLU FC 1000 ReLU CONV. 4 × 4 × 128 stride 2 BN ReLU CONV. 4 × 4 × 64 stride 2 BN ReLU FC 1000 ReLU CONV. 4 × 4 × 256 stride 2 BN ReLU CONV. 4 × 4 × 32stride 2 BN Relu FC 1000 ReLU FC 2 × d Z CONV. 4 × 4 × c FC 1000 ReLU FC 1000 ReLU FC 2

FactorVAE discriminator setup

γ values for methods for each data set.

