DISENTANGLED GENERATIVE CAUSAL REPRESENTATION LEARNING

Abstract

This paper proposes a Disentangled gEnerative cAusal Representation (DEAR) learning method. Unlike existing disentanglement methods that enforce independence of the latent variables, we consider the general case where the underlying factors of interests can be causally correlated. We show that previous methods with independent priors fail to disentangle causally correlated factors. Motivated by this finding, we propose a new disentangled learning method called DEAR that enables causal controllable generation and causal representation learning. The key ingredient of this new formulation is to use a structural causal model (SCM) as the prior for a bidirectional generative model. The prior is then trained jointly with a generator and an encoder using a suitable GAN loss incorporated with supervision. Theoretical justification on the proposed formulation is provided, which guarantees disentangled causal representation learning under appropriate conditions. We conduct extensive experiments on both synthesized and real datasets to demonstrate the effectiveness of DEAR in causal controllable generation, and the benefits of the learned representations for downstream tasks in terms of sample efficiency and distributional robustness.

1. INTRODUCTION

Consider the observed data x from a distribution q x on X ⊆ R d and the latent variable z from a prior p z on Z ⊆ R k . In bidirectional generative models (BGMs), we are normally interested in learning an encoder E : X → Z to infer latent variables and a generator G : Z → X to generate data, to achieve both representation learning and data generation. Classical BGMs include Variational Autoencoder (VAE) (Kingma & Welling, 2014) and BiGAN (Donahue et al., 2017) . In representation learning, it was argued that an effective representation for downstream learning tasks should disentangle the underlying factors of variation (Bengio et al., 2013) . In generation, it is highly desirable if one can control the semantic generative factors by aligning them with the latent variables such as in StyleGAN (Karras et al., 2019) . Both goals can be achieved with the disentanglement of latent variable z, which informally means that each dimension of z measures a distinct factor of variation in the data (Bengio et al., 2013) . Earlier unsupervised disentanglement methods mostly regularize the VAE objective to encourage independence of learned representations (Higgins et al., 2017; Burgess et al., 2017; Kim & Mnih, 2018; Chen et al., 2018; Kumar et al., 2018) . Later, Locatello et al. (2019) show that unsupervised learning of disentangled representations is impossible: many existing unsupervised methods are actually brittle, requiring careful supervised hyperparameter tuning or implicit inductive biases. To promote identifiability, recent work resorts to various forms of supervision (Locatello et al., 2020b; Shu et al., 2020; Locatello et al., 2020a) . In this work, we also incorporate supervision on the ground-truth factors in the form stated in Section 3.2. Most of these existing methods are built on the assumption that the underlying factors of variation are mutually independent. However, in many real world cases the semantically meaningful factors of interests are not independent (Bengio et al., 2020) . Instead, semantically meaningful high-level variables are often causally correlated, i.e., connected by a causal graph. In this paper, we prove formally that methods with independent priors fail to disentangle causally correlated factors. Motivated by this observation, we propose a new method to learn disentangled generative causal representations called DEAR. The key ingredient of our formulation is a structured causal model (SCM) (Pearl et al., 2000) as the prior for latent variables in a bidirectional generative model. With some background knowledge on the binary causal structure, the causal prior is then learned jointly with a generator and an encoder using a suitable GAN (Goodfellow et al., 2014) loss. We establish theoretical guarantees for DEAR to learn disentangled causal representations under appropriate conditions. An immediate application of DEAR is causal controllable generation, which can generate data from any desired interventional distributions of the latent factors. Another useful application of disentangled representations is to use such representations in downstream tasks, leading to better sample complexity (Bengio et al., 2013; Schölkopf et al., 2012) . Moreover, it is believed that causal disentanglement is invariant and thus robust under distribution shifts (Schölkopf, 2019; Arjovsky et al., 2019) . In this paper, we demonstrate these conjectures in various downstream prediction tasks for the proposed DEAR method, which has theoretically guaranteed disentanglement property. We summarize our main contributions as follows: • We formally identify a problem with previous disentangled representation learning methods using the independent prior assumption, and prove that they fail to disentangle when the underlying factors of interests are causally correlated. • We propose a new disentangled learning method, DEAR, which integrates an SCM prior into a bidirectional generative model, trained with a suitable GAN loss. • We provide theoretical justification on the identifiability of the proposed formulation. • Extensive experiments are conducted on both synthesized and real data to demonstrate the effectiveness of DEAR in causal controllable generation, and the benefits of the learned representations for downstream tasks in terms of sample efficiency and distributional robustness.

2. OTHER RELATED WORK

GAN-based disentanglement methods. Existing methods, including InfoGAN (Chen et al., 2016) and InfoGAN-CR (Lin et al., 2020) , differ from our proposed formulation mainly in two folds. First they still assume an independent prior for latent variables, so suffer from the same problem with previous VAE-based methods mentioned above. Besides, the idea of InfoGAN-CR is to encourage each latent code to make changes that are easy to detect, which actually applies well only when the underlying factors are independent. Second, InfoGAN as a bidirectional generative modeling method further requires variational approximation apart from adversarial training, which is inferior to the principled formulation in BiGAN and AGES (Shen et al., 2020 ) that we adopt. Causality with generative models. CausalGAN (Kocaoglu et al., 2018) and a concurrent work (Moraffah et al., 2020) of ours, are unidirectional generative models (i.e., a generative model that learns a single mapping from the latent variable to data) that build upon a cGAN (Mirza & Osindero, 2014) . They assign an SCM to the conditional attributes while leave the latent variables as independent Gaussian noises. The limit of a cGAN is that it always requires full supervision on attributes to apply conditional adversarial training. And the ground-truth factors are directly fed into the generator as the conditional attributes, without an extra effort to align the dimensions between the latent variables and the underlying factors, so their models have nothing to do with disentanglement learning. Moreover their unidirectional nature makes it impossible to learn representations. Besides they only consider binary factors whose consequent semantic interpolations appear nonsmooth, as shown in Appendix D. CausalVAE (Yang et al., 2020) assigns the SCM directly on the latent variables, but built upon iVAE (Khemakhem et al., 2020) , it adopts a conditional prior given the ground-truth factors so is also limited to fully supervised setting.

