EFFICIENTLY DISENTANGLE CAUSAL REPRESENTATIONS

Abstract

In this paper, we propose a novel approach to efficiently learning disentangled representations with causal mechanisms, based on the difference of conditional probabilities in original and new distributions. We approximate the difference with model's generalization abilities so that it fits in standard machine learning framework and can be efficiently computed. In contrast to the state-of-the-art approach, which relies on learner's adaptation speed to new distribution, the proposed approach only requires evaluating the generalization ability of the model. We provide theoretical explanation for the advantage of the proposed method, and our experiments show that the proposed technique is 1.9-11.0× more sample efficient and 9.4-32.4× quicker than the previous method on various tasks. The source code is in supplementary material.

1. INTRODUCTION

Causal reasoning is a fundamental tool that has shown great impact in different disciplines (Rubin & Waterman, 2006; Ramsey et al., 2010; Rotmensch et al., 2017) , and it has roots in work by David Hume in the eighteenth century (Hume, 2003) and in classical AI (Pearl, 2003) . Causality has been mainly studied from a statistical perspective (Pearl, 2009; Peters et al., 2016; Greenland et al., 1999; Pearl, 2018) with Judea Pearl's work on the causal calculus leading its statistical development. More recently, there has been a growing interest to integrate statistical techniques into machine learning to leverage their benefits. Welling raises a particular question about how to disentangle correlation from causation in machine learning settings to take advantage of the sample efficiency and generalization abilities of causal reasoning (Welling, 2015) . Although machine learning has achieved important results on a variety of tasks like computer vision and games over the past decade (e.g., Mnih et al. ( 2015 2018)), current approaches can struggle to generalize when the test data distribution is much different from the training distribution (common in real applications). Further, these successful methods are typically "data hungry", requiring an abundance of labeled examples to perform well across data distributions. In statistical settings, encoding the causal structure in models has been shown to have significant efficiency advantages. In support of the advantages of encoding causal mechanisms, Bengio et al. ( 2020) recently introduced an approach to disentangling causal relationships in end-to-end machine learning by comparing the adaptation speeds of separate models that encode different causal structures. With this as the baseline, in this paper, we propose a more efficient approach to learning disentangled representations with causal mechanisms, based on the difference of conditional probabilities in original and new distributions. The key idea is to approximate the difference with model's generalization abilities so that it fits in standard machine learning framework and can be efficiently computed. In contrast to the state-of-the-art baseline approach, which relies on learner's adaptation speed to new distribution, the proposed approach only requires evaluating the generalization ability of the model. Our method is based on the same assumption as the baseline that the conditional distribution P (B|A) does not change between the train and transfer distribution. This assumption can be explained with an atmospheric physics example of learning P (A, B), where A (Altitude) causes B (Temperature) (Peters et al., 2016) . The marginal probability of A can be changed depending on, for instance, the place (Switzerland to less mountainous country like Netherlands), but P (B|A) 1



); Silver et al. (2017); Szegedy et al. (2017); Hudson & Manning (

