IMPROVING THE RECONSTRUCTION OF DISENTAN-GLED REPRESENTATION LEARNERS VIA MULTI-STAGE MODELLING Anonymous

Abstract

Current autoencoder-based disentangled representation learning methods achieve disentanglement by penalizing the (aggregate) posterior to encourage statistical independence of the latent factors. This approach introduces a trade-off between disentangled representation learning and reconstruction quality since the model does not have enough capacity to learn correlated latent variables that capture detail information present in most image data. To overcome this trade-off, we present a novel multi-stage modelling approach where the disentangled factors are first learned using a preexisting disentangled representation learning method (such as β-TCVAE); then, the low-quality reconstruction is improved with another deep generative model that is trained to model the missing correlated latent variables, adding detail information while maintaining conditioning on the previously learned disentangled factors. Taken together, our multi-stage modelling approach results in single, coherent probabilistic model that is theoretically justified by the principal of D-separation and can be realized with a variety of model classes including likelihood-based models such as variational autoencoders, implicit models such as generative adversarial networks, and tractable models like normalizing flows or mixtures of Gaussians. We demonstrate that our multi-stage model has much higher reconstruction quality than current state-of-the-art methods with equivalent disentanglement performance across multiple standard benchmarks.

1. INTRODUCTION

Deep generative models (DGMs) such as variational autoencoders (VAEs) (Kingma & Welling, 2014; Rezende et al., 2014) and generative adversarial networks (GANs) (Goodfellow et al., 2014) have recently enjoyed great success at modeling high dimensional data such as natural images. As the name suggests, DGMs leverage deep learning to model a data generating process. The underlying assumption is that the high dimensional observations X ∈ R D can be meaningfully described by a small set of latent factors H ∈ R K , where K < D. More precisely, the observation (X = x) is assumed to be generated by first sampling a set of low dimensional factors h from a simple prior distribution p(H) and then sampling x ∼ p θ (X|h). DGMs realize p θ through a deep neural network also known as the decoder or the generative network. VAE-based DGMs use another deep neural network (called the encoder or the inference network) to parameterize an approximate posterior q φ (H|x). Learning the variational posterior parameters is done by maximizing an evidence lower bound (ELBO) to the log-marginal likelihood of the data under the model p θ (X). Learning disentangled factors h ∼ q φ (H|x) that are semantically meaningful representations of the observation x is highly desirable because such interpretable representations can arguably be advantageous for a variety of downstream tasks (Locatello et al., 2018) , including classification, detection, reinforcement learning, transfer learning and image synthesis from textual descriptions (Bengio et al., 2013; LeCun et al., 2015; Lake et al., 2017; van Steenkiste et al., 2019; Reed et al., 2016; Zhang et al., 2016) . While a formal definition of disentangled representation (DR) remains elusive, it is understood to mean that by manipulating only one of the factors while holding the rest constant, only one semantically meaningful aspect of the observation (e.g. the pose of an object in an image) changes. Prior work in unsupervised DR learning focuses on the objective of learning statistically independent latent factors as means for obtaining DR. The underlying assumption is