3.1. GENERATIVE MODEL

We first describe the probabilistic framework of disentangled learning with supervision. We follow the commonly assumed two-step data generating process that first samples the underlying generative factors, and then conditional on those factors, generates the data (Kingma & Welling, 2014) . During the generation process, the generator induces the generated conditional p G (x|z) and generated joint distribution p G (x, z) = p z (z)p G (x|z). During the inference process, the encoder induces the encoded conditional q E (z|x) which can be a factorized Gaussian and the encoded joint distribution q E (x, z) = q x (x)q E (z|x). We consider the following objective for generative modeling: L gen = D KL (q E (x, z), p G (x, z)), (1) which is shown to be equivalent to the evidence lower bound used in VAEs up to a constant, and allows a closed form only with factorized Gaussian prior, encoder and generator (Shen et al., 2020) . Since constraints on the latent space are required to enforce disentanglement, it is desired that the distribution family of q E (x, z) and p G (x, z) should be large enough, especially for complex data like images. Normally more general implicit distributions are favored over factorized Gaussians in terms of expressiveness (Karras et al., 2019; Mescheder et al., 2017) . Then minimizing (1) requires adversarial training, as discussed detailedly in Section 4.3.

3.2. SUPERVISED REGULARIZER

To guarantee disentanglement, we incorporate supervision when training the BGM, following the similar idea in Locatello et al. (2020b) but with a different formulation. Specifically, let ξ ∈ R m be the underlying ground-truth factors of interests of x, following distribution p ξ , and y i be some continuous or discrete observation of the underlying factor ξ i satisfying ξ i = E(y i |x) for i = 1, . . . , m. For example, in the case of human face images, y 1 can be the binary label indicating whether a person is young or not, and ξ 1 = E(y 1 |x) = P(y 1 = 1|x) is the probability of being young given one image x. Let Ē(x) be the deterministic part of the stochastic transformation E(x), i.e., Ē(x) = E(E(x)|x), which is used for representation learning. We consider the following objective: L(E, G) = L gen (E, G) + λL sup (E), where L sup = ! m i=1 E x,y [CE( Ēi (x), y i )] if y i is the binary or bounded continuous label of the i-th factor ξ i , where CE(l, y) = y log σ(l) + (1 -y) log(1 -σ(l)) is the cross-entropy loss with σ(•) being the sigmoid function; L sup = ! m i=1 E x,y [ Ēi (x) -y i ] 2 if y i is the continuous observation of ξ i , and λ > 0 is the coefficient to balance both terms. We empirically find the choice of λ quite insensitive to different tasks and datasets, and hence set λ = 5 in all experiments. Estimating of L gen requires the unlabelled dataset {x 1 , . . . , x N } while estimating L sup requires a labeled dataset {(x j , y j ) : j = 1, . . . , N s } where N s can be much smaller than N . In contrast, Locatello et al. (2020b) propose the regularizer L sup = ! m i=1 E x,z [CE( Ēi (x), z i )] involving only the latent variable z which is a part of the generative model, without distinguishing from the ground-truth factor ξ and its observation y. Hence they do not establish any theoretical justification on disentanglement. Besides, they adopt a VAE loss for L gen with an independent prior, which suffers from the unidentifiability problem described in the next section.

3.3. UNIDENTIFIABILITY WITH AN INDEPENDENT PRIOR

Intuitively, the above supervised regularizer aims at ensuring some alignment between factor ξ and latent variable z. We start with the definition of a disentangled representation following this intuition. Definition 1 (Disentangled representation). Given the underlying factor ξ ∈ R m of data x, a deterministic encoder E is said to learn a disentangled representation with respect to ξ if ∀i = 1, . . . , m, there exists a 1-1 function g i such that E i (x) = g i (ξ i ). Further, a stochastic encoder E is said to be disentangled wrt ξ if its deterministic part Ē(x) is disentangled wrt ξ. As stated above, we consider the general case where the underlying factors of interests are causally correlated. Then the goal becomes to disentangle the causal factors. Previous methods mostly use an independent prior for z, which contradicts with the truth. We make this formal through the following proposition, which indicates that the disentangled representation is generally unidentifiable with an independent prior. Proposition 1. Let E * be any encoder that is disentangled wrt ξ. Let b * = L sup (E * ), a = min G L gen (E * , G), and b = min {(E,G):Lgen=0} L sup (E). Assume the elements of ξ are connected by a causal graph whose adjacency matrix A 0 is not a zero matrix. Suppose the prior p z is factorized, i.e., p z (z) = " k i=1 p i (z i ). Then we have a > 0, and either when b * ≥ b or b * < b and λ < a b-b * , there exists a solution (E ′ , G ′ ) such that for any generator G, we have L(E ′ , G ′ ) < L(E * , G). All proofs are given in Appendix A. This proposition directly suggests that minimizing (2) favors the solution (E ′ , G ′ ) over one with a disentangled encoder E * . Thus, with an independent prior we have no way to identify the disentangled solution with λ that is not large enough. However, in real applications it is impossible to estimate the threshold, and too large λ makes it difficult to learn the BGM. In the following section we propose a solution to this problem.

