SELF-REFLECTIVE VARIATIONAL AUTOENCODER

Abstract

The Variational Autoencoder (VAE) is a powerful framework for learning probabilistic latent variable generative models. However, typical assumptions on the approximate posterior distributions can substantially restrict its capacity for inference and generative modeling. Variational inference based on neural autoregressive models respects the conditional dependencies of the exact posterior, but this flexibility comes at a cost: the resulting models are expensive to train in highdimensional regimes and can be slow to produce samples. In this work, we introduce an orthogonal solution, which we call self-reflective inference. By redesigning the hierarchical structure of existing VAE architectures, self-reflection ensures that the stochastic flow preserves the factorization of the exact posterior, sequentially updating the latent codes in a manner consistent with the generative model. We empirically demonstrate the advantages of matching the variational posterior to the exact posterior-on binarized MNIST self-reflective inference achieves state-of-the-art performance without resorting to complex, computationally expensive components such as autoregressive layers. Moreover, we design a variational normalizing flow that employs the proposed architecture, yielding predictive benefits compared to its purely generative counterpart. Our proposed modification is quite general and it complements the existing literature; self-reflective inference can naturally leverage advances in distribution estimation and generative modeling to improve the capacity of each layer in the hierarchy.

1. INTRODUCTION

The advent of deep learning has led to great strides in both supervised and unsupervised learning. One of the most popular recent frameworks for the latter is the Variational Autoencoder (VAE), in which a probabilistic encoder and generator are jointly trained via backpropagation to simultaneously perform sampling and variational inference. Since the introduction of the VAE (Kingma & Welling, 2014) , or more generally, the development of techniques for low-variance stochastic backpropagation of Deep Latent Gaussian Models (DLGMs) (Rezende et al., 2014) , research has rapidly progressed towards improving their generative modeling capacity and/or the quality of their variational approximation. However, as deeper and more complex architectures are introduced, care must be taken to ensure the correctness of various modeling assumptions, whether explicit or implicit. In particular, when working with hierarchical models it is easy to unintentionally introduce mismatches in the generative and inference models, to the detriment of both. In this work, we demonstrate the existence of such a modeling pitfall common to much of the recent literature on DLGMs. We discuss why this problem emerges, and we introduce a simple-yet crucial-modification to the existing architectures to address the issue. Vanilla VAE architectures make strong assumptions about the posterior distribution-specifically, it is standard to assume that the posterior is approximately factorial. More recent research has investigated the effect of such assumptions which govern the variational posterior (Wenzel et al., 2020) or prior (Wilson & Izmailov, 2020) in the context of uncertainty estimation in Bayesian neural networks. In many scenarios, these restrictions have been found to be problematic. A large body of recent work attempts to improve performance by building a more complex encoder and/or decoder with convolutional layers and more modern architectures (such as ResNets (He et al., 2016) ) (Salimans et al., 2015; Gulrajani et al., 2017) or by employing more complex posterior distributions constructed with autoregressive layers (Kingma et al., 2016; Chen et al., 2017) . Other work (Tomczak & Welling, 2018; Klushyn et al., 2019a) focuses on refining the prior distribution of the latent codes. Taking a different approach, hierarchical VAEs (Rezende et al., 2014; Gulrajani et al., 2017; Sønderby et al., 2016; Maaløe et al., 2019; Klushyn et al., 2019b) leverage increasingly deep and interdependent layers of latent variables, similar to how subsequent layers in a discriminative network are believed to learn more and more abstract representations. These architectures exhibit superior generative and reconstructive capabilities since they allow for modeling of much richer latent spaces. While the benefits of incorporating hierarchical latent variables is clear, all existing architectures suffer from a modeling mismatch which results in sub-optimal performance: the variational posterior does not respect the factorization of the exact posterior distribution of the generative model. In earlier works on hierarchical VAEs (Rezende et al., 2014) , inference proceeds bottom-up, counter to the top-down generative process. To better match the order of dependence of latent variables to that of the generative model, later works (Sønderby et al., 2016; Bachman, 2016) split inference into two stages: first a deterministic bottom-up pass which does necessary precomputation for evidence encoding, followed by a stochastic top-down pass which incorporates the hierarchical latents to form a closer variational approximation to the exact posterior. Crucially, while these newer architectures ensure that the order of the latent variables mirrors that of the generative model, the overall variational posterior does not match because of the strong restrictions on the variational distributions of each layer. Contributions. In this work, we propose to restructure common hierarchical VAE architectures with a series of bijective layers which enable communication between the inference and generative networks, refining the latent representations. Concretely, our contributions are as follows: • We motivate and introduce a straightforward rearrangement of the stochastic flow of the model which addresses the aforementioned modeling mismatch. This modification substantially compensates for the observed performance gap between models with only simple layers and those with complex autoregressive networks (Kingma et al., 2016; Chen et al., 2017) . • We formally prove that this refinement results in a hierarchical VAE whose variational posterior respects the precise factorization of the exact posterior. To the best of our knowledge, this is the first deep architecture to do so without resorting to computationally expensive autoregressive components or making strong assumptions (e.g., diagonal Gaussian) on the distributions of each layer (Sønderby et al., 2016) -assumptions that lead to degraded performance. • We experimentally demonstrate the benefits of the improved representation capacity of this model, which stems from the corrected factorial form of the posterior. We achieve state-of-the-art perfomance on MNIST among models without autoregressive layers, and our model performs on par with recent, fully autoregressive models such as Kingma et al. (2016) . Due to the simplicity of our architecture, we achieve these results for a fraction of the computational cost in both training and inference. • We design a hierarchical variational normalizing flow that deploys the suggested architecture in order to recursively update the base distribution and the conditional bijective transformations. This architecture significantly improves upon the predictive performance and data complexity of a Masked Autoregressive Flow (MAF) (Papamakarios et al., 2017) on CIFAR-10. Finally, it should be noted that our contribution is quite general and can naturally leverage recent advances in variational inference and deep autoencoders (Chen et al., 2017; Kingma et al., 2016; Tomczak & Welling, 2018; Burda et al., 2016; Dai & Wipf, 2019; van den Oord et al., 2016a; Rezende & Viola, 2018) as well as architectural improvements to density estimation (Gulrajani et al., 2017; Dinh et al., 2017; Kingma & Dhariwal, 2018; Durkan et al., 2019; van den Oord et al., 2016b; Gregor et al., 2015) . We suspect that combining our model with other state-of-the-art methods could further improve the attained performance, which we leave to future work.

2. VARIATIONAL AUTONENCODERS

A Variational Autoencoder (VAE) (Kingma & Welling, 2014; 2019) is a generative model which is capable of generating samples x ∈ R D from a distribution of interest p(x) by utilizing latent variables z coming from a prior distribution p(z). To perform inference, the marginal likelihood

