SIMPLE AND EFFECTIVE VAE TRAINING WITH CALIBRATED DECODERS

Abstract

Variational autoencoders (VAEs) provide an effective and simple method for modeling complex distributions. However, training VAEs often requires considerable hyperparameter tuning to determine the optimal amount of information retained by the latent variable. We study the impact of calibrated decoders, which learn the uncertainty of the decoding distribution and can determine this amount of information automatically, on the VAE performance. While many methods for learning calibrated decoders have been proposed, many of the recent papers that employ VAEs rely on heuristic hyperparameters and ad-hoc modifications instead. We perform the first comprehensive comparative analysis of calibrated decoder and provide recommendations for simple and effective VAE training. Our analysis covers a range of datasets and several single-image and sequential VAE models. We further propose a simple but novel modification to the commonly used Gaussian decoder, which computes the prediction variance analytically. We observe empirically that using heuristic modifications is not necessary with our method.

1. INTRODUCTION

Deep density models based on the variational autoencoder (VAE) (Kingma & Welling, 2014; Rezende et al., 2014) have found ubiquitous use in probabilistic modeling and representation learning as they are both conceptually simple and are able to scale to very complex distributions and large datasets. These VAE techniques are used for tasks such as future frame prediction (Castrejon et al., 2019) , image segmentation (Kohl et al., 2018 ), generating speech (Chung et al., 2015) and music (Dhariwal et al., 2020) , as well as model-based reinforcement learning (Hafner et al., 2019a) . However, in practice, many of these approaches require careful manual tuning of the balance between two terms that correspond to distortion and rate from information theory (Alemi et al., 2017) . This balance trades off fidelity of reconstruction and quality of samples from the model: a model with low rate would not contain enough information to reconstruct the data, while allowing the model to have high rate might lead to unrealistic samples from the prior as the KL-divergence constraint becomes weaker (Alemi et al., 2017; Higgins et al., 2017) . While a proper variational lower bound does not expose any free parameters to control this tradeoff, many prior works heuristically introduce a weight on the prior KL-divergence term, often denoted β. Usually, β needs to be tuned for every dataset and model variant as a hyperparameter, which slows down development and can lead to poor performance as finding the optimal value is often prohibitively computationally expensive. Moreover, using β = 1 precludes the appealing interpretation of the VAE objective as a bound on the data likelihood, and is undesirable for applications like density modeling. While many architectures for calibrating decoders have been proposed in the literature (Kingma & Welling, 2014; Kingma et al., 2016; Dai & Wipf, 2019) , more applied work typically employs VAEs with uncalibrated decoding distributions, such as Gaussian distributions without a learned variance, where the decoder only outputs the mean parameter (Castrejon et al., 2019; Denton & Fergus, 2018; Lee et al., 2019; Babaeizadeh et al., 2018; Lee et al., 2018; Hafner et al., 2019b; Pong et al., 2019; Zhu et al., 2017; Pavlakos et al., 2019) , or uses other ad-hoc modifications to the objective (Sohn et al., 2015; Henaff et al., 2019) . Indeed, it is well known that attempting to learn the variance in a Gaussian decoder may lead to numerical instability (Rezende & Viola, 2018; Dai & Wipf, 2019) , and naïve approaches often lead to poor results. As a result, it remains unclear whether practical empirical performance of VAEs actually benefits from calibrated decoders or not.