4.1. GENERATIVE MODEL WITH A CAUSAL PRIOR

We propose to use a causal model as the prior p z . Specifically we use the generalized nonlinear Structural Causal Model (SCM) proposed by Yu et al. (2019) as follows z = f ((I -A ⊤ ) -1 h($)) := F β ($), where A is the weighted adjacency matrix of the directed acyclic graph (DAG) upon the k elements of z (i.e., A ij ∕ = 0 if and only if z i is the parent of z j ), $ denotes the exogenous variables following N (0, I), f and h are element-wise nonlinear transformations, and β = (f, h, A) denotes the set of parameters of f , h and A, with the parameter space B. Further let 1 A = I(A ∕ = 0) denote the corresponding binary adjacency matrix, where I is the element-wise indicator function. When f is invertible, (3) is equivalent to f -1 (z) = A ⊤ f -1 (z) + h($) ) which indicates that the factors z satisfy a linear SCM after nonlinear transformation f , and enables interventions on latent variables as discussed later. The model structure is presented in Figure 1 . Note that different from our model where z is the latent variable following the prior (3) with the goal of causal disentanglement, Yu et al. (2019) proposed a causal discovery method where variables z are observed with the aim of learning the causal structure among z. In causal structure learning, the graph is required to be acyclic. Zheng et al. (2018) propose an equality constraint whose satisfaction ensures acyclicity and solve the problem with augmented Lagrangian method, which however leads to optimization difficulties (Ng et al., 2020) . In this paper, to avoid dealing with the non-convex constraint but focus on disentangling, we assume to have some prior knowledge of the binary causal structure. Specifically, we assume the super-graph of the true binary graph 1 A * is given, the best case of which is the true graph while the worst is that only the causal ordering is available. Then we learn the weights of the non-zero elements of the prior adjacency matrix that indicate the sign and scale of causal effects, jointly with other parameters using the formulation and algorithm described in later sections. To incorporate structure learning methods and jointly learn the structure from scratch with guarantee of identifiability could be explored in future work. An ablation study is done in Appendix B regarding this prior knowledge. To enable causal controllable generation, we use invertible f and h and describe the mechanism to generate images from interventional distributions of latent variables. Note that interventions can be formalized as operations that modify a subset of equations in (4) (Pearl et al., 2000) . Suppose we would like to intervene on the i-th dimension of z, i.e., Do(z i = c), where c is a constant. Once we have latent factors z inferred from data x, i.e., z = E(x), or sampled from prior p z , we follow the intervened equations in (4) to obtain z ′ on the left hand side using ancestral sampling by performing (4) iteratively. Then we decode the intervened latent factor z ′ to generate the sample G(z ′ ). In Section 5.1 we define the two types of interventions of most interests in applications. Another issue of the model is the latent dimension, to handle which we propose the so-called composite prior. Recall that m is the number of generative factors that we are interested to disentangle, e.g., all the semantic concepts related to some field, where m tends to be smaller than the total number M of generative factors. The latent dimension k of the generative model should be no less than M to allow a sufficient degree of freedom in order to generate or reconstruct data well. Since M is generally unknown in reality, we set a sufficiently large k, at least larger than m which is a trivial lower bound of M . The role of the remaining k -m dimensions is to capture other factors necessary for generation whose structure is not cared or explicitly modeled. Then we propose to use a prior that is a composition of a causal model for the first m dimensions and another distribution for the other k -m dimensions to capture other factors necessary for generation, like a standard Gaussian.

4.2. FORMULATION AND IDENTIFIABILITY OF DISENTANGLEMENT

In this section, we present the formulation of DEAR and establish the theoretical justification of it. Compared with the BGM described in Section 3.1, here we have one more module to learn that is the SCM prior. Thus p G (x, z) becomes p G,F (x, z) = p F (z)p G (x|z) where p F (z) or p β (z) denotes the marginal distribution of F β ($) with $ ∼ N (0, I). We then rewrite the generative loss as follows L gen (E, G, F ) = D KL (q E (x, z), p G,F (x, z)). Then we propose the following formulation to learn causal generative causal representations: min E,G,F L(E, G, F ) := L gen (E, G, F ) + λL sup (E). In order to achieve causal disentanglement, we make two assumptions on the causal model. Assumption 1 supposes a sufficiently large capacity of the SCM in (3) to contain the underlying distribution p ξ , which is reasonable due to the generalization of the nonlinear SCM. Assumption 2 states the identifiability of the true causal structure 1 A0 of ξ, which is applicable given the true causal ordering under basic Markov and causal minimality conditions (Pearl, 2014; Zhang & Spirtes, 2011) . Assumption 1 (SCM capacity). The underlying distribution p ξ belongs to the distribution family {p β : β ∈ B}, i.e., there exits β 0 = (f 0 , h 0 , A 0 ) such that p ξ = p β0 . Assumption 2 (Structure identifiability). For all β = (f, h, A) ∈ B with p β = p β0 , it holds that 1 A = 1 A0 . The following theorem then guarantees that under appropriate conditions the DEAR formulation can learn the disentangled representations defined in Definition 1. Theorem 1. Assume the infinite capacity of E and G. Further under Assumption 1-2, DEAR formulation (6) learns the disentangled encoder E * . Specifically, we have g i (ξ i ) = σ -1 (ξ i ) if CE loss is used for the supervised regularizer, and g i (ξ i ) = ξ i if L 2 loss is used. Note that the identifiability we establish in this paper differs from some previous work on the parameter identifiability, e.g., Khemakhem et al. (2020) . We argue that to learn disentangled representations, the form in Definition 1, i.e., the existence but not the uniqueness of g i 's, is sufficient to identify the relation among the representations and the data. In contrast, parameter identifiability may not be achievable in many cases like over-parametrization. Thus the identifiability discussed here is more realistic in terms of the goal of disentangling. Later we provide empirical evidence to support the theory directly through the application in causal controllable generation.

