MULTI-RATE VAE: TRAIN ONCE, GET THE FULL RATE-DISTORTION CURVE

Abstract

Variational autoencoders (VAEs) are powerful tools for learning latent representations of data used in a wide range of applications. In practice, VAEs usually require multiple training rounds to choose the amount of information the latent variable should retain. This trade-off between the reconstruction error (distortion) and the KL divergence (rate) is typically parameterized by a hyperparameter β. In this paper, we introduce Multi-Rate VAE (MR-VAE), a computationally efficient framework for learning optimal parameters corresponding to various β in a single training run. The key idea is to explicitly formulate a response function using hypernetworks that maps β to the optimal parameters. MR-VAEs construct a compact response hypernetwork where the pre-activations are conditionally gated based on β. We justify the proposed architecture by analyzing linear VAEs and showing that it can represent response functions exactly for linear VAEs. With the learned hypernetwork, MR-VAEs can construct the rate-distortion curve without additional training and can be deployed with significantly less hyperparameter tuning. Empirically, our approach is competitive and often exceeds the performance of multiple β-VAEs training with minimal computation and memory overheads.

1. INTRODUCTION

Deep latent variable models sample latent factors from a prior distribution and convert them to realistic data points using neural networks. However, computing the model parameters via maximum likelihood estimation is challenging due to the need to marginalize the latent factors, which is intractable. Variational Autoencoders (VAEs) (Kingma & Welling, 2013; Rezende et al., 2014) formulate a tractable lower bound for the log-likelihood and enable optimization of deep latent variable models by reparameterization of the Evidence Lower Bound (ELBO) (Jordan et al., 1999) . VAEs have been applied in many different contexts, including text generation (Bowman et al., 2015) , data augmentation generation (Norouzi et al., 2020) , anomaly detection (An & Cho, 2015; Park et al., 2022) , future frame prediction (Castrejon et al., 2019) , image segmentation (Kohl et al., 2018) , and music generation (Roberts et al., 2018) . In practice, VAEs are typically trained with the β-VAE objective (Higgins et al., 2016) which balances the reconstruction error (distortion) and the KL divergence term (rate): L β (ϕ, θ) = E p d (x) [E q ϕ (z|x) [-log p θ (x|z)]] Distortion + β E p d (x) [D KL (q ϕ (z|x), p(z))] Rate , where p θ (x|z) models the process that generates the data x given the latent variable z (the "decoder") and q ϕ (z|x) is the variational distribution (the "encoder"), parameterized by θ ∈ R m and ϕ ∈ R p , respectively. Here, p(z) is a prior on the latent variables, p d (x) is the data distribution, and β > 0 is the weight on the KL term that trades off between rate and distortion. On the one hand, models with low distortions can reconstruct data points with high quality but may generate unrealistic data points due to large discrepancies between variational distributions and priors (Alemi et al., 2018) . On the other hand, models with low rates have variational distributions close to the prior but may not have encoded enough useful information to reconstruct the data. Hence, * Correspondence to jbae@cs.toronto.edu. 1

