EFFICIENTLY DISENTANGLE CAUSAL REPRESENTATIONS

Abstract

In this paper, we propose a novel approach to efficiently learning disentangled representations with causal mechanisms, based on the difference of conditional probabilities in original and new distributions. We approximate the difference with model's generalization abilities so that it fits in standard machine learning framework and can be efficiently computed. In contrast to the state-of-the-art approach, which relies on learner's adaptation speed to new distribution, the proposed approach only requires evaluating the generalization ability of the model. We provide theoretical explanation for the advantage of the proposed method, and our experiments show that the proposed technique is 1.9-11.0× more sample efficient and 9.4-32.4× quicker than the previous method on various tasks. The source code is in supplementary material.

1. INTRODUCTION

Causal reasoning is a fundamental tool that has shown great impact in different disciplines (Rubin & Waterman, 2006; Ramsey et al., 2010; Rotmensch et al., 2017) , and it has roots in work by David Hume in the eighteenth century (Hume, 2003) and in classical AI (Pearl, 2003) . Causality has been mainly studied from a statistical perspective (Pearl, 2009; Peters et al., 2016; Greenland et al., 1999; Pearl, 2018) with Judea Pearl's work on the causal calculus leading its statistical development. More recently, there has been a growing interest to integrate statistical techniques into machine learning to leverage their benefits. Welling raises a particular question about how to disentangle correlation from causation in machine learning settings to take advantage of the sample efficiency and generalization abilities of causal reasoning (Welling, 2015) . Although machine learning has achieved important results on a variety of tasks like computer vision and games over the past decade (e.g., Mnih et al. (2015); Silver et al. (2017); Szegedy et al. (2017); Hudson & Manning (2018) ), current approaches can struggle to generalize when the test data distribution is much different from the training distribution (common in real applications). Further, these successful methods are typically "data hungry", requiring an abundance of labeled examples to perform well across data distributions. In statistical settings, encoding the causal structure in models has been shown to have significant efficiency advantages. In support of the advantages of encoding causal mechanisms, Bengio et al. (2020) recently introduced an approach to disentangling causal relationships in end-to-end machine learning by comparing the adaptation speeds of separate models that encode different causal structures. With this as the baseline, in this paper, we propose a more efficient approach to learning disentangled representations with causal mechanisms, based on the difference of conditional probabilities in original and new distributions. The key idea is to approximate the difference with model's generalization abilities so that it fits in standard machine learning framework and can be efficiently computed. In contrast to the state-of-the-art baseline approach, which relies on learner's adaptation speed to new distribution, the proposed approach only requires evaluating the generalization ability of the model. Our method is based on the same assumption as the baseline that the conditional distribution P (B|A) does not change between the train and transfer distribution. This assumption can be explained with an atmospheric physics example of learning P (A, B), where A (Altitude) causes B (Temperature) (Peters et al., 2016) . The marginal probability of A can be changed depending on, for instance, the place (Switzerland to less mountainous country like Netherlands), but P (B|A) as the underlining causality mechanism does not change. Therefore, the causality structure can be inferred from the robustness of predictive models on out-of-distribution (Peters et al., 2016; 2017) . The proposed method is more efficient by omitting the adaptation process, and is more robust when the marginal distribution is complicated. We provide theoretical explanation and experimental verification for the advantage of the proposed method. Our experiments show the proposed technique is 1.9-11.0× more sample efficient and 9.4-32.4× quicker than measuring adaptation speed on various tasks. We also argue that the proposed approach has less hyper parameters and it is straight-forward to implement the approach within the standard machine learning workflows. Our contributions can be summarized as follows. • We propose an efficient approach to disentangling representations for causal mechanisms by measuring generalization. • We theoretically prove that the proposed estimators can identify causal direction and disentangles causal mechanisms. • We empirically show that the proposed approach is significantly quicker and more sample efficient for various tasks. Sample efficiency is important when data size in transfer distribution is small.

2. APPROACH

To begin, we reflect on the tasks and the disentangling approach (as baseline) described by previous work (Bengio et al., 2020) . The invariance of conditional distribution for the correct causal direction P (B|A) is the key assumption in their work, and we also follow it in this work. We notice that their baseline approach compares the adaptation speed of models on a transfer data distribution, and hence requires significant time for adaptation. We propose an approach to learn causality mechanisms by directly measuring the changes in conditional probabilities before and after intervention for both P (B|A) and P (A|B). Further, we optimize the proposed approach to use generalization loss rather than a divergence metric-because loss can be directly measured in standard machine learning workflows-and we show that it is likely to correctly predict causal direction and disentangle causal mechanisms.

2.1. CAUSALITY DIRECTION PREDICTION

As the first step towards learning disentangled representations for causal mechanisms, we start with the binary classification task. Given two discrete variables A and B, we want to determine whether A causes B, or vice-versa. We assume noiseless dynamics, and A and B do not have hidden confounders. The training (transfer) data contains samples (a, b) from training (transfer) distribution, P 1 (P 2 ). The baseline approach defines models that factor the joint distribution P (A, B) into two causality directions P A→B (A, B) = P A→B (B|A)P A→B (A) and P B→A = P B→A (A|B)P B→A (B). It then compares their speed of adaptation to transfer distribution. Intuitively, the factorization with correct causality direction should adapt more quickly to the transfer distribution. Suppose A → B is the ground-truth causal direction. For the correct factorization, they assume that conditional distribution P A→B (B|A) does not change between the train and transfer distributions, so that only the marginal P A→B (A) needs adaptation. In contrast, for the factorization with incorrect causality direction, both the conditional P B→A (A|B) and marginal distributions P B→A (B) need adaptation. They analyze that updating a marginal distribution P (A) A→B is likely to have lower sample complexity than the conditional distribution P B→A (A|B) (only a part of P B→A (A, B)) because the later has more parameters. Therefore, the model with correct factorization will adapt more quickly, and causality direction can be predicted from adaptation speed. To leverage this observation, the baseline method defines a meta-learning objective. Let L A→B and L B→A be the log-likelihood of P A→B (A, B) and P B→A (A, B), respectively. Their baseline approach optimizes a regret, R, to acquire an indicator variable γ ∈ R. If γ > 0, the prediction is