4.3. ALGORITHM

In this section we propose the algorithm to solve the formulation (6). The SCM prior p F (z) and implicit generated conditional p G (x|z) make (5) lose an analytic form. Hence we adopt a GAN method to adversarially estimate the gradient of (5). We parametrize E φ (x) and G θ (z) by neural networks. Different from Shen et al. (2020) , the prior also involves learnable parameters. We present in the following lemma the gradient formulas of (5). Lemma 1. Let r(x, z) = q(x, z)/p(x, z) and D(x, z) = log r(x, z). Then we have ∇ θ Lgen = -E z∼p β (z) [s(x, z)∇xD(x, z) ⊤ | x=G θ (z) ∇ θ G θ (z)], ∇ φ Lgen = Ex∼q x [∇zD(x, z) ⊤ | z=E φ (x) ∇ φ E φ (x)], ∇ β Lgen = -E$[s(x, z)(∇xD(x, z) ⊤ ∇ β G(F β (!)) + ∇zD(x, z) ⊤ ∇ β F β (!))| x=G(F β ($)) z=F β ($) ], where s(x, z) = e D(x,z) is the scaling factor. We then estimate the gradients in ( 7) by training a discriminator D ψ via empirical logistic regression: min ψ ′ [ 1 |Se| ! (x,z)∈Se log(1 + e -D ψ ′ (x,z) ) + 1 |Sg| ! (x,z)∈Sg log(1 + e D ψ ′ (x,z) )], where S e and S g are finite samples from q E (x, z) and p G (x, z) respectively, leading to a GAN approach. Based on above, we propose Algorithm 1 to learn disentangled generative causal representation. Algorithm 1: Disentangled gEnerative cAusal Representation (DEAR) Learning  1 n ! n i=1 ∇ ψ " log(1 + e -D ψ (x i ,E φ (x i )) ) + log(1 + e D ψ (G θ (z i ),z i ) ) # 6 Sample {x1, . . . , xn, y1, . . . , yn s }, {!1, . . . , !n} as above; generate zi = F β (!i) 7 Compute θ-gradient: -1 n ! n i=1 s(G θ (zi), zi)∇ θ D ψ (G θ (zi), zi) 8 Compute φ-gradient: 1 n ! n i=1 ∇ φ D ψ (xi, E φ (xi)) + 1 ns ! ns i=1 ∇ φ Lsup(φ; xi, yi) 9 Compute β-gradient: -1 n ! n i=1 s(G(zi), zi)∇ β D ψ (G θ (F β (!i)), F β (!i)) 10 Update parameters φ, θ, β using the gradients Return: φ, θ, β Remark: without loss of generality, assume the first Ns samples in the training set and the first ns samples in each mini-batch has available labels; ns may vary across different iterations.

5. EXPERIMENTS

We evaluate our methods on two datasets. The first one is a synthesized dataset Pendulum similar to the one in Yang et al. (2020) . As shown in Figure 3 , each image is generated by four continuous factors: pendulum angle, light angle, shadow length and shadow position whose underlying structure is given in Figure 2 (a) following physical mechanisms. To make the dataset realistic, we introduce random noises when generating the two effects from the causes, representing the measurement error. We further introduce 20% corrupted data whose shadow is randomly generated, mimicking some environmental disturbance. The sample sizes for training, validation and test set are all 6,724.foot_0  The second one is a real human face dataset CelebA (Liu et al., 2015) , containing 202,599 images with 40 labelled binary attributes. Among them we consider two groups of causally correlated factors shown in 2(b,c). We believe these two datasets are diverse enough to assess our methods. All the details of experimental setup and architectures are given in Appendix C. 

5.1. CONTROLLABLE GENERATION

We first investigate the performance of our methods in disentanglement through applications in causal controllable generation (CG). Traditional CG methods mainly manipulate the independent generative factors (Karras et al., 2019) , while we consider the general case where the factors are causally correlated. With a learned SCM as the prior, we are able to generate images from any desired interventional distributions of the latent factors. For example, we can manipulate only the cause factor while leave its effects unchanged. Besides, the BGM framework enables controllable generation either from scratch or a given unlabeled image. We consider two types of intervention. In traditional traversals, we manipulate one dimension of the latent vector while keep the others fixed to either their inferred or sampled values (Higgins et al., 2017) . A causal view of such operations is an intervention on all the variables by setting them as constants with only one of them varying. Another interesting type of interventional distribution is to intervene on only one latent variable, i.e., P do(zi=c) (z). The proposed SCM prior enables us to conduct such intervention though the mechanism given in Section 4.1. Figure 3 -4 illustrate the results of causal controllable generation of the proposed DEAR and the baseline method with an independent prior, S-β-VAE (Locatello et al., 2020b) . Results from other baselines including S-TCVAE, S-FactorVAE (which essentially make no difference due to the independence assumption) and CausalGAN are given in Appendix D. Note that we do not compare with unsupervised disentanglement methods because of fairness and their lack of justification. In each figure, we first infer the latent representations from a test image in block (c). The traditional traversals of two models are given in blocks (a,b). We see that in each line when manipulating one latent dimension, the generated images from our model vary only in a single factor, indicating that our method can disentangle the causally correlated factors. It is worth pointing out that we are the first to achieve the disentanglement between the cause and its effects, while other methods tend to entangle them. In block (d), we show the results of intervention on the latent variables representing the cause factors, which clearly show that intervening on a cause variable changes its effect variables. Results in Appendix D further show that intervening on an effect node does not influence its cause. Since the underlying factors are causally correlated, all previous quantitative metrics for disentanglement no longer apply. We provide more qualitative traversals in Appendix D to show the overall performance. A quantitative metric for causal disentanglement is worth exploring in future work. 

