LEARNING TO INDUCE CAUSAL STRUCTURE

Abstract

One of the fundamental challenges in causal induction is to infer the underlying graph structure given observational and/or interventional data. Most existing causal induction algorithms operate by generating candidate graphs and evaluating them using either score-based methods (including continuous optimization) or independence tests. In our work, we instead treat the inference process as a black box and design a neural network architecture that learns the mapping from both observational and interventional data to graph structures via supervised training on synthetic graphs. The learned model generalizes to new synthetic graphs, is robust to train-test distribution shifts, and achieves state-of-the-art performance on naturalistic graphs for low sample complexity.

1. INTRODUCTION

The problem of discovering the causal relationships that govern a system through observing its behavior, either passively (observational data) or by manipulating some of its variables (interventional data), lies at the core of many important challenges spanning scientific disciplines, including medicine, biology, and economics. By using the graphical formalism of causal Bayesian networks (CBNs) (Koller & Friedman, 2009; Pearl, 2009) , this problem can be framed as inducing the graph structure that best represents the relationships. Most approaches to causal structure induction (or causal discovery) are based on an unsupervised learning paradigm in which the structure is directly inferred from the system observations, either by ranking different structures according to some metrics (score-based approaches) or by determining the presence of an edge between pairs of variables using conditional independence tests (constraint-based approaches) (Drton & Maathuis, 2017; Glymour et al., 2019; Heinze-Deml et al., 2018a; b) (see Fig. 1(a) ). The unsupervised paradigm poses however some challenges: score-based approaches are burdened with the high computational cost of having to explicitly consider all possible structures and with the difficulty of devising metrics that can balance goodness of fit with constraints for differentiating causal from purely statistical relationships (e.g. sparsity of the structure or simplicity of the generation mechanism); constraint-based methods are sensitive to failure of independence tests and require faithfulness, a property that does not hold in many real-world scenarios (Koski & Noble, 2012; Mabrouk et al., 2014) . Recently, supervised learning methods based on observational data have been introduced as an alternative to unsupervised approaches (Lopez-Paz et al., 2015a; b; Li et al., 2020) . In this work, we extend the supervised learning paradigm to also use interventional data, enabling greater flexibility. We propose a model that is first trained on synthetic data generated using different CBNs to learn a mapping from data to graph structures and then used to induce the structures underlying datasets of interest (see Fig. 1(b) ). The model is a novel variant of a transformer neural network that receives as input a dataset consisting of observational and interventional samples corresponding to the same CBN and outputs a prediction of the CBN graph structure. The mapping from the dataset to the underlying structure is achieved through an attention mechanism which alternates between attending to different variables in the graph and to different samples from a variable. The output is produced by a decoder mechanism that operates as an autoregressive generative model on the inferred structure. Our approach can be viewed as a form of meta-learning, whereby the relationship between datasets and causal structures underlying them are learned rather than built-in. A requirement of a supervised approach would seem to be that the distributions of the training and test data match or highly overlap. Obtaining real-world training data with a known causal structure that matches test data from multiple domains is extremely challenging. We show that meta-learning enables the model to generalize well to data from naturalistic CBNs even if trained on synthetic data with relatively few assumptions, where naturalistic CBNs are graphs that correspond to causal relationships that exist in nature, such as graphs from the bnlearn repository (www.bnlearn.com/ bnrepository). We show that our model can learn a mapping from datasets to structures and achieves state-of-the-art performance on classic benchmarks such as the Sachs, Asia and Child datasets (Lauritzen & Spiegelhalter, 1988; Sachs et al., 2005; Spiegelhalter & Cowell, 1992) , despite never directly being trained on such data. Our contributions can be summarized as follows: • We tackle causal structure induction with a supervised approach (CSIvA) that maps datasets composed of both observational and interventional samples to structures. • We introduce a variant of a transformer architecture whose attention mechanism is structured to discover relationships among variables across samples. • We show that CSIvA generalizes to novel structures, whether or not training and test distributions match. Most importantly, training on synthetic data transfers effectively to naturalistic CBNs. • We show that CSIvA significantly outperforms state-of-the-art causal structure induction methods such as DCDI (Brouillard et al., 2020) and ENCO (Lippe et al., 2021) both on various types of synthetic CBNs, as well as on naturalistic CBNs.

2. BACKGROUND

In this section we give some background on causal Bayesian networks and transformer neural networks, which form the main ingredients of our approach (see Appendix A.1 for more details). Causal Bayesian networks (CBNs). A causal Bayesian network (Koller & Friedman, 2009; Pearl, 2009) is a pair M = ⟨G, p⟩, where G is a directed acyclic graph (DAG) whose nodes X 1 , . . . , X N represent random variables and edges express casual dependencies among them, and where p is a joint distribution over all nodes such that p(X 1 , . . . , X N ) = N n=1 p(X n | pa(X n )) , where pa(X n ) are the parents of X n , i.e. the nodes with an edge onto X n (direct causes of X n ). An input to the transformer neural network is formed by a dataset D = {x s } S s=1 , where x s := (x s 1 , . . . , x s N ) T is either an observational data sample or an interventional data sample obtained by performing an intervention on a randomly selected node in G. Observational data samples are samples from p(X 1 , . . . , X N ). Except where otherwise noted, for all experimental settings, we considered hard interventions on a node X n that consist in replacing the conditional probability distribution (CPD) p(X n | pa(X n )) with a delta function δ X n ′ =x forcing X n ′ to take on value x. Additional experiments were also performed using soft interventions, which consisted of replacing p(X n | pa(X n )) with a different CPD p ′ (X n | pa(X n )). An interventional data sample is a sample from δ X n ′ =x N n=1,n̸ =n ′ p(X n | pa(X n )) in the first case, and a sample from p ′ (X n | pa(X n )) N n=1,n̸ =n ′ p(X n | pa(X n )) in the second case. The structure of G can be repre-



DeepMind, Mila, Polytechnique Montreal, University of Montreal, Google Research, Brain Team, Corresponding author: nke@google.com



Figure 1: (A). Standard unsupervised approach to causal structure induction: Algorithms use a predefined scoring metric or statistical independence tests to select the best candidate structures. (B). Our supervised approach to causal structure induction (via attention; CSIvA): A model is presented with data and structures as training pairs and learns a mapping between them.

