LEARNING DISENTANGLED REPRESENTATIONS WITH THE WASSERSTEIN AUTOENCODER

Abstract

Disentangled representation learning has undoubtedly benefited from objective function surgery. However, a delicate balancing act of tuning is still required in order to trade off reconstruction fidelity versus disentanglement. Building on previous successes of penalizing the total correlation in the latent variables, we propose TCWAE (Total Correlation Wasserstein Autoencoder). Working in the WAE paradigm naturally enables the separation of the total-correlation term, thus providing disentanglement control over the learned representation, while offering more flexibility in the choice of reconstruction cost. We propose two variants using different KL estimators and perform extensive quantitative comparisons on data sets with known generative factors, showing competitive results relative to state-of-the-art techniques. We further study the trade off between disentanglement and reconstruction on more-difficult data sets with unknown generative factors, where the flexibility of the WAE paradigm in the reconstruction term improves reconstructions.

1. INTRODUCTION

Learning representations of data is at the heart of deep learning; the ability to interpret those representations empowers practitioners to improve the performance and robustness of their models (Bengio et al., 2013; van Steenkiste et al., 2019) . In the case where the data is underpinned by independent latent generative factors, a good representation should encode information about the data in a semantically meaningful manner with statistically independent latent variables encoding for each factor. Bengio et al. (2013) define a disentangled representation as having the property that a change in one dimension corresponds to a change in one factor of variation, while being relatively invariant to changes in other factors. While many attempts to formalize this concept have been proposed (Higgins et al., 2018; Eastwood & Williams, 2018; Do & Tran, 2019) , finding a principled and reproducible approach to assess disentanglement is still an open problem (Locatello et al., 2019) . Recent successful unsupervised learning methods have shown how simply modifying the ELBO objective, either re-weighting the latent regularization terms or directly regularizing the statistical dependencies in the latent, can be effective in learning disentangled representation. Higgins et al. (2018) further improve the reconstruction capacity of β-TCVAE by introducing structural dependencies both between groups of variables and between variables within each group. Alternatively, directly regularizing the aggregated posterior to the prior with density-free divergences (Zhao et al., 2019) or moments matching (Kumar et al., 2018) , or simply penalizing a high Total Correlation (TC, (Watanabe, 1960) ) in the latent (Kim & Mnih, 2018) has shown good disentanglement performances. In fact, information theory has been a fertile ground to tackle representation learning. Achille & Soatto (2018) re-interpret VAEs from an Information Bottleneck view (Tishby et al., 1999) , re-phrasing it as a trade off between sufficiency and minimality of the representation, regularizing a pseudo TC between the aggregated posterior and the true conditional posterior. Similarly, Gao et al. ( 2019) use the principle of total Correlation Explanation (CorEX) (Ver Steeg & Galstyan, 2014 ) and maximize the mutual information between the observation and a subset of anchor latent points. Maximizing the mutual information (MI) between the observation and the latent has been broadly used (van den Oord et al., 2018; Hjelm et al., 2019; Bachman et al., 2019; Tschannen et al., 2020) , showing encouraging results in representation learning. However, Tschannen et al. (2020) argued that MI maximization alone cannot explain the disentanglement performances of these methods. Building on the Optimal Transport (OT) problem (Villani, 2008 ), Tolstikhin et al. (2018) introduced the Wasserstein Autoencoder (WAE), an alternative to VAE for learning generative models. Similarly to VAE, WAE maps the data into a (low-dimensional) latent space while regularizing the averaged encoding distribution. This is in contrast with VAEs where the posterior is regularized at each data point, and allows the encoding distribution to capture significant information about the data while still matching the prior when averaged over the whole data set. Interestingly, by directly regularizing the aggregated posterior, WAE hints at more explicit control on the way the information is encoded, and thus better disentanglement. The reconstruction term of the WAE allows for any cost function on the observation space, opening the door to better suited reconstruction terms, for example when working with continuous RGB data sets where the Euclidean distance or any metric on the observation space can result in more accurate reconstructions of the data. In this work, following the success of regularizing the TC in disentanglement, we propose to use the Kullback-Leibler (KL) divergence as the latent regularization function in the WAE. We introduce the Total Correlation WAE (TCWAE) with an explicit dependency on the TC of the aggregated posterior. Using two different estimators for the KL terms, we perform extensive comparison with succesful methods on a number of data sets. Our results show that TCWAEs achieve competitive disentanglement performances while improving modelling performance by allowing flexibility in the choice of reconstruction cost.

2.1. TOTAL CORRELATION

The TC of a random vector Z ∈ Z under P is defined by TC(Z) d Z d=1 H p d (Z d ) -H p (Z) where p d (z d ) is the marginal density over only z d and H p (Z) -E p log p(Z) is the Shannon differential entropy, which encodes the information contained in Z under P . Since d Z d=1 H p d (Z d ) ≤ H p (Z) with equality when the marginals Z d are mutually independent, the TC can be interpreted as the loss of information when assuming mutual independence of the Z d ; namely, it measures the mutual dependence of the marginals. Thus, in the context of disentanglement learning, we seek a low TC of the aggregated posterior, p(z) = X p(z|x) p(x) dx, which forces the model to encode the data into statistically independent latent codes. High MI between the data and the latent is then obtained when the posterior, p(z|x), manages to capture relevant information from the data.

2.2. TOTAL CORRELATION IN ELBO

We consider latent generative models p θ (x) = Z p θ (x|z) p(z) dz with prior p(z) and decoder network, p θ (x|z), parametrized by θ. VAEs approximate the intractable posterior p(z|x) by introducing an encoding distribution (the encoder), q φ (z|x), and learning simultaneously θ and φ when optimizing the variational lower bound, or ELBO, defined in Eq. 3: L ELBO (θ, φ) E pdata(X) E q φ (Z|X) [log p θ (X|Z)] -KL q φ (Z|X) p(Z) ≤ E pdata(X) log p θ (X) (3)



(2017) and Burgess et al. (2018) control the information bottleneck capacity of Variational Autoencoders (VAEs, (Kingma & Welling, 2014; Rezende et al., 2014)) by heavily penalizing the latent regularization term. Chen et al. (2018) perform ELBO surgery to isolate the terms at the origin of disentanglement in β-VAE, improving the reconstruction-disentanglement trade off. Esmaeili et al.