5.2. DOWNSTREAM TASK

The previous section verifies the good disentanglement performance of DEAR. In this section, equipped with DEAR, we investigate and demonstrate the benefits of learned disentangled causal representations in sample efficiency and distributional robustness. We state the downstream tasks. On CelebA, we consider the structure CelebA-Attractive in Figure 2(c) . We artificially create a target label τ = 1 if young=1, gender=0, receding hairline=0, make up=1, chubby=0, eye bag=0, and τ = 0 otherwise, indicating one kind of attractiveness as a slim young woman with makeup and thick hair.foot_1 On the pendulum dataset, we regard the label of data corruption as the target τ , i.e., τ = 1 if the data is corrupted and τ = 0 otherwise. We consider the downstream tasks of predicting the target label. In both cases, the factors of interests in Figure 2 (a,c) are causally related to τ , which are the features that humans use to do the task. Hence it is conjectured that a disentangled representation of these causal factors tends to be more data efficient and invariant to distribution shifts.

5.2.1. SAMPLE EFFICIENCY

For a BGM including the previous state-of-the-art supervised disentangling methods S-VAEs (Locatello et al., 2020b) and DEAR, we use the learned encoder to embed the training data to the latent space and train a MLP classifier on the representations to predict the target label. Without an encoder, one normally needs to train a convolutional neural network with raw images as the input. Here we adopt the ResNet50 as the baseline classifier which is the architecture of the BGM encoder. Since disentangling methods use additional supervision of the generative factors, we consider another baseline that is pretrained using multi-label prediction of the factors on the same training set. To measure the sample efficiency, we use the statistical efficiency score defined as the average test accuracy based on 100 samples divided by the average accuracy based on 10,000/all samples, following Locatello et al. (2019) . Table 1 presents the results, showing that DEAR owns the highest sample efficiency on both datasets. ResNet with raw data inputs has the lowest efficiency, although multi-label pretraining improves its performance to a limited extent. S-VAEs have better efficiency than the ResNet baselines but lower accuracy under the case with more training data, which we think is mainly because the independent prior conflicts with the supervised loss as indicated in Proposition 1, making the learned representations entangled (as shown in the previous section) and less informative. Besides, we also investigate the performance of DEAR under the semi-supervised setting where only 10% of the labels are available. We find that DEAR with fewer labels has comparable sample efficiency with that in the fully supervised setting, with a sacrifice in accuracy that is yet still comparable to other baselines with more supervision. Intuitively these spurious attributes are not causally correlated to the target label, but normal independent and identically distributed (IID) based methods like empirical risk minimization (ERM) tend to exploit these easily learned spurious correlations in prediction, and hence face performance degradation when the such correlation no longer exists during test. In contrast, causal factors are regarded invariant and thus robust under such shifts. Previous sections justify both theoretically and empirically that DEAR can learn disentangled causal representations. We then apply those representations by training a classifier upon them, which is conjectured to be invariant and robust. Baseline methods include ERM, multi-label ERM to predict target label and all the factors considered in disentangling to have the same amount of supervision, and S-VAEs that can not disentangle well in the causal case. Table 2 shows the average and worst-case (Sagawa et al., 2019) test accuracy to assess both the overall classification performance and distributional robustness, where we group the test set according to the two binary labels, the target one and the spurious attribute, into four cases and regard the one with the worst accuracy as the worst-case, which usually owns the opposite correlation to the training data. We see that the classifiers trained upon DEAR representations outperform the baselines in both metrics. Particularly, when comparing the worst-case accuracy with the average one, we observe a slump from around 80 to around 60 for other methods on CelebA, while DEAR enjoys an acceptable small decline. These results support the above conjecture and the benefits of causal disentanglement in distributional robustness.

6. CONCLUSION

This paper showed that previous methods with the independent latent prior assumption fail to learn disentangled representation when the underlying factors of interests are causally correlated. We then proposed a new disentangled learning method called DEAR with theoretical guarantees. Extensive experiments demonstrated the effectiveness of DEAR in causal generation, and the benefits of the learned representations for downstream tasks. 

APPENDIX A PROOFS

