DAVA: DISENTANGLING ADVERSARIAL VARIATIONAL AUTOENCODER

Abstract

The use of well-disentangled representations offers many advantages for downstream tasks, e.g. an increased sample efficiency, or better interpretability. However, the quality of disentangled interpretations is often highly dependent on the choice of dataset-specific hyperparameters, in particular the regularization strength. To address this issue, we introduce DAVA, a novel training procedure for variational auto-encoders. DAVA completely alleviates the problem of hyperparameter selection. We compare DAVA to models with optimal hyperparameters. Without any hyperparameter tuning, DAVA is competitive on a diverse range of commonly used datasets. Underlying DAVA, we discover a necessary condition for unsupervised disentanglement, which we call PIPE. We demonstrate the ability of PIPE to positively predict the performance of downstream models in abstract reasoning. We also thoroughly investigate correlations with existing supervised and unsupervised metrics. The code is available at github.com/besterma/dava.

1. INTRODUCTION

Real-world data tends to be highly structured, full of symmetries and patterns. This implies that there exists a lower-dimensional set of ground truth factors that is able to explain a significant portion of the variation present in real-world data. The goal of disentanglement learning is to recover these factors, so that changes in a single ground truth factor are reflected only in a single latent dimension of a model (see Figure 1 for an example). Such an abstraction allows for more efficient reasoning (Van Steenkiste et al., 2019) and improved interpretability (Higgins et al., 2017a) . It further shows positive effects on zero-shot domain adaption (Higgins et al., 2017b) and data efficiency (Duan et al., 2020; Schott et al., 2022) . If the generative ground-truth factors are known and labeled data is available, one can train a model in a supervised manner to extract the ground-truth factors. What if the generative factors are unknown, but one still wants to profit from the aforementioned benefits for a downstream task? This may be necessary when the amount of labeled data for the downstream task is limited or training is computationally expensive. Learning disentangled representations in an unsupervised fashion is generally impossible without the use of some priors (Locatello et al., 2019b) . These priors can be present both implicitly in the model architecture and explicitly in the loss function (Tschannen et al., 2018) . An example of such a prior present in the loss function is a low total correlation between latent variables of a model (Chen et al., 2018; Kim & Mnih, 2018) . Reducing the total correlation has been shown to have a positive effect on disentanglement (Locatello et al., 2019b) . Unfortunately, as we show in more detail in this work, how much the total correlation should be reduced to achieve good disentanglement is highly dataset-specific. The optimal hyperparameter setting for one dataset may yield poor results on another dataset. To optimize regularization strength, we need a way to evaluate disentanglement quality. So how can we identify well disentangled representations? Evaluating representation quality, even given labeled data, is no easy task. Perhaps as an example of unfortunate nomenclature, the often-used term "ground-truth factor" implies the existence of a canonical set of orthogonal factors. However, there are often multiple equally valid sets of ground truth factors, such as affine transformations of coordinate axes spanning a space, different color representations, or various levels of abstraction for group properties. This poses a problem for supervised disentanglement metrics, since they fix the ground truth factors for evaluating a representation and judge the models too harshly if they have learned another equally valid representation. Furthermore, acquiring labeled data in a practical setting is usually a costly endeavor. The above reasons hinder the usability of supervised metrics for model selection. In this work, we overcome these limitations for both learning and evaluating disentangled representations. Our improvements are based on the following idea: We define two distributions that can be generated by a VAE. Quantifying the distance between these two distributions yields a disentanglement metric that is independent of the specific choice of ground truth factors and reconstruction quality. The further away these two distributions are, the less disentangled the VAEs latent space is. We show that the similarity of the two distributions is a necessary condition for disentanglement. Furthermore, we can exploit this property at training time by introducing an adversarial loss into classical training of VAEs. To do this, we introduce a discriminator network into training and use the VAEs decoder as generator. During training, we control the weight of the adversarial loss. We adjust the capacity of the latent space information bottleneck accordingly, inspired by (Burgess et al., 2017) . In this way, we allow the model to increase the complexity of its representation as long as it is able to disentangle. This dynamic training procedure solves the problem of dataset-specific hyperparameters and allows our approach to reach competitive disentanglement on a variety of commonly used datasets without hyperparameter tuning. Our contributions are as follows: • We identify a novel unsupervised aspect of disentanglement called PIPE and demonstrate its usefulness in a metric with correlation to supervised disentanglement metrics as well as a downstream task. • We propose an adaptive adversarial training procedure (DAVA) for variational auto-encoders, which solves the common problem that disentanglement performance is highly dependent on dataset-specific regularization strength. 



Figure 1: Latent traversals of a single latent dimension (hair fringes) of DAVA trained on CelebA. DAVA visibly disentangles the fringes from all other facial properties.

• We provide extensive evaluations on several commonly used disentanglement datasets to support our claims. VAE by Higgins et al. (2017a) is a cornerstone model architecture for disentanglement learning. The loss function of the β-VAE, the evidence lower bound (ELBO), consists of a reconstruction term and a KL-divergence term weighted by β, which forces the aggregated posterior latent distribution to closely match the prior distribution. The KL-divergence term seems to promote disentanglement as shown in(Rolinek et al., 2019). The β-TCVAE architecture proposed by Chen et al. (2018) further decomposes the KL divergence term of the ELBO into an index-code mutual information, a total correlation and a dimension-wise KL term. They are able to show that it is indeed the total correlation that encourages disentanglement and propose a tractable but biased Monte Carlo estimate. Similarly, the FactorVAE architecture(Kim & Mnih, 2018)  uses the density ratio trick with an adversarial network to estimate total correlation. The AnnealedVAE architecture(Burgess et al., 2017)  as well

