DYNAMICVAE: DECOUPLING RECONSTRUCTION ER-ROR AND DISENTANGLED REPRESENTATION LEARN-ING

Abstract

This paper challenges the common assumption that the weight β, in β-VAE, should be larger than 1 in order to effectively disentangle latent factors. We demonstrate that β-VAE, with β < 1, can not only attain good disentanglement but also significantly improve reconstruction accuracy via dynamic control. The paper removes the inherent trade-off between reconstruction accuracy and disentanglement for β-VAE. Existing methods, such as β-VAE and FactorVAE, assign a large weight to the KL-divergence term in the objective function, leading to high reconstruction errors for the sake of better disentanglement. To mitigate this problem, a Con-trolVAE has recently been developed that dynamically tunes the KL-divergence weight in an attempt to control the trade-off to more a favorable point. However, ControlVAE fails to eliminate the conflict between the need for a large β (for disentanglement) and the need for a small β (for smaller reconstruction error). Instead, we propose DynamicVAE that maintains a different β at different stages of training, thereby decoupling disentanglement and reconstruction accuracy. In order to evolve the weight, β, along a trajectory that enables such decoupling, Dy-namicVAE leverages a modified incremental PI (proportional-integral) controller, a variant of proportional-integral-derivative controller (PID) algorithm, and employs a moving average as well as a hybrid annealing method to evolve the value of KL-divergence smoothly in a tightly controlled fashion. We theoretically prove the stability of the proposed approach. Evaluation results on three benchmark datasets demonstrate that DynamicVAE significantly improves the reconstruction accuracy while achieving disentanglement comparable to the best of existing methods. The results verify that our method can separate disentangled representation learning and reconstruction, removing the inherent tension between the two.

1. INTRODUCTION

The goal of disentangled representation learning is to encode input data into a low-dimensional space that preserves information about the salient factors of variation, so that each dimension of the representation corresponds to a distinct factor in the data (Bengio et al., 2013; Locatello et al., 2020; van Steenkiste et al., 2019) . Learning disentangled representations benefits a variety of downstream tasks (Higgins et al., 2018; Lake et al., 2017; Locatello et al., 2019c; a; Denton et al., 2017; Mathieu et al., 2019) , including abstract visual reasoning (van Steenkiste et al., 2019) , zero-shot transfer learning (Burgess et al., 2018; Lake et al., 2017; Higgins et al., 2017a) and image generation (Nie et al., 2020) , just to name a few. Due to its central importance in various downstream applications, there is abundant literature on learning disentangled representations. Roughly speaking, there are two lines of methods towards this goal. The first category includes supervised methods (Chen & Batmanghelich, 2019; Locatello et al., 2019c; Shu et al., 2019; Bouchacourt et al., 2018; Nie et al., 2020; Yang et al., 2015) , where external supervision (e.g., data generative factors) is available during training to guide the learning of disentangled representations. The second line of works focus on unsupervised methods (Chen et al., 2016; 2018; Burgess et al., 2018; Kim & Mnih, 2018; Denton et al., 2017; Kumar et al., 2018; Fraccaro et al., 2017) , which substantially relieve the needs to have external supervisions. For this reason, in this paper, we mainly focus on unsupervised disentangled representation learning. One major challenge of unsupervised disentanglement learning is that there exists a trade-off between reconstruction quality of the input signal and the degree of disentanglement in the latent representations. Let us take β-VAE and its variants (Burgess et al., 2018; Chen et al., 2018; Higgins et al., 2017a) as an example. These methods assign a large and fixed weight β in the objective function to improve the disentanglement at the cost of reconstruction quality, which is highly correlated with accuracy in downstream tasks (van Steenkiste et al., 2019; Locatello et al., 2020) . In order to improve the reconstruction quality, researchers have proposed a dynamic learning approach, Con-trolVAE (Shao et al., 2020) , to dynamically adjust the weight on the KL term in the VAE objective to better balance the quality of disentangled representation learning and reconstruction error. However, while ControlVAE allows better control of the trade-off between disentangled representation learning and reconstruction error, it does not eliminate it. One is still achieved at the expense of the other. The contribution of this paper, compared to the above state of the art, lies in demonstrating that with the proper design, the trade-off between disentangled representation learning and reconstruction error is completely eliminated. Both objectives can be attained at the same time in a decoupled fashion, without affecting each other. More specifically, we observe that if β was kept high in the beginning of training then lowered later in the process, the two objectives are decoupled allowing each to be independently optimized. To the authors' knowledge, this work is the first to attain such decoupled optimization of both quality of disentanglement and reconstruction error. Our Contributions: In this paper, we propose a novel unsupervised disentangled representation learning method, dubbed as DynamicVAE, that turns the weight of β-VAE (β > 1) (Burgess et al., 2018; Higgins et al., 2017a ) into a small value (β ≤ 1) to achieve not only good disentanglement but also a high reconstruction accuracy via dynamic control. We summarize the main contributions of this paper as follows. • We propose a new model, DynamicVAE, that leverages an incremental PI controller and moving average to evolve the desired KL-divergence along a trajectory that enables decoupling of two objectives: high-quality disentanglement and low reconstruction error. • We provide the theoretical conditions on parameters of the PI controller to guarantee stability of DynamicVAE. • We experimentally demonstrate that our approach turns the weight of β-VAE (β > 1) to β ≤ 1, achieving higher reconstruction quality yet comparable disentanglement compared to prior approaches (e.g., FactorVAE). Thus, our results verify that the proposed method indeed decouples disentanglement and reconstruction accuracy without hurting each other's performance.

2. PRELIMINARIES

β-VAE and its Variants: β-VAE (Higgins et al., 2017b; Chen et al., 2018 ) is a popular unsupervised method for learning disentangled representations of the data generative factors (Bengio et al., 2013) . Compared to the original VAE, β-VAE incorporates an extra hyperparameter β(β > 1) as the weight of the KL term in the VAE objective: L β = E q φ (z|x) [log p θ (x|z)] -βD KL (q φ (z|x) p(z)). In order to discover more disentangled factors, in other variants, practitioners further add a constraint on the total information capacity, C, to control the capacity of the latent channels (Burgess et al., 2018) to transmit information. The constraint can be formulated as an optimization method: L β = E q φ (z|x) [log p θ (x|z)] -β • |D KL (q φ (z|x) p(z)) -C|, where β is a large and fixed hyperparameter. As a result, when the weight β is large (e.g. 100), the algorithm tends to optimize the second term in (2), leading to much higher reconstruction error.

PID Control Algorithm:

The PID is a simple yet effective control algorithm that can stabilize system output to a desired value via feedback control (Stooke et al., 2020; Åström et al., 2006) . The PID algorithm calculates an error, e(t), between a set point (in this case, the desired KL-divergence) and the current value of the controlled variable (in this case, the actual KL-divergence), then applies a correction in a direction that reduces that error. The correction is the weighted sum of three terms, one proportional to the error (called P), one that is the integral of error (called I), and one that is the derivative of error (called D); thus, the term PID. The derivative term is not recommended for noisy systems, such as ours, reducing the algorithm to PI control. The canonical form of a PI controller