A.1 PROOF OF PROPOSITION 1 Proof. On one hand, by the assumption that the elements of ξ are connected by a causal graph whose adjacency matrix is not a zero matrix. Then exist i ∕ = j such that ξ i and ξ j are not independent, indicating that the probability density of ξ cannot be factorized. Since E * is disentangled wrt ξ, by Definition 1, ∀i = 1, . . . , m there exists g i such that E * i (x) = g i (ξ i ). This implies that the probability density of E * (x) is not factorized. On the other hand, notice that the distribution family of the latent prior is {p z : p z is factorized}. Hence the intersection of the marginal distribution families of z and E * (x) is an empty set. Then the joint distribution families of (x, E * (x)) and (G(z), z) also have an empty intersection. We know that L gen (E, G) = 0 implies q E (x, z) = p G (x, z) which contradicts the above. Therefore, we have a = min G L gen (E * , G) > 0. Let (E ′ , G ′ ) be the solution of the optimization problem min {(E,G):Lgen=0} L sup (E). Then we have Proof. Assume E is deterministic. L ′ = L(E ′ , G ′ ) = On one hand, for each i = 1, . . . , m, first consider the cross-entropy loss L sup,i (E) = E (x,y) [CE(E i (x), y i )] = # p(x)p(y i |x)(y i log σ(E i (x))+(1-y i ) log(1-σ(E i (x))))dxdy i , where p(y i |x) is the probability mass function of the binary label y i given x, characterized by P(y i = 1|x) = E(y i |x) and P(y i = 0|x) = 1 -E(y i |x). Let ∂L sup,i ∂σ(E i (x)) = # p(x)p(y i |x) $ y i 1 σ(E i )(1 -σ(E i )) - 1 1 -σ(E i ) % dxdy i = 0. Then we know that E * i (x) = σ -1 (E(y i |x)) = σ -1 (ξ i ) minimizes L sup,i . Consider the L 2 loss L sup,i (φ) = E (x,y) [ Ēi (x) -y i ] 2 = # p(x)p(y i |x))E i (x) -y i ) 2 dxdy i . Let ∂L sup,i ∂E i (x) = 2 # p(x)p(y i |x)(E i (x) -y i )dxdy i = 0. Then we know that E * i (x) = E(y i |x) = ξ i minimizes L sup,i in this case. On the other hand, by Assumption 1 there exists β 0 = (f 0 , h 0 , A 0 ) such that p ξ = p β0 . Then the distribution of E * (x) is given by p β * with β * = (g • f 0 , h 0 , A 0 ). Assumption 2 ensures that there is no β ′ = (f ′ , h ′ , A ′ ) such that A ′ ∕ = A 0 but p β ′ = p β * . Let F * = F β * . Further due to the infinite capacity of G, we have the distribution family of p G,F * (x, z) contains q E * (x, z). Then by minimizing the loss in (6) over G, we can find G * such that p G * ,F * (x, z) matches q E * (x, z) and thus L gen (E * , G * , F * ) reaches 0. Hence minimizing L = L gen + λL sup , which is the DEAR formulation (6), leads to the solution with E * i (x) = g i (ξ i ) with g i (ξ i ) = σ -1 (ξ i ) if CE loss is used, and g i (ξ i ) = ξ i if L 2 loss is used, and the true binary adjacency matrix. For a stochastic encoder, we establish the disentanglement of its deterministic part as above, and follow Definition 1 to obtain the desired result. Since y ′ is arbitrary, above implies that p ′ (x, z) = p(x, z) -∆ ⊤ ( fβ (x), f β (z)) • (∇ x p(x, z), ∇ z p(x, z)) ⊤ • ∇ x p(x, z) -∆ ⊤ p(x, z)(∇ • fβ (x ′ ), ∇ • f β (z ′ )) ⊤ + o(δ) for all x ∈ R d , z ∈ R k and i = 1, . . . , l, leading to (10) by taking δ → 0, and noting that p = p β and p ′ = p β+∆ . Similarly we can obtain (8) and ( 9). Proof of Lemma 1. Recall the objective D KL (q, p) = ) q(x, z) log(p(x, z)/q(x, z))dxdz. Denote its integrand by ℓ(q, p). Let ℓ ′ 2 (q, p) = ∂ℓ(q, p)/∂p. We have ∇ β ℓ(q(x, z), p(x, z)) = ℓ ′ 2 (q(x, z), p(x, z))∇ β p θ,β (x, z) where ∇ β p θ,β (x, z) is computed in Lemma 2. Besides, we have ∇ x • [ℓ ′ 2 (q, p)p(x, z) fβ (x)] = ℓ ′ 2 (q, p)p(x, z)∇ • fβ (x) + ℓ ′ 2 (q, p)∇ x p(x, z) • fβ (x) + ∇ x ℓ ′ 2 (q, p)p(x, z) fβ (x), ∇ z • [ℓ ′ 2 (q, p)p(x, z)f β (z)] = ℓ ′ 2 (q, p)p(x, z)∇ • f β (z) + ℓ ′ 2 (q, p)∇p(x, z) • f β (z) + ∇ℓ ′ 2 (q, p)p(x, z)f β (z). Thus, ∇ β L gen = # ∇ β ℓ(q(x, z), p(x, z))dxdz = # p(x, z)[∇ x ℓ ′ 2 (q, p) fβ (x) + ∇ z ℓ ′ 2 (q, p)f β (z)] where we can compute ∇ x ℓ ′ 2 (q, p) = s(x, z)∇ x D(x, z) and ∇ x ℓ ′ 2 (q, p) = s(x, z)∇ z D(x, z). Hence ∇ β L gen = -E (x,z)∼p(x,z) * s(x, z)(∇ x D(x, z) ⊤ fβ (x) + ∇ z D(x, z) ⊤ f β (z)) + = -E ' * s(x, z)(∇ x D(x, z) ⊤ ∇ β G(F β ($)) + ∇ z D(x, z) ⊤ ∇ β F β ($))| x=G(F β (')) z=F β (') + . where the second equality follows reparametrization.  P = ∪ h∈{(b-a)/n|n∈N + } P h P h = , - . k + (b-a)/h-1 & i=0 w i (x -a -ih)1(x ≥ a + ih) / / / / w i , k ∈ R 0 1 2 , where [•] here is floor function. Proof. Since [a, b] is compact, any function f ∈ C[a, b] is uniform continuous, i.e., ∀$ > 0, there exists δ > 0 such that |x -y| < δ =⇒ |f (x) -f (y)| < $/2. Let [a, b] = ∪ N -1 n=0 [a n , b n ] , and g n (x) be a linear function, such that a n = a + nh, b n = a + (n + 1)h, g n (a n ) = f (a n ), g i (b n ) = f (b n ), N h = b -a. Assume that h < δ. For any x ∈ [a n , b n ], we have |f (x) -g i (x)| ≤ min {|f (x) -f (a n )| + |g i (x) -g i (a n )|, |f (x) -f (b n )| + |g i (x) -g i (b n )|} ≤ |g i (a n ) -g i (b n )| + min {|f (x) -f (a n )|, |f (x) -f (b n )|} ≤ |f (a n ) -f (b n )| + min {|f (x) -f (a n )|, |f (x) -f (b n )|} < $.

