IMPROVING THE RECONSTRUCTION OF DISENTAN-GLED REPRESENTATION LEARNERS VIA MULTI-STAGE MODELLING Anonymous

Abstract

Current autoencoder-based disentangled representation learning methods achieve disentanglement by penalizing the (aggregate) posterior to encourage statistical independence of the latent factors. This approach introduces a trade-off between disentangled representation learning and reconstruction quality since the model does not have enough capacity to learn correlated latent variables that capture detail information present in most image data. To overcome this trade-off, we present a novel multi-stage modelling approach where the disentangled factors are first learned using a preexisting disentangled representation learning method (such as β-TCVAE); then, the low-quality reconstruction is improved with another deep generative model that is trained to model the missing correlated latent variables, adding detail information while maintaining conditioning on the previously learned disentangled factors. Taken together, our multi-stage modelling approach results in single, coherent probabilistic model that is theoretically justified by the principal of D-separation and can be realized with a variety of model classes including likelihood-based models such as variational autoencoders, implicit models such as generative adversarial networks, and tractable models like normalizing flows or mixtures of Gaussians. We demonstrate that our multi-stage model has much higher reconstruction quality than current state-of-the-art methods with equivalent disentanglement performance across multiple standard benchmarks.

1. INTRODUCTION

Deep generative models (DGMs) such as variational autoencoders (VAEs) (Kingma & Welling, 2014; Rezende et al., 2014) and generative adversarial networks (GANs) (Goodfellow et al., 2014) have recently enjoyed great success at modeling high dimensional data such as natural images. As the name suggests, DGMs leverage deep learning to model a data generating process. The underlying assumption is that the high dimensional observations X ∈ R D can be meaningfully described by a small set of latent factors H ∈ R K , where K < D. More precisely, the observation (X = x) is assumed to be generated by first sampling a set of low dimensional factors h from a simple prior distribution p(H) and then sampling x ∼ p θ (X|h). DGMs realize p θ through a deep neural network also known as the decoder or the generative network. VAE-based DGMs use another deep neural network (called the encoder or the inference network) to parameterize an approximate posterior q φ (H|x). Learning the variational posterior parameters is done by maximizing an evidence lower bound (ELBO) to the log-marginal likelihood of the data under the model p θ (X). Learning disentangled factors h ∼ q φ (H|x) that are semantically meaningful representations of the observation x is highly desirable because such interpretable representations can arguably be advantageous for a variety of downstream tasks (Locatello et al., 2018) , including classification, detection, reinforcement learning, transfer learning and image synthesis from textual descriptions (Bengio et al., 2013; LeCun et al., 2015; Lake et al., 2017; van Steenkiste et al., 2019; Reed et al., 2016; Zhang et al., 2016) . While a formal definition of disentangled representation (DR) remains elusive, it is understood to mean that by manipulating only one of the factors while holding the rest constant, only one semantically meaningful aspect of the observation (e.g. the pose of an object in an image) changes. Prior work in unsupervised DR learning focuses on the objective of learning statistically independent latent factors as means for obtaining DR. The underlying assumption is 1b ) and MS-VAE (Figure 1c ). MS-VAE is able to take the blurry output of the underlying β-TCVAE model and learn to render a much better approximation of the target while maintaining the pose of the original image (Figure 1a ). While the aforementioned models show promising results, they suffer from a trade-off between DR learning and reconstruction quality. If the latent space is heavily regularized -not allowing enough capacity for the correlated variables -then the reconstruction quality will be diminished, signaling that the learned representation is sub-optimal. As the correlated variables are functionally ignored with high levels of regularization, an observation (X = x) can be thought to be generated by only sampling independent latent factors c from p(C) and then sampling x ∼ p θ (X|c) (Figure 2b ). This is a departure from the original data generating hypothesis that x is sampled from a distribution dependent on both the independent and correlated latent variables. On the other hand, if the correlated variables are not well-constrained, the model can use them to achieve a high quality reconstruction while ignoring the independent variables (the disentangled latent factors). This phenomena is referred to as the "shortcut problem" and has been discussed in previous works (Szabó et al., 2018; Lezama, 2018) . Overcoming the aforementioned trade-off is essential for using these models in real world applications such as realistic, controlled image synthesis (Lee et al., 2020; Lezama, 2018) . In this paper, we propose a new graphical model for DR learning (Figure 2c ) that allows for learning disentangled factors while also correctly realizing the data generating hypothesis that an observation is generated from independent and correlated factors. Importantly, the graphical model in Figure 2c is D-separated, meaning that any changes in the correlated latent variables Z will not influence the independent latent variables C. Generating an observation (X = x) from this model can then be done by sampling the independent factors c from p(C), sampling a low-quality reconstruction y ∼ p θ (Y |c), sampling the correlated factors z from p(Z), and then finally sampling x ∼ p θ (X|z, Y ). The final reconstruction x depends both on z and c, however, any regularization needed to extract the independent factors c no longer diminishes the model capacity for the correlated factors z. To realize our proposed graphical model, we introduce MS-VAE, a multi-stage DGM that is implemented as follows: first, the disentangled representation C is learned using an existing DR learning method such as β-TCVAE (Chen et al., 2018). Since the learned factors C are regularized to be statistically independent -not allowing enough capacity for correlated factors -the final reconstruction Y will have diminished reconstruction quality. Then, we train another DGM to improve the low-quality reconstruction Y by learning the missing correlated factors Z. This is achieved during training by inputting the reconstruction Y into the decoder of the second DGM and then modulating the hidden activation of each layer as function of latent factor Z (using Feature-wise Linear Modulation (Perez et al., 2018) ). Through this training paradigm, MS-VAE is able to preserve conditioning on the disentangled factors while dramatically improving the reconstruction quality. A schematic of MS-VAE is shown in Figure 2d and example images from each stage are shown in Figure 1 . The reconstruction from β-TCVAE (1b) is improved by MS-VAE (1c) to better approximate the target (1a) while maintaining conditioning on the disentangled factors, e.g. azimuth. To summarize our contributions:



Figure 1: Image reconstruction using β-TCVAE (Figure1b) and MS-VAE (Figure1c). MS-VAE is able to take the blurry output of the underlying β-TCVAE model and learn to render a much better approximation of the target while maintaining the pose of the original image (Figure1a).

that the latent variables H can be partitioned into independent components C (i.e. the disentangled factors) and correlated components Z. An observation (X = x) is assumed to be generated by sampling the low dimensional factors h = (c, z) from p(H) and then sampling x ∼ p θ (X|c, z) (Figure2a). A series of works starting from (Higgins et al., 2017) enforce statistical independence of the latent factors H via regularization, up-weighting certain terms in the ELBO which penalize the (aggregate) posterior to be factorized over all or some of the latent dimensions(Kumar et al., 2017;  Kim & Mnih, 2018; Chen et al., 2018)  (see Section 2 for details).

