MULTI-RATE VAE: TRAIN ONCE, GET THE FULL RATE-DISTORTION CURVE

Abstract

Variational autoencoders (VAEs) are powerful tools for learning latent representations of data used in a wide range of applications. In practice, VAEs usually require multiple training rounds to choose the amount of information the latent variable should retain. This trade-off between the reconstruction error (distortion) and the KL divergence (rate) is typically parameterized by a hyperparameter β. In this paper, we introduce Multi-Rate VAE (MR-VAE), a computationally efficient framework for learning optimal parameters corresponding to various β in a single training run. The key idea is to explicitly formulate a response function using hypernetworks that maps β to the optimal parameters. MR-VAEs construct a compact response hypernetwork where the pre-activations are conditionally gated based on β. We justify the proposed architecture by analyzing linear VAEs and showing that it can represent response functions exactly for linear VAEs. With the learned hypernetwork, MR-VAEs can construct the rate-distortion curve without additional training and can be deployed with significantly less hyperparameter tuning. Empirically, our approach is competitive and often exceeds the performance of multiple β-VAEs training with minimal computation and memory overheads.

1. INTRODUCTION

Deep latent variable models sample latent factors from a prior distribution and convert them to realistic data points using neural networks. However, computing the model parameters via maximum likelihood estimation is challenging due to the need to marginalize the latent factors, which is intractable. Variational Autoencoders (VAEs) (Kingma & Welling, 2013; Rezende et al., 2014) formulate a tractable lower bound for the log-likelihood and enable optimization of deep latent variable models by reparameterization of the Evidence Lower Bound (ELBO) (Jordan et al., 1999) . VAEs have been applied in many different contexts, including text generation (Bowman et al., 2015) , data augmentation generation (Norouzi et al., 2020) , anomaly detection (An & Cho, 2015; Park et al., 2022) , future frame prediction (Castrejon et al., 2019) , image segmentation (Kohl et al., 2018) , and music generation (Roberts et al., 2018) . In practice, VAEs are typically trained with the β-VAE objective (Higgins et al., 2016) which balances the reconstruction error (distortion) and the KL divergence term (rate): L β (ϕ, θ) = E p d (x) [E q ϕ (z|x) [-log p θ (x|z)]] Distortion + β E p d (x) [D KL (q ϕ (z|x), p(z))] Rate , where p θ (x|z) models the process that generates the data x given the latent variable z (the "decoder") and q ϕ (z|x) is the variational distribution (the "encoder"), parameterized by θ ∈ R m and ϕ ∈ R p , respectively. Here, p(z) is a prior on the latent variables, p d (x) is the data distribution, and β > 0 is the weight on the KL term that trades off between rate and distortion. On the one hand, models with low distortions can reconstruct data points with high quality but may generate unrealistic data points due to large discrepancies between variational distributions and priors (Alemi et al., 2018) . On the other hand, models with low rates have variational distributions close to the prior but may not have encoded enough useful information to reconstruct the data. Hence, We first formulate response functions ϕ ⋆ (β) and θ ⋆ (β) (Gibbons et al., 1992) which map the KL weight β to the optimal encoder and decoder parameters trained with such β. Next, we explicitly construct response functions ϕ ψ (β) and θ ψ (β) using hypernetworks (Ha et al., 2016) , where ψ ∈ R h denotes hypernetwork parameters. Unlike the original VAE framework, which requires retraining the network to find optimal parameters for some particular β, response hypernetworks can directly learn this mapping and do not require further retraining. While there is a lot of freedom in designing the response hypernetwork, we propose a hypernetwork parameterization that is memory and cost-efficient yet flexible enough to represent the optimal parameters over a wide range of KL weights. Specifically, in each layer of a VAE, our MR-VAE architecture applies an affine transformation to log β and uses it to scale the pre-activation. We justify the proposed architecture by analyzing linear VAEs and showing that the MR-VAE architecture can represent the response functions on this simplified model. We further propose a modified objective analogous to Self-Tuning Networks (MacKay et al., 2019; Bae & Grosse, 2020) to optimize response hypernetworks instead of the standard encoder and decoder parameters. Empirically, we trained MR-VAEs to learn rate-distortion curves for image and text reconstruction tasks over a wide range of architectures. Across all tasks and architectures, MR-VAEs found competitive or even improved rate-distortion curves compared to the baseline method of retraining the network multiple times with different KL weights. We show a comparison between β-VAE (with and without KL annealing (Bowman et al., 2015) ) and MR-VAE with ResNet-based encoders and decoders (He et al., 2016) in Figure 1 . MR-VAEs can learn multiple optimal parameters corresponding to various KL weights in a single training run. Moreover, MR-VAEs do not require KL weight schedules and can be deployed without significant hyperparameter tuning. Our framework is general and can be extended to various existing VAE models. We demonstrate this flexibility by applying MR-VAEs to β-TCVAEs (Chen et al., 2018) , where we trade-off the reconstruction error and total correlation instead of the reconstruction error and rate. We show that MR-VAEs can be used to evaluate the disentanglement quality over a wide range of β values without having to train β-TCVAEs multiple times.

2. BACKGROUND

2.1 (β-) VARIATIONAL AUTOENCODERS Variational Autoencoders jointly optimize encoder parameters ϕ and decoder parameters θ to minimize the β-VAE objective defined in Eqn. 1. While the standard ELBO sets the KL weight to 1,



VAEs require multiple runs of training with different KL weights β to visualize parts of the rate-distortion curve (Pareto frontier). Our proposed Multi-Rate VAEs (MR-VAEs) can learn the full continuous rate-distortion curve in a single run with small memory and computational overhead.the KL weight β plays an important role in training VAEs and requires careful tuning for various applications(Kohl et al., 2018; Castrejon et al., 2019; Pong et al., 2019). The KL weighting term also has a close connection to disentanglement quality(Higgins et al., 2016; Burgess et al., 2018;  Nakagawa et al., 2021), generalization ability(Kumar & Poole, 2020; Bozkurt et al., 2021), data  compression (Zhou et al., 2018; Huang et al., 2020), and posterior collapse(Lucas et al., 2019; Dai  et al., 2020; Wang & Ziyin, 2022).By training multiple VAEs with different values of β, we can obtain different points on a ratedistortion curve (Pareto frontier) from information theory(Alemi et al., 2018). Unfortunately, as rate-distortion curves depend on both the dataset and architecture, practitioners generally need to tune β for each individual task. In this work, we introduce a modified VAE framework that does not require hyperparameter tuning on β and can learn multiple VAEs with different rates in a single training run. Hence, we call our approach Multi-Rate VAE (MR-VAE).