Thus, sup

x∈ [an,bn] |f (x) -g n (x)| < $. We define g(x) = N -1 & n=1 g n (x)1(x ∈ [a n , b n ]) which is obvious that g(x) ∈ P h ⊂ P . And we have sup x∈[a,b] |f (x) -g(x)| < $ Therefore, P is dense in C[a, b] and P h is $-dense.

APPENDIX B LEARNING THE STRUCTURE

As mentioned in Section 4.1, our DEAR algorithm requires the prior knowledge on the super-graph of the true graph over the underlying factors of interests. The experiments shown in the main text are all based on the assumption that the true graph is given. In this section we investigate the performance of the learned weighted adjacency matrix and present an ablation study on different extents of prior knowledge on the structure.

B.1 GIVEN THE TRUE GRAPH

Figure 5 shows the learned weighted adjacency matrices when the true binary structure is given, whose weights show sensible signs and scalings consistent with common knowledge. For example, smile and its effect mouth open are positively correlated. The corresponding element in the weighted adjacency A 03 of (a) turns out positive, which makes sense. Also gender (the logit of male) and its effect make up are negatively correlated. Then A 13 of (b) turns out negative.

B.2 GIVEN THE TRUE CAUSAL ORDERING

Consider the Pendulum dataset, whose ground-truth structure is given in Figure 2 (a). Consider a causal ordering pendulum angle, light angle, shadow position, shadow length, given which we start with a full graph whose elements are randomly initialized around 0 as shown in Figure 6 (a). Figure 6 presents the adjacency matrices learned by DEAR at different training epochs, from which we see that it eventually obtains the learned structure that nearly coincides with the one learned given the true graph shown in Figure 5(c) . This experiment shows the potential of DEAR to incorporate structure learning methods to learn the latent causal structure from scratch, which will be explored in future research. notice that setting a larger independence regularizer hurts disentanglement in the correlated case. For the supervised regularizer, we use λ = 1000 for a balance of generative model and supervision. The ERM ResNet is trained using the same optimizer with a learning rate of 1 × 10 -4 . 

GENERATION

In this section we present more qualitative results of causal controllable generation on two datasets using DEAR and baseline methods, including S-VAEs (Locatello et al., 2020b) and CausalGAN (Kocaoglu et al., 2018) . We consider three underlying structures on two datasets: Pendulum in Figure 2 On the right we show the results of intervening on one latent variable from which we see the consequent changes of the others (the first type of intervention). Specifically intervening on the cause variable influences the effect variables while intervening on effect variables makes no difference to the causes.



The Pendulum dataset will be released as a causal disentanglement benchmark soon. Note that the definition of attractiveness here only refers to one kind of attractiveness, which has nothing to do with the linguistic definition of attractiveness.



Figure 1: Model structure of a bidirectional generative model (BGM) with an SCM prior.

training set {x1, . . . , xN , y1, . . . , yN s }, initial parameters φ, θ, β, ψ, batch-size n 1 while not convergence do 2 for multiple steps do 3 Sample {x1, . . . , xn} from the training set, {!1, . . . , !n} from N (0, I) 4 Generate from the causal prior zi = F β (!i), i = 1, . . . n 5 Update ψ by descending the stochastic gradient:

Figure 2: Underlying causal structures.

Figure 3: Results of causal controllable generation on Pendulum.

λb, and L * = L(E * , G) ≥ a + λb * > λb * for any generator G. When b * ≥ b we directly have L ′ < L * . When b * < b and λ is not large enough, i.e., λ < a b-b * , we have L ′ < L * . A.2 PROOF OF THEOREM 1

For any a, b ∈ R (a < b), the set of continuous piece-wise linear function P is dense in C[a, b] where the metric d(f, g) = sup x∈[a,b] |f (x) -g(x)|. Note that P is defined as

(a) CelebA-Smile (b) CelebA-Attractive (c) Pendulum

Figure 5: Learned adjacency matrices for different underlying structures.

Figure 7: (a) Architecture of the discriminator D(x, z); (b) A residual block (up scale) in the SAGAN generator where we use nearest neighbor interpolation for Upsampling; (c) A residual block (down scale) in the SAGAN discriminator.

∈ R k ∼ N (0, I) Linear → 4 × 4 × 16ch ResBlock up 16ch → 16ch ResBlock up 16ch → 8ch ResBlock up 8ch → 4ch Non-Local Block (64 × 64) ResBlock up 4ch → 2ch BN, ReLU, 3 × 3 Conv 2ch → 3 Tanh (b) Discriminator module Dx Input: RGB image x ∈ R 64×64×3 ResBlock down ch → 2ch Non-Local Block (64 × 64) ResBlock down 2ch → 4ch ResBlock down 4ch → 8ch ResBlock down 8ch → 16ch ResBlock 16ch → 16ch ReLU, Global average pooling (fx) Linear → 1 (sx) APPENDIX D ADDITIONAL RESULTS OF CAUSAL CONTROLLABLE

(a), CelebA-Smile in Figure 2(b), and CelebA-Attractive in Figure 2(c).

Figure 8: Results of DEAR. Note that the ordering of the representations matches that of the indices in Figure2. On the left we show the traditional latent traversals (the first type of intervention stated in Section 5.1). On the right we show the results of intervening on one latent variable from which we see the consequent changes of the others (the first type of intervention). Specifically intervening on the cause variable influences the effect variables while intervening on effect variables makes no difference to the causes.

