DISENTANGLED GENERATIVE CAUSAL REPRESENTATION LEARNING

Abstract

This paper proposes a Disentangled gEnerative cAusal Representation (DEAR) learning method. Unlike existing disentanglement methods that enforce independence of the latent variables, we consider the general case where the underlying factors of interests can be causally correlated. We show that previous methods with independent priors fail to disentangle causally correlated factors. Motivated by this finding, we propose a new disentangled learning method called DEAR that enables causal controllable generation and causal representation learning. The key ingredient of this new formulation is to use a structural causal model (SCM) as the prior for a bidirectional generative model. The prior is then trained jointly with a generator and an encoder using a suitable GAN loss incorporated with supervision. Theoretical justification on the proposed formulation is provided, which guarantees disentangled causal representation learning under appropriate conditions. We conduct extensive experiments on both synthesized and real datasets to demonstrate the effectiveness of DEAR in causal controllable generation, and the benefits of the learned representations for downstream tasks in terms of sample efficiency and distributional robustness.

1. INTRODUCTION

Consider the observed data x from a distribution q x on X ⊆ R d and the latent variable z from a prior p z on Z ⊆ R k . In bidirectional generative models (BGMs), we are normally interested in learning an encoder E : X → Z to infer latent variables and a generator G : Z → X to generate data, to achieve both representation learning and data generation. Classical BGMs include Variational Autoencoder (VAE) (Kingma & Welling, 2014) and BiGAN (Donahue et al., 2017) . In representation learning, it was argued that an effective representation for downstream learning tasks should disentangle the underlying factors of variation (Bengio et al., 2013) . In generation, it is highly desirable if one can control the semantic generative factors by aligning them with the latent variables such as in StyleGAN (Karras et al., 2019) . Both goals can be achieved with the disentanglement of latent variable z, which informally means that each dimension of z measures a distinct factor of variation in the data (Bengio et al., 2013) . Earlier unsupervised disentanglement methods mostly regularize the VAE objective to encourage independence of learned representations (Higgins et al., 2017; Burgess et al., 2017; Kim & Mnih, 2018; Chen et al., 2018; Kumar et al., 2018 ). Later, Locatello et al. (2019) show that unsupervised learning of disentangled representations is impossible: many existing unsupervised methods are actually brittle, requiring careful supervised hyperparameter tuning or implicit inductive biases. To promote identifiability, recent work resorts to various forms of supervision (Locatello et al., 2020b; Shu et al., 2020; Locatello et al., 2020a) . In this work, we also incorporate supervision on the ground-truth factors in the form stated in Section 3.2. Most of these existing methods are built on the assumption that the underlying factors of variation are mutually independent. However, in many real world cases the semantically meaningful factors of interests are not independent (Bengio et al., 2020) . Instead, semantically meaningful high-level variables are often causally correlated, i.e., connected by a causal graph. In this paper, we prove formally that methods with independent priors fail to disentangle causally correlated factors. Motivated by this observation, we propose a new method to learn disentangled generative causal representations called DEAR. The key ingredient of our formulation is a structured causal model (SCM) (Pearl et al., 2000) as the prior for latent variables in a bidirectional generative model. With some background knowledge on the binary causal structure, the causal prior is then learned jointly with a generator and an encoder using a suitable GAN (Goodfellow et al., 2014) loss. We establish theoretical guarantees for DEAR to learn disentangled causal representations under appropriate conditions. An immediate application of DEAR is causal controllable generation, which can generate data from any desired interventional distributions of the latent factors. Another useful application of disentangled representations is to use such representations in downstream tasks, leading to better sample complexity (Bengio et al., 2013; Schölkopf et al., 2012) . Moreover, it is believed that causal disentanglement is invariant and thus robust under distribution shifts (Schölkopf, 2019; Arjovsky et al., 2019) . In this paper, we demonstrate these conjectures in various downstream prediction tasks for the proposed DEAR method, which has theoretically guaranteed disentanglement property. We summarize our main contributions as follows: • We formally identify a problem with previous disentangled representation learning methods using the independent prior assumption, and prove that they fail to disentangle when the underlying factors of interests are causally correlated. • We propose a new disentangled learning method, DEAR, which integrates an SCM prior into a bidirectional generative model, trained with a suitable GAN loss. • We provide theoretical justification on the identifiability of the proposed formulation. • Extensive experiments are conducted on both synthesized and real data to demonstrate the effectiveness of DEAR in causal controllable generation, and the benefits of the learned representations for downstream tasks in terms of sample efficiency and distributional robustness.

2. OTHER RELATED WORK

GAN-based disentanglement methods. Existing methods, including InfoGAN (Chen et al., 2016) and InfoGAN-CR (Lin et al., 2020) , differ from our proposed formulation mainly in two folds. First they still assume an independent prior for latent variables, so suffer from the same problem with previous VAE-based methods mentioned above. Besides, the idea of InfoGAN-CR is to encourage each latent code to make changes that are easy to detect, which actually applies well only when the underlying factors are independent. Second, InfoGAN as a bidirectional generative modeling method further requires variational approximation apart from adversarial training, which is inferior to the principled formulation in BiGAN and AGES (Shen et al., 2020) that we adopt. Causality with generative models. CausalGAN (Kocaoglu et al., 2018) and a concurrent work (Moraffah et al., 2020) of ours, are unidirectional generative models (i.e., a generative model that learns a single mapping from the latent variable to data) that build upon a cGAN (Mirza & Osindero, 2014) . They assign an SCM to the conditional attributes while leave the latent variables as independent Gaussian noises. The limit of a cGAN is that it always requires full supervision on attributes to apply conditional adversarial training. And the ground-truth factors are directly fed into the generator as the conditional attributes, without an extra effort to align the dimensions between the latent variables and the underlying factors, so their models have nothing to do with disentanglement learning. Moreover their unidirectional nature makes it impossible to learn representations. Besides they only consider binary factors whose consequent semantic interpolations appear nonsmooth, as shown in Appendix D. CausalVAE (Yang et al., 2020) assigns the SCM directly on the latent variables, but built upon iVAE (Khemakhem et al., 2020) , it adopts a conditional prior given the ground-truth factors so is also limited to fully supervised setting.

3.1. GENERATIVE MODEL

We first describe the probabilistic framework of disentangled learning with supervision. We follow the commonly assumed two-step data generating process that first samples the underlying generative factors, and then conditional on those factors, generates the data (Kingma & Welling, 2014) . During the generation process, the generator induces the generated conditional p G (x|z) and generated joint distribution p G (x, z) = p z (z)p G (x|z). During the inference process, the encoder induces the encoded conditional q E (z|x) which can be a factorized Gaussian and the encoded joint distribution