Sample efficiency and test accuracy with different training sample sizes. DEAR-lin and -nlr denote the model with linear and nonlinear f . Line 1 is unsupervised; 2-3 are semi-supervised; others are supervised.

Distributional robustness. The worst-case and average test accuracyWe manipulate the training data to inject spurious correlations between the target label and some spurious attributes. On CelebA, we regard mouth open as the spurious factor; on Pendulum, we choose background color ∈ {blue(+), white(-)}. We manipulate the training data such that the target label is more strongly correlated with the spurious attributes, i.e., the target label and the spurious attribute of 80% of the examples are both positive or negative, while those of 20% examples are opposite. For example, in the manipulated training set, 80% smiling examples in CelebA have an open mouth; 80% corrupted examples in Pendulum are masked with a blue background. The test set however does not have these correlations, leading to a distribution shift.

Shengjia Zhao, Jiaming Song, and Stefano Ermon. Learning hierarchical features from generative models. In ICML, 2017. Xun Zheng, Bryon Aragam, Pradeep K Ravikumar, and Eric P Xing. Dags with no tears: Continuous optimization for structure learning. In Advances in Neural Information Processing Systems, pp. 9472-9483, 2018.

SAGAN architecture (k = 100 and ch = 32).

A.3 PROOF OF LEMMA 1

We follow the same proof scheme as in Shen et al. (2020) where the only difference lies in the gradient wrt the prior parameter β. To make this paper self-contained, we restate some proof steps here using our notations.Let ) • ) denote the vector 2-norm. For a scalar function h(x, y), let ∇ x h(x, y) denote its gradient with respect to x. For a vector function g(x, y), let ∇ x g(x, y) denote its Jacobi matrix with respect to x. Given a differentiable vector function g(x) : R k → R k , we use ∇ • g(x) to denote its divergence, defined aswhere [x] j denotes the j-th component of x. We know that # ∇ • g(x)dx = 0 for all vector function g(x) such that g(∞) = 0. Given a matrix function w(x) = (w 1 (x), . . . , w l (x)) : R k → R k×l where each w i (x), i = 1 . . . , l is a k-dimensional differentiable vector function, its divergence is defined asTo prove Lemma 1, we need the following lemma which specifies the dynamics of the generator joint distribution p g (x, z) and the encoder joint distribution p e (x, z), denoted by p θ (x, z) and p φ (x, z) here.Lemma 2. Using the definitions and notations in Lemma 1, we havefor all data x and latent variable z, whereProof of Lemma 2. We only prove (10) which is the distinct part from Shen et al. (2020) .Let l be the dimension of parameter β. To simplify notation, let random vector Z = F β ($) and X = G(Z) ∈ R d and Y = (X, Z) ∈ R d+k , and let p be the probability density of Y . For each i = 1, . . . , l, let ∆ = δe i where e i is a l-dimensional unit vector whose i-th component is one and all the others are zero, and δ is a small scalar. LetLet p ′ be the probability density of Y ′ . For an arbitrary y

APPENDIX C IMPLEMENTATION DETAILS

In this section we state the details of experimental setup and the network architectures used for all experiments.Preprocessing and hyperparameters. We pre-process the images by taking a center crops of 128 × 128 for CelebA and resizing all images in CelebA and Pendulum to the 64 × 64 resolution. We adopt Adam with β 1 = 0, β 2 = 0.999, and a learning rate of 1 × 10 -4 for D, 5 × 10 -5 for E, G and F , and 1 × 10 -3 for the adjacency matrix A. We use a mini-batch size of 128. For adversarial training in Algorithm 1, we train the D once on each mini-batch. The coefficient λ of the supervised regularizer is set to 5. We use CE supervised loss for both CelebA with binary observations of the underlying factors and Pendulum with bounded continuous observations. Note that L 2 loss works comparable to CE loss on Pendulum. In downstream tasks, for BGMs with an encoder, we train a two-level MLP classifier with 100 hidden nodes using Adam with a learning rate of 1 × 10 -2 and a mini-batch size of 128. Models were trained for around 150 epochs on CelebA and 600 epochs on Pendulum on NVIDIA RTX 2080 Ti.Network architectures. We follow the architectures used in Shen et al. (2020) . Specifically, for such realistic data, we adopt the SAGAN (Zhang et al., 2019) architecture for D and G. The D network consists of three modules as shown in Figure 7 and detailed described in (Shen et al., 2020) . Details for newtork G and D x are given in Figure 7 and Table 3 . The encoder architecture is the ResNet50 (He et al., 2016) followed by a 4-layer MLP of size 1024.Implementation of the SCM. Recall the nonlinear SCM as the priorWe find Gaussians are expressive enough as unexplained noises, so we set h as the identity mapping. As mentioned in Section 4.1 we require the invertibility of f . We implement both linear and nonlinear ones. For a linear f , we formally refer towhere W and b are learnable weights and biases. Note that W is a diagonal matrix to model the element-wise transformation. Its inverse function can be easily computed byFor a non-linear f , we use piece-wise linear functions defined bywhere • (i) denote the i-th dimension of a vector or a vector-function, a 0 < a 1 < • • • < a Na are points of division, and 1(•) is the indicator function. From its denseness shown in lemma 3, the family of such piece-wise linear functions is expressive enough to model general element-wise nonlinear invertible transformations.Experimental details for baseline methods. We reproduce the S-VAEs including S-VAE, S-β-VAE and S-TCVAE using E and G with the same architecture as ours and adopt the same optimization algorithm for training. The coefficient for the independence regularizer is set to 4 since we 

