LEARNING TO INDUCE CAUSAL STRUCTURE

Abstract

One of the fundamental challenges in causal induction is to infer the underlying graph structure given observational and/or interventional data. Most existing causal induction algorithms operate by generating candidate graphs and evaluating them using either score-based methods (including continuous optimization) or independence tests. In our work, we instead treat the inference process as a black box and design a neural network architecture that learns the mapping from both observational and interventional data to graph structures via supervised training on synthetic graphs. The learned model generalizes to new synthetic graphs, is robust to train-test distribution shifts, and achieves state-of-the-art performance on naturalistic graphs for low sample complexity.

1. INTRODUCTION

The problem of discovering the causal relationships that govern a system through observing its behavior, either passively (observational data) or by manipulating some of its variables (interventional data), lies at the core of many important challenges spanning scientific disciplines, including medicine, biology, and economics. By using the graphical formalism of causal Bayesian networks (CBNs) (Koller & Friedman, 2009; Pearl, 2009) , this problem can be framed as inducing the graph structure that best represents the relationships. Most approaches to causal structure induction (or causal discovery) are based on an unsupervised learning paradigm in which the structure is directly inferred from the system observations, either by ranking different structures according to some metrics (score-based approaches) or by determining the presence of an edge between pairs of variables using conditional independence tests (constraint-based approaches) (Drton & Maathuis, 2017; Glymour et al., 2019; Heinze-Deml et al., 2018a; b) (see Fig. 1(a) ). The unsupervised paradigm poses however some challenges: score-based approaches are burdened with the high computational cost of having to explicitly consider all possible structures and with the difficulty of devising metrics that can balance goodness of fit with constraints for differentiating causal from purely statistical relationships (e.g. sparsity of the structure or simplicity of the generation mechanism); constraint-based methods are sensitive to failure of independence tests and require faithfulness, a property that does not hold in many real-world scenarios (Koski & Noble, 2012; Mabrouk et al., 2014) . Recently, supervised learning methods based on observational data have been introduced as an alternative to unsupervised approaches (Lopez-Paz et al., 2015a; b; Li et al., 2020) . In this work, we extend the supervised learning paradigm to also use interventional data, enabling greater flexibility. We propose a model that is first trained on synthetic data generated using different CBNs to learn a mapping from data to graph structures and then used to induce the structures underlying datasets of interest (see Fig. 1(b) ). The model is a novel variant of a transformer neural network that receives as input a dataset consisting of observational and interventional samples corresponding to the same CBN and outputs a prediction of the CBN graph structure. The mapping from the dataset to the underlying structure is achieved through an attention mechanism which alternates between attending to different variables in the graph and to different samples from a variable. The output is produced by a decoder mechanism that operates as an autoregressive generative model on the inferred structure. Our approach can be viewed as a form of meta-learning, whereby the relationship between datasets and causal structures underlying them are learned rather than built-in. Our supervised approach to causal structure induction (via attention; CSIvA): A model is presented with data and structures as training pairs and learns a mapping between them. A requirement of a supervised approach would seem to be that the distributions of the training and test data match or highly overlap. Obtaining real-world training data with a known causal structure that matches test data from multiple domains is extremely challenging. We show that meta-learning enables the model to generalize well to data from naturalistic CBNs even if trained on synthetic data with relatively few assumptions, where naturalistic CBNs are graphs that correspond to causal relationships that exist in nature, such as graphs from the bnlearn repository (www.bnlearn.com/ bnrepository). We show that our model can learn a mapping from datasets to structures and achieves state-of-the-art performance on classic benchmarks such as the Sachs, Asia and Child datasets (Lauritzen & Spiegelhalter, 1988; Sachs et al., 2005; Spiegelhalter & Cowell, 1992) , despite never directly being trained on such data. Our contributions can be summarized as follows: • We tackle causal structure induction with a supervised approach (CSIvA) that maps datasets composed of both observational and interventional samples to structures. • We introduce a variant of a transformer architecture whose attention mechanism is structured to discover relationships among variables across samples. • We show that CSIvA generalizes to novel structures, whether or not training and test distributions match. Most importantly, training on synthetic data transfers effectively to naturalistic CBNs. • We show that CSIvA significantly outperforms state-of-the-art causal structure induction methods such as DCDI (Brouillard et al., 2020) and ENCO (Lippe et al., 2021) both on various types of synthetic CBNs, as well as on naturalistic CBNs.

2. BACKGROUND

In this section we give some background on causal Bayesian networks and transformer neural networks, which form the main ingredients of our approach (see Appendix A.1 for more details). Causal Bayesian networks (CBNs). A causal Bayesian network (Koller & Friedman, 2009; Pearl, 2009) is a pair M = ⟨G, p⟩, where G is a directed acyclic graph (DAG) whose nodes X 1 , . . . , X N represent random variables and edges express casual dependencies among them, and where p is a joint distribution over all nodes such that p(X 1 , . . . , X N ) = N n=1 p(X n | pa(X n )) , where pa(X n ) are the parents of X n , i.e. the nodes with an edge onto X n (direct causes of X n ). An input to the transformer neural network is formed by a dataset D = {x s } S s=1 , where x s := (x s 1 , . . . , x s N ) T is either an observational data sample or an interventional data sample obtained by performing an intervention on a randomly selected node in G. Observational data samples are samples from p(X 1 , . . . , X N ). Except where otherwise noted, for all experimental settings, we considered hard interventions on a node X n that consist in replacing the conditional probability distribution (CPD) p(X n | pa(X n )) with a delta function δ X n ′ =x forcing X n ′ to take on value x. Additional experiments were also performed using soft interventions, which consisted of replacing p( X n | pa(X n )) with a different CPD p ′ (X n | pa(X n )). An interventional data sample is a sample from δ X n ′ =x N n=1,n̸ =n ′ p(X n | pa(X n )) in the first case, and a sample from p ′ (X n | pa(X n )) N n=1,n̸ =n ′ p(X n | pa(X n )) in the second case. The structure of G can be repre-sented by an adjacency matrix A, defined by setting the (k, l) entry, A k,l , to 1 if there is an edge from X l to X k and to 0 otherwise. Therefore, the n-th row of A, denoted by A n,: , indicates the parents of X n while the n-th column, denoted by A :,n , indicates the children of X n . Transformer neural network. A transformer (Devlin et al., 2018; Vaswani et al., 2017) is a neural network equipped with layers of self-attention that make them suited to modeling structured data. In traditional applications, attention is used to account for the sequentially ordered nature of the data, e.g. modeling a sentence as a stream of words. In our case, each input of the transformer is a dataset of observational or interventional samples corresponding to the same CBN. Attention is thus used to account for the structure induced by the CBN graph structure and by having different samples from the same node. Transformers are permutation invariant with respect to the positions of the input elements, ensuring that the graph structure prediction does not depend on node and sample position.

3. CAUSAL STRUCTURE INDUCTION VIA ATTENTION (CSIVA)

Figure 2 : Our model architecture and the structure of the input and output at training time. The input is a dataset D = {x s := (x s 1 , . . . , x s N ) T } S s=1 of S samples from a CBN and its adjacency matrix A. The output is a prediction Â of A. Our approach is to treat causal structure induction as a supervised learning problem, by training a neural network to learn to map observational and interventional data to the graph structure of the underlying CBN. Obtaining diverse, real-world data with known causal relationships in amounts sufficient for supervised training is not feasible. The key contribution of this work is to introduce a method that uses synthetic data generated from CBNs with different graph structures and associated CPDs that is robust to shifts between the training and test data distributions.

3.1. SUPERVISED APPROACH

We learn a distribution of graphs conditioned on observational and interventional data as follows. We generate training data from a joint distribution t(G, D), between a graph G and a dataset D comprising of S observational and interventional samples from a CBN associated to G as follows. We first sample a set of graphs {G i } I i=1 with nodes X i 1 , . . . , X i N from a common distribution t(G) as described in Section 4 (to simplify notation, in the remainder of the paper we omit the graph index i when referring to nodes), and then associate random CPDs to the graphs as described in Section 4. This results in a set of CBNs {M i } I i=1 . For each CBN M i , we then create a dataset D i = {x s } S s=1 , where each element x s := (x s 1 , . . . , x s N ) T is either an observational data sample or an interventional data sample obtained by performing an intervention on a randomly selected node in G i . Our model defines a distribution t(G | D; Θ) over graphs conditioned on observational and interventional data and parametrized by Θ. Specifically, t(A | D; Θ) has the following auto-regressive form: t(A | D; Θ) = N 2 l=1 σ(A l ; Âl = f Θ (A 1,...,(l-1) , D)) , where σ(•; ρ) is the Bernoulli distribution with parameter ρ, which is a function f Θ built from an encoder-decoder architecture explained in Section 3.2 taking as input previous elements of the adjacency matrix A (represented here as an array of N 2 elements) and D. It is trained via maximum likelihood estimation (MLE), i.e Θ * = argmin Θ L(Θ), where L(Θ) = -E (G,D)∼t [ln t(G | D; Θ)], which corresponds to the usual cross-entropy (CE) loss for the Bernoulli distribution. Training is achieved using a stochastic gradient descent (SGD) approach in which each gradient update is performed using a pair (D i , A i ). The data-sampling distribution t(G, D) and the MLE objective uniquely determine the target distribution learned by the model. In the infinite capacity case, t(• | D; Θ * ) = t(• | D). To see this, it suffices to note that the MLE objective L(Θ) can be written as L(Θ) = E D∼t [KL( t(• | D; Θ); t(• | D))] + c , where KL is the Kullback-Leibler divergence and c is a constant. In the finite-capacity case, the distribution defined by the model t(• | D; Θ * ) is only an approximation of t(• | D).

3.2. MODEL ARCHITECTURE

The function f Θ defining the model's probabilities is built using two transformer networks. It is formed by an encoder transformer and by a decoder transformer (which we refer to as "encoder" and "decoder" for short). At training time, the encoder receives as input dataset D i and outputs a representation that summarizes the relationship between nodes in G i . The decoder then recursively outputs predictions of the elements of the adjacency matrix A i using as input the elements previously predicted and the encoder output. This is shown in Fig. 2 (where we omitted index i, as in the remainder of the section). At test time we obtain deterministic predictions of the adjacency matrix elements by taking the argmax of the Bernoulli distribution for use as inputs to the decoder.

3.2.1. ENCODER

Our encoder is structured as an (N + 1) × (S + 1) lattice. The N × S part of the lattice formed by the first N rows and first S columns receives a dataset D = {(x s 1 , . . . , x s N ) T } S s=1 . This is unlike standard transformers which typically receive as input a single data sample (e.g., a sequence of words in neural machine translation applications) rather than a set of data samples. Row N + 1 of the lattice is used to specify whether each data sample is observational, through value -1, or interventional, through integer value in {1, . . . , N } to indicate the intervened node. The goal of the encoder is to infer causal relationships between nodes by examining the set of samples. The transformer performs this inference in multiple stages, each represented by one transformer layer, such that each layer yields a (N + 1) × (S + 1) lattice of representations. The transformer is designed to deposit its summary representation of the causal structure in column S + 1. Embedding of the input. Each data-sample element x s n is embedded into a vector of dimensionality H. Half of this vector is allocated to embed the value x s n itself, while the other half is allocated to embed the unique identity for the node X n . The value embedding is obtained by passing x s n , whether discrete or continuous, through an MLPUsing an MLP for a discrete variable is a slightly inefficient implementation of a node value embedding, but it ensures that the architecture is general. encoder specific to node X n . We use a node-specific embedding because the values of each node may have very different interpretations and meanings. The node identity embedding is obtained using a standard 1D transformer positional embedding over node indices. For column S + 1 of the input, the value embedding is a vector of zeros. Alternating attention. Traditional transformers discover relationships among the elements of a data sample arranged in a one-dimensional sequence. With our two-dimensional lattice, the transformer could operate over the entire lattice at once to discover relationships among both nodes and samples. Given an encoding that indicates position n, s in the lattice, the model can in principle discover stable relationships among nodes over samples. However, the inductive bias to encourage the model to leverage the lattice structure is weak. Additionally, the model is invariant to sample ordering, which is desirable because the samples are iid. Therefore, we arrange our transformer in alternating layers. In the first layer of the pair, attention operates across all nodes of a single sample (x s 1 , . . . , x s N ) T to encode the relationships among two or more nodes. In the second layer of the pair, attention operates across all samples for a given node (x 1 n , . . . , x S n ) to encode information about the distribution of node values. Alternating attention in transformers was also done in Kossen et al. (2021) . Encoder summary. The encoder produces a summary vector e sum n with H elements for each node X n , which captures essential information about the node's behavior and its interactions with other nodes. The summary representation is formed independently for each node and involves combining information across the S samples (the columns of the lattice). This is achieved with a method often used with transformers that involves a weighted average based on how informative each sample is. The weighting is obtained using the embeddings in column S + 1 to form queries, and embeddings in columns 1, . . . , S to provide keys and values, and then using standard key-value attention.

3.2.2. DECODER

The decoder uses the summary information from the encoder to generate a prediction of the adjacency matrix A of the underlying G. It operates sequentially, at each step producing a binary output indicating the prediction Âk,l of A k,l , proceeding row by row. The decoder is an autoregressive transformer, meaning that each prediction Âkl is obtained based on all elements of A previously predicted, as well as the summary produced by the encoder. Our method does not enforce acyclicity. Although this could in principle yield cycles in the graph, in practice we observe strong performance regardless (Section 6.4), likely due to the fact that training and evaluation graphs typically studied (e.g., ER-1 and ER-2) are very sparse. Nevertheless, one could likely improve the results e.g. by using post-processing (Lippe et al., 2021) or by extending the method with an accept-reject algorithm (Castelletti & Mascaro, 2022; Li et al., 2022) . Auxiliary loss. Autoregressive decoding of the flattened N × N adjacency matrix can be difficult for the decoder to learn alone. To provide additional inductive bias to facilitate learning, we added auxiliary task of predicting the parents A n,: and children A :,n of node X n from the encoder summary, e sum n . This is achieved using an MLP to learn a mapping f n , such that f n (e sum n ) = ( Ân,: , ÂT :,n ). While this prediction is redundant with the operation of the decoder, it short-circuits the autoregressive decoder and provides a strong training signal to support proper training.

4. SYNTHETIC DATA

In this section, we discuss identifiability and describe how synthetic data were generated. Identifiability. Dataset D i associated to CBN M i is given by D i = {x s } S s=1 , where x s := (x s 1 , . . . , x s N ) T is either an observational or interventional data sample obtained by performing a hard intervention on a randomly selected node in G i . As discussed in Eberhardt et al. (2006) , in the limit of an infinite amount of such single-node interventional data samples, G i is identifiable. Note, that identifiability here means the ability to recover the exact graph given the data. As our model defines a distribution over graphs, its predictions are meaningful even when the amount of data is insufficient for identifiability: in this case, the model would sample graphs that are compatible with the given data. Empirically, we found that our model can make reasonable predictions even with a small amount of samples per intervention and improves as more samples are observed (Section 6.4). Graph distribution. We specified a distribution over G in terms of the number of nodes N (graph size) and number of edges (graph density) present in G. As shown in Zheng et al. (2018) ; Yu et al. (2019) ; Ke et al. (2020a) , larger and denser graphs are more challenging to learn. We varied N from 5 to 80. We used the Erdős-Rényi (ER) and the scale-free (SF) metrics to vary density and evaluated our model on ER-1 and ER-2 graphs, as in Yu et al. (2019) ; Brouillard et al. (2020) ; Scherrer et al. (2021) . We generated an adjacency matrix A by first sampling a lower-triangular matrix to ensure that it represents a DAG, and by then permuting the order of the nodes to ensure random ordering. Conditional probability distributions. We performed ancestral sampling on the underlying CBN. We considered both continuous and discrete nodes. For continuous nodes, we generated continuous data using three methods following similar setups in previous works: (i) linear models (linear data) (Zheng et al., 2018; Yu et al., 2019) , (ii) nonlinear with additive noise models (ANM data) (Brouillard et al., 2020; Lippe et al., 2021) , and (iii) nonlinear with non-additive noise models using neural networks (NN data) (Brouillard et al., 2020; Lippe et al., 2021) . For discrete nodes, we generated discrete data using two different methods: MLP (MLP data) and Dirichlet (Dirichlet data) conditionalprobability table generators. Following past work (Ke et al., 2020a; Scherrer et al., 2021) , we used a randomly initialized network. The Dirichlet generator filled in the rows of a conditional probability table by sampling a categorical distribution from a Dirichlet prior with symmetric parameters α. (We remind the reader that this generative procedure is performed prior to node ordering being randomized for presentation to the learning model.) Refer to Appendix A.2 for details.

5. RELATED WORK

Methods for inferring causal graphs can broadly be categorized into score-based, constraint-based, and asymmetry-based methods. Score-based methods search through the space of possible candidate graphs, and ranks them based on some scoring function (Chickering, 2002; Cooper & Yoo, 1999; Goudet et al., 2017; Hauser & Bühlmann, 2012; Heckerman et al., 1995; Tsamardinos et al., 2006; Huang et al., 2018; Zhu et al., 2019) . Recently, Zheng et al. (2018) ; Yu et al. (2019) ; Lachapelle et al. (2019) framed the structure search as a continuous optimization problem. There exist score-based methods that use a mix of continuous and discrete optimization (Bengio et al., 2019; Zhu et al., 2019; Ke et al., 2020a; Lippe et al., 2021; Scherrer et al., 2021) . Constraint-based methods (Monti et al., 2019; Spirtes et al., 2000; Sun et al., 2007; Zhang et al., 2012; Zhu et al., 2019) infer the DAG by analyzing conditional independencies in the data. Eaton & Murphy (2007) use dynamic programming techniques. Asymmetry-based methods (Shimizu et al., 2006; Hoyer et al., 2009; Peters et al., 2011; Daniusis et al., 2012; Budhathoki & Vreeken, 2017; Mitrovic et al., 2018) uses asymmetry between cause and effect to estimate the causal structure. 2016) propose a modeling framework that leverages existing methods. Data Type RCC DAG-EQ CSIvA Observational ✓ ✓ ✓ Interventional X X ✓ Linear dependencies ✓ ✓ ✓ Non-linear dependencies ✓ X ✓ Table 1 : Data-type comparison between CSIvA and other supervised approaches to causal structure induction (RCC (Lopez-Paz et al., 2015a; b) and DAG-EQ (Li et al., 2020) ). Learning-based methods have been proposed (Bengio et al., 2019; Goudet et al., 2018; Guyon, 2013; 2014; Kalainathan et al., 2018; Ke et al., 2020a; b; Lachapelle et al., 2022; Lopez-Paz et al., 2015b; Wang et al., 2021b; Zhu et al., 2019) . In particular, Zhu et al. (2019) ; Wang et al. (2021b) use transformers. These works are concerned with learning only part of the causal induction pipeline, such as the scoring function, and hence are significantly different from our work, which uses an end-to-end supervised learning approach to learn to map from datasets to graphs. Neural network methods equipped with learned masks exist (Douglas et al., 2017; Goyal et al., 2021; Ivanov et al., 2018; Li et al., 2019; Yoon et al., 2018) , but only a few have been adapted to causal inference. Several transformer models (Goyal et al., 2022; Kossen et al., 2021; Müller et al., 2021) have been proposed for learning to map from datasets to targets. However, none have been applied to causal discovery, although Löwe et al. (2022) proposes a neural-network based approach for causal discovery on time-series data. A few supervised learning approaches have been proposed either framing the task as a kernel mean embedding classification problem (Lopez-Paz et al., 2015a; b) or operating directly on covariance matrices (Li et al., 2020) or binary classification problem of identifying v-structures (Dai et al., 2021) . These models accept observational data only (see Table 1 ), and because causal identifiability requires both observational and interventional data, our model is in principle more powerful.

6. EXPERIMENTS

We report on a series of experiments of increasing challenge to our supervised approach to causal structure induction. First, we examined whether CSIvA generalizes well on synthetic data for which the training and test distributions are identical (Section 6.1). This experiment tests whether the model can learn to map from a dataset to a structure. Second, we examined generalization to an out-of-distribution (OOD) test distribution, and we determined hyperparameters of the synthetic data generating process that are most robust to OOD testing (Section 6.2). Third, we trained CSIvA using the hyperparameters from our second experiment and evaluated it on a different type of OOD test distribution from several naturalistic CBNs (Section 6.3). This experiment is the most important test of our hypothesis that causal structure of synthetic datasets can be a useful proxy for discovering causal structure in realistic settings. Lastly, we performed a set of ablation studies to analyze the performance of CSIvA under different settings (Section 6.4). All models are trained for 500k iterations, please refer to Appendix A.3 for details on the hyperparameters. Comparisons to baselines. We compared CSIvA to a range of methods considered to be state-of-theart in the literature, ranging from classic to neural-network based causal discovery baselines. For both in-distribution and OOD experiments, we compare to 4 very strong baselines: DAG-GNN (Yu et al., 2019) , non-linear ICP (Heinze-Deml et al., 2018b), DCDI (Brouillard et al., 2020) , and ENCO (Lippe et al., 2021) . For OOD experiments to naturalistic graphs, we compare to 5 additional baselines (Chickering, 2002; Hauser & Bühlmann, 2012; Zheng et al., 2018; Gamella & Heinze-Deml, 2020 ; CSIvA significantly outperforms all other baselines for all data types. Li et al., 2020) . Because some methods use only observational data and others do not scale well to large N, we could not compare to all methods in all experiments. See Appendix A.4 for further discussion of alternative methods and conditions under which they can be used.

6.1. IN-DISTRIBUTION EXPERIMENTS

We begin by investigating whether CSIvA can learn to map from data to structures in the case in which the training and test distributions are identical. In this setting, our supervised approach has an advantage over unsupervised ones, as it can learn about the training distribution and leverage this knowledge during testing. We evaluate CSIvA for N ≤ 80 graphs. We examined the performance on data with increasing order of difficulty, starting with linear (continuous data), before moving to non-linear cases (ANM, NN, MLP and Dirichlet data). See Fig. 4 for comparisons between CSIvA and strong baselines on all data types (N = 30), showing that CSIvA significantly outperforms baselines for a wide range of data types. For all but the Dirichlet data, we summarize the high-level results here, but detailed results and plots for all experiments can be found in Appendix A.5. Results on continuous linear data are presented in Tables 4, 5 and results on continuous ANM data are presented in Table 6 , results on continuous NN data are reported in Table 7 in The Dirichlet data requires setting the values of the parameter α. Hence, we run two sets of experiments on this data. In the first set of experiments, we investigated how different values of α impact learning in CSIvA. As shown in Table 10 in the appendix, CSIvA performs well on all data with α ≤ 0.5, achieving H < 2.5 in all cases. CSIvA still performs well when α = 1.0, achieving H < 5 on size 10 graphs. Learning with α > 1 is more challenging. This is not surprising, as α > 1 tends to generate more uniform distributions, which are not informative of the causal relationship between nodes. In the second set of experiments, we compared CSIvA to DCDI and ENCO, as they are the strongest performing baselines. We run the comparisons on graphs of size up to N ≤ 80. To limit the number of experiments to run, we set α = 0.1, as this allows the conditional probabilities to concentrate on non-uniform distributions, which is more likely to be true in realistic settings. As shown in Fig. 3 , our model significantly outperforms DCDI and ENCO on all graphs (up to N ≤ 80). As with other data types, the difference becomes more apparent as the size of the graphs grows (N ≥ 50), as larger graphs are more challenging to learn. Varying α. We next trained and evaluated on data generated from Dirichlet distributions with α ∈ {0.1, 0.25, 0.5}. Results for ER-1 graphs with N = 7 are found in Table 19 . There isn't a value of α that performs consistently well across different values of α for the test data. Nevertheless, α = 0.25 is a balanced trade-off and generalizes well across test data with 0.1 ≤ α ≤ 0.5.

6.3. SIM-TO-REAL EXPERIMENTS

In this final set of experiments, we evaluated CSIvA's ability to generalize from being trained on MLP and Dirichlet data to being evaluated on the widely used Sachs (Sachs et al., 2005) , Asia (Lauritzen & Spiegelhalter, 1988) and Child (Spiegelhalter & Cowell, 1992) We emphasize that all hyperparameters for the MLP and Dirichlet data and for the learning procedure were chosen a priori; only after the architecture and parameters were finalized was the model tested on these benchmarks. Furthermore, to keep the setup simple, we trained on data sampled from a single set of hyperparameters instead of a broader mixture. Findings in Section 6.2 suggest that ER-2 graphs with α = 0.25 work well overall and hence were chosen. Results are reported in Acyclicity. We analyzed the generated (sampled) graphs for acyclicity under several settings. We found that none of the generated graphs contains any cycles. For details, refer to Appendix A.7.3. Visualization of generated graphs. We visualized some of the generated graphs in Figure 7 and Figure 8 in the Appendix. Note that samples are randomly drawn, and all are acyclic. Identifiability upon seeing intervention data. Here, we investigate how varying the proportion of interventional data impacts the performance of the proposed model. As shown in Figure 9 (Appendix), the model's performance improves (Hamming distance decreases) almost monotonically as the amount of interventional data increases, from 0 -100%. This is a clear indication that our model is able to extract information from interventional data for identifying the graph structure. For more details on the results and the how the experiments are conducted, please refer to Section A.7.4. Amount of training cases. We conducted experiments to better understand how much synthetic data is needed for different graph sizes. For smaller graphs (N ≤ 25), there is a small improvement for using more than 10k training cases. For larger graphs (N > 25) the performance consistently improves as the model uses more training data (more details in Appendix A.7.6). Ablation studies. We investigate the role of various components in CSIvA. We found that all the different components play an important role in our model. For details, see Appendix A.7. Computation time. Our models are trained for 500k iterations. Training can take up to 8 hours, and then inference takes minutes. Training time of CSIvA can be amortized for multiple test cases; training time for baseline models are proportional to the number of test cases. For more details, please refer to Appendix A.7.9.)

7. DISCUSSION

In this paper, we have presented a novel approach towards causal graph structure inference. Our method is based on learning from synthetic data in order to obtain a strong learning signal (in the form of explicit supervision), using a novel transformer-based architecture which directly analyzes the data and computes a distribution of candidate graphs. Through a comprehensive and detailed set of experiments, we demonstrated that even though only trained on synthetic data, our model generalizes to out-of-distribution datasets, and robustly outperforms comparison methods under a wide range of conditions. A direction of future work would be to use the proposed framework for learning causal structure from raw visual data. This could be useful, e.g. in an RL setting in which an RL agent interacts with the environment via observing low level pixel data (Ahmed et al., 2020; Ke et al., 2021; Wang et al., 2021a) .

A.1 TRANSFORMER NEURAL NETWORKS

The transformer architecture, introduced in Vaswani et al. ( 2017), is a multi-layer neural network architecture using stacked self-attention and point-wise, fully connected, layers. The classic transformer architecture has an encoder and a decoder, but the encoder and decoder do not necessarily have to be used together. Scaled dot-product attention. The attention mechanism lies at the core of the transformer architecture. The transformer architecture uses a special form of attention, called the scaled dot-product attention. The attention mechanism allows the model to flexibility learn to weigh the inputs depending on the context. The input to the QKV attention consists of a set of queries, keys and value vectors. The queries and keys have the same dimensionality of d k , and values often have a different dimensionality of d v . Transformers compute the dot products of the query with all keys, divide each by √ d k , and apply a softmax function to obtain the weights on the values. In practice, transformers compute the attention function on a set of queries simultaneously, packed together into a matrix Q. The keys and values are also packed together into matrices K and V . The matrix of outputs is computed as: Attention(Q, K, V ) = softmax( QK T √ d k )V . Encoder. The encoder is responsible for processing and summarizing the information in the inputs. The encoder is composed of a stack of N identical layers, where each layer has two sub-layers. The first sub-layer consists of a multi-head self-attention mechanism, and the second is a simple, position-wise fully connected feed-forward network. Transformers employ a residual connection (He et al., 2016) around each of the two sub-layers, followed by layer normalization (Ba et al., 2016) . That is, the output of each sub-layer is LayerNorm(x + Sublayer(x)), where Sublayer(x) is the function implemented by the sub-layer itself. Decoder. The decoder is responsible for transforming the information summarized by the encoder into the outputs. The decoder also composes of a stack of N identical layers, with a small difference in the decoder transformer layer. In addition to the two sub-layers in each encoder layer, a decoder layer consists of a third sub-layer. The third sub-layer performs a multi-head attention over the output of the encoder stack. Similar to the encoder, transformers employ residual connections around each of the sub-layers, followed by layer normalization. Transformers also modify the self-attention sub-layer in the decoder stack to prevent positions from attending to subsequent positions. This masking, combined with fact that the output embeddings are offset by one position, ensures that the predictions for position i can depend only on the known outputs at positions less than i.

A.2 SYNTHETIC DATA

In Section 4, we introduced the types of methods used to generate the conditional probability distribution. In this section, we discuss the details of these methods. We evaluate our model on datasets generated from 5 different methods, which covers both continuous and discrete valued data generated with a varying degree of difficulty. For continuous data, we generated the data using three different methods: linear, non-linear additive noise model (ANM) and non-linear non-additive noise neural network model (NN). For discrete data, we generated data using two different methods: MLP and Dirichlet. Let X be a N × S matrix representing S samples of a CBN with N nodes and weighted adjacency matrix A, and let Z be a random matrix of elements in N (0, 0.1). We describe each type of how different types of data is generated. • For linear data, we follow the setup in Zheng et al. (2018) and Yu et al. (2019) . Specifically, we generated data as X n,: = A n,: X + Z n,: . The biases were initialized using U [-0.5, 0.5], and the individual weights were initialized using a truncated normal distribution with standard deviation of 1.5. For nodes with interventions, values are sampled from the uniform distribution U [-1, 1]. • For additive noise models (ANM) , we follow the setup in (Brouillard et al., 2020; Lippe et al., 2021) . We generated the data as X n,: = F n,: (X) + 0.4 • Z n,: , where F is fully connected neural network, the weights are randomly initialized from N (0, 1). The neural network has one hidden layer with 10 hidden units, the activation function is a leaky relu with a slop of 0.25. The noise variables are sampled from N (0, σ 2 ), where σ 2 ∼ U[1, 2]. For nodes with interventions, values are sampled from the uniform distribution N (2, 1). • For non-additive noise neural network (NN) models , we also follow the setup in (Brouillard et al., 2020; Lippe et al., 2021) . We generate the data as X n,: = F n,: (X, Z n,: ), such that F is fully connected neural network, the weights are randomly initialized from N (0, 1). The neural network has one hidden layer with 20 hidden units, the activation function is a tanh function. The noise variables are sampled from N (0, σ 2 ), where For all of our experiments (unless otherwise stated) our model was trained on I = 15, 000 (for graphs N ≤ 20) and on I = 20, 000 (for graphs N > 20) pairs {(D i , A i )} I i=1 , where each dataset D i contained S = 1500 observational and interventional samples. For experiments on discrete data, a data-sample element x s could take values in {1, 2, 3}. Details of the data generating process can be found in Section 4. For evaluation in Sections 6.1 and 6.2, our model was tested on σ 2 ∼ U[ I ′ = 128 (different for the training) pairs {(D i ′ , A i ′ )} I ′ i ′ =1 , where each dataset D i ′ contained S = 1500 observational and interventional samples. For the Asia, Sachs and Child benchmarks, our model was still tested on I ′ = 128 (different for the training) pairs {(D i ′ , A i ′ )} I ′ i ′ =1 , however, A i ′ = A j ′ since there is only a single adjacency matrix in each one of the benchmarks. We present test results averaging performance over the 128 datasets and 3 random seeds and up to size N = 80 graphs. The model was trained for 500, 000 iterations using the Adam optimizer (Kingma & Ba, 2014) with a learning rate of 1e-4. Also, refer to Table 3 for the list of hyperparameters presented in a table. We parameterized our architecture such that inputs to the encoder were embedded into 128dimensional vectors. The encoder transformer has 10 layers and 8 attention-heads per layer. The final attention step for summarization has 8 attention heads. The decoder was a smaller transformer with only 2 layers and 8 attention heads per layer. Discrete inputs were encoded using an embedding layer before passing into our model.

A.4 COMPARISONS TO BASELINES

In Section 6.1, we compare CSIvA to four strong baselines in the literature, ranging from classic causal discovery baselines to neural-network based causal discovery baselines. These baselines are DAG-GNN (Yu et al., 2019 ), non-linear ICP (Heinze-Deml et al., 2018b) , DCDI (Brouillard et al., 2020) and ENCO (Lippe et al., 2021) . Non-linear ICP, DCDI and ENCO can handle both observational and interventional data, while DAG-GNN can only use observational data. Non-linear ICP could not scale to graphs larger than 20, therefore we compare to DCDI (Brouillard et al., 2020) and ENCO (Lippe et al., 2021) on larger graphs. All baselines are unsupervised methods, i.e. they are not tuned to a particular training dataset but instead rely on a general-purpose algorithm. We also compared to an all-absent model corresponding to a zero adjacency matrix, which acts as a sanity check baseline. We also considered other methods (Chickering, 2002; Hauser & Bühlmann, 2012; Zhang et al., 2012; Gamella & Heinze-Deml, 2020) , but only presented a comparison with DCDI, ENCO, non-linear ICP and DAG-GNN as these have shown to be strong performing models in other works (Ke et al., 2020a; Lippe et al., 2021; Scherrer et al., 2021) . For Section 6.3, we also compared to additional baselines from Chickering ( 2002 2020) take observational data only. DAG-GNN outputs several candidate graphs based on different scores, such as evidence lower bound or negative log likelihood, DCDI can also be run in two different settings (DCDI-G and DCDI-DSF), we chose the best result to compare to our model. Note that non-linear ICP does not work on discrete data, i.e. on the MLP and Dirichlet data, therefore a small amount of Gaussian noise N (0, 0.1) was added to this data in order for the method to run. A.5 DETAILED RESULTS FOR SECTION 6.1 We present detailed results for experiments in Section 6.1 are described in the tables below. Results for comparions between our model CSIvA and baselines non-linear ICP (Heinze-Deml et al., 2018b) , DAG-GNN (Yu et al., 2019) , DCDI (Brouillard et al., 2020) and ENCO (Lippe et al., 2021) are shown in Table 4 for smaller graphs with N ≤ 20 and Table 5 for larger graphs with 20 < N ≤ 80. For smaller graphs with N ≤ 20, all models that takes interventional data perform significantly better compared to DAG-GNN (Yu et al., 2019) , which only takes observational data. CSIvA achieves Hamming distance H < 7 on evaluated graphs up to size 20. Similar to previous findings (Yu et al., 2019; Ke et al., 2020a) , larger and denser graphs are more challenging to learn. Non-linear ICP achieves fairly good performance on smaller graphs ( N ≤ 10), however, the performance drops quickly as size of graphs increases (N > 10). Also, note that Non-linear ICP can not scale to graphs larger than N > 20. It also required a modificationWithout this modification, the method achieved near chance performance. to the dataset wherein multiple samples were collected from the same modified graph after a point intervention (20 samples per intervention), while other methods only sampled once per intervention. For larger graphs of 20 < N ≤ 280, we compare to strongest baselines: DCDI and ENCO. Results are found in Table 5 . CSIvA significantly outperforms both DCDI and ENCO. The difference because more apparent as the size of the graphs grow (N ≥ 50). Table 5 : Results on Linear data on larger graphs. Hamming distance H for learned and groundtruth edges on synthetic graphs, compared to other methods. The number of variables varies from 30 to 80, expected degree = 1 or 2. We compare to DCDI (Brouillard et al., 2020) and ENCO (Lippe et al., 2021) as they are the best performing baselines and they are able to handle larger graphs.

A.5.2 RESULTS ON ANM DATA

For additive noise non-linear model (ANM) data, we compare to the strongest baseline models: DCDI and ENCO on N ≤ 80 graphs. Results are found in Table 6 . CSIvA achives hamming distance H < 11 on all graphs of size up to 80. Therefore, CSIvA significantly outperforms strong baseline models DCDI and ENCO on all graphs. This difference becomes more apparent on larger graphs of N ≥ 50. Hamming distance H for learned and ground-truth edges on synthetic graphs, compared to other methods. The number of variables varies from 30 to 80, expected degree = 1 or 2. We compare to DCDI (Brouillard et al., 2020) and ENCO (Lippe et al., 2021) as they are the best performing baselines and they are able to handle larger graphs.

A.5.3 RESULTS ON NN DATA

Results for comparisons between CSIvA and strongest baselines DCDI and ENCO on non-additive noise non-linear (NN) data are found in Table 7 . CSIvA achieves hamming distance H < 11 on all graphs. Hence, CSIvA significantly outperforms DCDI and ENCO on all graphs. The differences grow larger as the size of the graph grows. Additionally, we generated non-additive noise non-linear (NN) data with scale-free (SF) graphs with degree of 4 and 6 for N ≤ 50. We compare our model to strongest baselines DCDI and ENCO on these data. The results are found in Table 15 . Table 7 : Results on continuous non-linear non-additive noise model (NN) data with larger Erdős-Rényi (ER) graphs. Hamming distance H for learned and ground-truth edges on synthetic graphs, compared to other methods. The number of variables varies from 30 to 80, expected degree = 1 or 2. We compare to DCDI (Brouillard et al., 2020) and ENCO (Lippe et al., 2021) as they are the best performing baselines and they are able to handle larger graphs. ) and DAG-GNN (Yu et al., 2019) on MLP data are shown in Table 9 . MLP data is non-linear and hence more challenging compared to the continuous linear data. Our model CSIvA significantly outperforms non-linear ICP and DAG-GNN. The difference becomes more apparent as the graph size grows larger and more dense. Table 8 : Results on discrete MLP data with larger graphs. Hamming distance H for learned and ground-truth edges on synthetic graphs, compared to other methods. The number of variables varies from 30 to 80, expected degree = 1 or 2. We compare to DCDI (Brouillard et al., 2020) and ENCO (Lippe et al., 2021) as they are the best performing baselines and they are able to handle larger graphs. Table 9 : Results on MLP data. Hamming distance H for learned and ground-truth edges on synthetic graphs, compared to other methods, averaged over 128 sampled graphs (± standard deviation). The number of variables varies from 5 to 20, expected degree = 1 or 2, and the dimensionality of the variables are fixed to 3. We compared to DAG-GNN (Yu et al., 2019) , which is a strong baseline that uses observational data. We also compare to Non-linear ICP (Heinze-Deml et al., 2018b) , which is a strong baseline that uses interventional data. Note that for (Heinze-Deml et al., 2018b) , the method required nodes to be causally ordered, and Gaussian noise N (0, 0.1) to be added. "Abs" baselines are All-Absent baselines, which is an baseline model that outputs all zero edges for the adjacency matrix. A.5.5 RESULTS ON DIRICHLET DATA. We run two sets of experiments on Dirichlet data. The first set of experiments is aimed at understanding how different α values impact the performance of our model. In the second set of experiments, we compare the performance of our model to four strong baselines: DCDI (Brouillard et al., 2020) , ENCO (Lippe et al., 2021) , non-linear ICP (Heinze-Deml et al., 2018b) and DAG-GNN (Yu et al., 2019) . All-absent baseline 2.5 3.5 5.0 5.0 7.0 10.0 (Yu et al., 2019) 1 . The number of variables varies from 5 to 10, expected degree = 1 or 2, the dimensionality of the variables are fixed to 3, and the α is fixed to 1.0. We compare to the strongest causal-induction methods that uses observational data (Yu et al., 2019) and the strongest that uses interventional data (Heinze-Deml et al., 2018b) . to non-linear ICP and DAG-GNN, both achieving a significantly higher Hamming distance (H = 9.3 and H = 9.5 respectively) on larger and denser graphs. Refer to Table 11 for complete sets of results. For larger graphs (N > 20), we compare CSIvA to DCDI (Brouillard et al., 2020) and ENCO (Lippe et al., 2021) as they could scale to larger graphs and are the strongest baselines. The results are illustrated in Figure 3 and the detailed results are found in Table 12 . As the graphs get larger, the performance of baselines DCDI and ENCO drops significantly, for dense and largr graphs (ER-2, N = 80), the baseline models achieve almost near all-absent performance, while our model performs significantly better (achieving almost 30% performance gain in terms of structured hamming distance). Dataset All-absent DCDI (Brouillard et al., 2020) Table 12 : Results on Dirichlet data with larger graphs. Hamming distance H for learned and ground-truth edges on synthetic graphs, compared to other methods. The number of variables varies from 30 to 80, expected degree = 1 or 2, and the α is fixed to 0.1. We compare to the approaches from (Brouillard et al., 2020) and (Lippe et al., 2021) as they are the strongest baselines and they are able to handle larger graphs.

A.5.6 RESULTS ON SCALE-FREE GRAPHS

We further evaluated our model's performance on on scale-free (SF) graphs. We run experiments for both discrete and continuous variables. For discrete variables. We train and test our model on Dirichlet data with α = 0.1 on scale-free (SF) graphs. We run them on SF graphs with degree 4 and 6 similar to the setup in Zheng et al. (2018) . Results on Dirichlet data is reported in Table 14 and results on NN data is found in Table 15 . Our model significantly outperforms strong baseline models DCDI (Brouillard et al., 2020) and ENCO (Lippe et al., 2021) in all cases, regardless of the type of data, size of the graph and the sparsity of the graphs. 

Soft interventions on

Dirichlet data is generated as follows. For an intervention on variable X i , we first sample the α i for the Dirichlet distribution from a Uniform distribution U {0.1, 0.3, 0.5, 0.7, 0.9}, then using the sampled α i , we sample the conditional probabilities from the Dirichlet distribution with the new α i . The results for CSIvA on soft interventions are shown in Table 13 . CSIvA was able to achieve a hamming distance H < 5 on all graphs of size H ≤ 20. These results are strong indications that that our model still works well on the imperfect interventions.

A.5.8 UNKNOWN INTERVENTIONS

Until now, we have considered cases where the target of the intervention is known, we now consider the case when the target of the intervention is unknown, this is also referred to as unknown interventions. Again, to limit the number of experiments to run, we focus on Dirichlet data with α = 0.1 and α = 0.25. The model does not know the target of the intervention and all other training procedures remains exactly the same. The results are shown in Table 16 . We compare how well CSIvAperforms on known and unknown interventions. We can see that the performance for the known and unknown interventions are almost the same for sparser graphs (ER-1). The differences increases slightly for denser graphs with higher values of α, for example ER -2 graphs with α = 0.25. This is to be expected, as denser graphs with higher α values are more challenging to learn. The biggest differences for the performances of known and unknown interventions is less than H < 1.0 in hamming distance. These results shown clear indications that our model performs well even when the intervention target is unknown. a redundant edge, such that our model generated an edge that does not exist in the groundtruth graph. As shown in Figure 8 and 7, our model is able to generate the correct graph almost all of the times.

A.7.2 ADDITIONAL EVALUATIONS

We evaluated our model on additional evaluation metrics to better understand the performance of the model. In particular, we evaluated our model on the balanced scoring function (BSF), F1 score, precision, recall and area under roc (AUROC). We evaluated our model both on discrete and continous data. For discrete data, we used Dirichlet data with α = 0.1 and for continous data, we used nonlinear non-additive noise (NN) data. Results on Dirichlet data is in Table 22 , results on NN data is in Table 23 . For Dirichlet data, our model achieved strong results. Again, the area under ROC (AUROC) is greater than 0.95 for all cases. The F1 score is also always above 0.90. For continuous data, our model achieved even stronger results. Our model scored between 0.98 and 0.99 for all scores. This is an indication that the edges were almost always correct and the correct edges also have strong probability. We evaluated CSIvA on different amount of samples (100, 200, 500, 1000, 1500) per CBNs. To limit the number of experiments to run, we focus on Dirichlet data sampled from N = 10 graphs. During the ith iteration in training, the model takes in as input a dataset with l i samples, where l i is sampled from [100, 200, 500, 1000, 1500] . During test time, we run separate evaluations on the model for test datasets with different number of samples and the results are reported in Figure 10 . We can see that the model performance improves as it observes up to 1000 samples for ER-1 graphs, whereas having 1500 samples gives slightly better results compared to 1000 samples for ER-2 graphs. Table 24 : Results on Dirichlet data trained with varying amount of training datasets. Hamming distance H for learned and ground-truth edges on synthetic graphs, compared to other methods. The number of variables varies from 10 to 50, expected degree = 1 or 2, and the α is fixed to 0.1. U

A.7.7 INTERVENTIONS ON A FEW VARIABLES

We performed additional experiments to understand the performance of our model when only a few variables are intervened on. We varied the percentage of variables to intervene on from 10% to 100%.



DeepMind, Mila, Polytechnique Montreal, University of Montreal, Google Research, Brain Team, Corresponding author: nke@google.com



Figure 1: (A). Standard unsupervised approach to causal structure induction: Algorithms use a predefined scoring metric or statistical independence tests to select the best candidate structures. (B). Our supervised approach to causal structure induction (via attention; CSIvA): A model is presented with data and structures as training pairs and learns a mapping between them.

Peters et al. (2016); Ghassami et al. (2017); Rojas-Carulla et al. (2018); Heinze-Deml et al. (2018a) exploit invariance across environments. Mooij et al. (

(a) Results on N = 30 and ER-1 graphs. (b) Results on N = 30 and ER-2 graphs.

Figure 4: Hamming distance H between predicted and ground-truth adjacency matrices on 5 different data types: Linear, ANM, NN, MLP and Dirichlet data, compared to DCDI (Brouillard et al., 2020) and ENCO (Lippe et al., 2021) on N = 30 graphs, averaged over 128 sampled graphs. The dotted line indicates the value of the all-absent baseline. CSIvA significantly outperforms all other baselines for all data types.

Figure 3: Results on Dirichlet data for 30 ≤ N ≤ 80. Hamming distance H between predicted and ground-truth adjacency matrices, averaged over 128 sampled graphs. CSIvA significantly outperforms DCDI and ENCO, both of which are very strong baselines. The difference in performance increases with N .

Appendix. Results on MLP data are shown in Fig. 5(b) (Appendix). For all data types and all graphs, CSIvA significantly outperforms non-linear ICP, DAG-GNN, DCDI and ENCO. Differences become more apparent with larger graph sizes (N ≥ 10) and denser graphs (ER-2 vs ER-1).

CBNs from the bnlearn repository, which have N = 11, N = 8 and N = 20 nodes respectively. We followed the established protocol from Ke et al. (2020a); Lippe et al. (2021); Scherrer et al. (2021) of sampling observational and interventional data from the CBNs provided by the repository. These experiments are the most important test of our hypothesis that causal structure of synthetic datasets can be a useful proxy for discovering causal structure in realistic settings.

); Hauser & Bühlmann (2012); Zheng et al. (2018); Gamella & Heinze-Deml (2020); Li et al. (2020). Note that methods from Chickering (2002); Zheng et al. (2018); Yu et al. (2019); Li et al. (



DatasetAll-absent DCDI(Brouillard et al., 2020) ENCO(Lippe et al., 2021) CSIvA N = 30,

DatasetAll-absent DCDI(Brouillard et al., 2020) ENCO(Li et al., 2019) CSIvA N = 30, ER =

Results on Linear data on N ≤ 20 graphs. (b) Results on MLP data on N ≤ 20 graphs.

Figure 5: Hamming distance H between predicted and ground-truth adjacency matrices on the Linear and MLP data for N ≤ 20 graphs, compared to DAG-GNN (Yu et al., 2019) and non-linear ICP (Heinze-Deml et al., 2018b), averaged over 128 sampled graphs. CSIvA significantly outperforms all other baselines.

DatasetAll-absent DCDI(Brouillard et al., 2020) ENCO(Lippe et al., 2021) CSIvA

Results on Dirichlet data. Hamming distance H (lower is better) for learned and groundtruth edges on synthetic graphs, averaged over 128 sampled graphs. Our model accomplished a hamming distance of less than 2.5 for Dirichlet data with α ≤ 0.5. "Abs" baselines are All-Absent baselines, which is an baseline model that outputs all zero edges for the adjacency matrix.

Figure 6: Results on Dirichlet data on N ≤ 20 graphs. Hamming distance H between predicted and ground-truth adjacency matrices on Dirichlet data, averaged over 128 sampled graphs.

1 α = 0.25 α = 0.1 α = 0.25 α = 0.1 α = 0.25 α = 0.1 α = 0.25 Known 1.

Figure 7: This figures visualizes samples that our model generated on test data. The model was trained and tested on MLP data of size 5 with ER-1 graphs. The samples are randomly chosen. The green edges indicate that our model has generated the correct edges; red edges indicate edges that our model had missed; and blue edges are the ones that our model generated, which were not in the groundtruth graph. As shown above, our model is able to generate the correct graph almost all of the times, while only occasionally generating 1 or 2 incorrect edges in a graph.

Figure 8: This figures visualizes samples that our model generated on test data. The model was trained and tested on MLP data of size 5 with ER-2 graphs. The samples are randomly chosen. The green edges indicate that our model has generated the correct edges; red edges indicate edges that our model had missed; and blue edges are the ones that our model generated, which were not in the groundtruth graph. As shown above, our model is able to generate the correct graph almost all of the times, while only occasionally generating 1 or 2 incorrect edges in a graph.

In this section, we evaluate how the amount of training data impacts the performance of our model (CSIvA ). To limit the number of experiments to run, we focus on Dirichlet data with α = 0.1. The model receives M amount of training datasets, where M is sampled from U[5000, 10000, 15000, 20000].

Soft interventions and unknown interventions. We additionally evaluated CSIvA on soft interventions, and on hard interventions for which the model did not know the intervention node (unknown interventions). In both cases, we focus on Dirichlet data with α = 0.1 and α = 0.25 for ER-1 and ER-2 graphs of size N ≤ 20. Results for the performance of CSIvA on soft interventions are shown in Table13in the Appendix. CSIvA performed well on all graphs, achieving Hamming distance H < 5 on all graphs. For details about the data generation process, refer to Appendix Section A.5.7.Results for the performance of CSIvA on unknown interventions are shown in Appendix Table16. CSIvA achieves strong performance on all graphs, with Hamming distance H < 7 on all graphs. Note that CSIvA loses almost no performance despite not knowing the intervention; the biggest difference in Hamming distance between known and unknown interventions on all graphs is H < 1.0. For more details, refer to Section A.5.8 in the Appendix.



Results on Asia, Sachs and Child data: Hamming distance H between predicted and ground-truth adjacency matrices. *To maintain computational tractability, the number of parents considered was limited to 3.

1, 2]. For nodes with interventions, values are sampled from the uniform distribution N (2, 1). • For MLP data, the neural network has two fully connected layers of hidden dimensionality 32. Following past work (Ke et al., 2020a; Scherrer et al., 2021; Lippe et al., 2021), we used a randomly initialized network. For nodes with interventions, values are randomly and independently sampled from U {1, . . . , K} where K indicates the number of categories of the discrete variable. • For Dirichlet data, the generator filled in the rows of a conditional probability table by sampling a categorical distribution from a Dirichlet prior with symmetric parameters α. Values of α smaller than 1 encourage lower entropy distributions; values of α greater than 1 provide less information about the causal relationships among variables. Similar to MLP data, for nodes with interventions, values are randomly and independently sampled from U {1, . . . , K} where K indicates the number of categories of the discrete variable. Hyperparameters used in all of our experiments.

± 0.03 0.35 ± 0.05 2.01 ± 0.07 2.21 ± 0.07 0.81 ± 0.05 1.73 ± 0.04 5.32 ± 0.19 5.86 ± 0.21 Results on Continuous linear data. Hamming distance H for learned and ground-truth edges on synthetic graphs, compared to other methods, averaged over 128 sampled graphs. The number of variables varies from 5 to 20, expected degree = 1 or 2, and the value of variables are drawn from N (0, 0.1). Note that for (Heinze-Deml et al., 2018b), the method required nodes to be causally ordered, and 20 repeated samples taken per intervention, as interventions were continuously valued. "Abs" baselines are All-Absent baselines, which is an baseline model that outputs all zero edges for the adjacency matrix. DCDI and ENCO are the strongest performing baselines. CSIvA outperforms all baselines (including DCDI and ENCO).

Results on continuous non-linear additive noise model (ANM) data with larger graphs.

Results on Dirichlet data. Hamming distance H for learned and ground-truth edges on synthetic graphs, compared to other methods, averaged over 128 sampled graphs (± standard deviation)

ENCO (Lippe et al., 2021)  CSIvA

Results on imperfect interventions. Hamming distance H for learned and ground-truth edges on synthetic graphs, compared to other methods, averaged over 128 sampled graphs (± standard deviation). The number of variables varies from 10 to 20, expected degree = 1 or 2, the dimensionality of the variables are fixed to 3, and the α is fixed to 0.1 or 0.25.

Results on Dirichlet data on Scale-Free graphs. Hamming distance H for learned and ground-truth edges on synthetic graphs, compared to other methods. The number of variables varies from 30 to 80, expected degree = 4 or 6, and the α is fixed to 0.1. We compare to the approaches fromBrouillard et al. (2020) andLippe et al. (2021) as they are the strongest baselines and they are able to handle larger graphs.

DatasetAll-absent DCDI(Brouillard et al., 2020) ENCO(Lippe et al., 2021) CSIvA Results on continuous non-linear non-additive noise model (NN) data on Scale-Free (SF) graphs. Hamming distance H for learned and ground-truth edges on synthetic graphs, compared to other methods. The number of variables varies from 30 to 80, expected degree = 4 or 6, and the α is fixed to 0.1. We compare to the approaches fromBrouillard et al. (2020) andLippe et al. (2021) as they are the strongest baselines and they are able to handle larger graphs.

Results on unknown interventions compared to known interventions.Hamming distance H from learned and ground-truth edges on synthetic graphs, for known vs unknown interventions, averaged over 128 sampled graphs (± standard deviation). The number of variables varies from 15 to 20, expected degree = 1 or 2, α = 0.1 or 0.25, and the dimensionality of the variables are fixed to 3.Results for the experiments on varying graph density in Section 6.2 are shown in Table17 and Table 18. Our model generalizes well to OOD test distribution where the graphs vary in terms of density.A.6.2 VARYING CONDITIONAL DISTRIBUTIONThe results on varying α values in Section 6.2 are found in Table19. Our model generalizes well to OOD test distributions, where the conditional probability can vary in terms of α values for the Dirichlet distributions.

Results on varying graph density for MLP data: Hamming distance H between predicted and ground-truth adjacency matrices.

Results

Results on Dirichlet data evaluated with additional evaluation metrics Hamming distance H for learned and ground-truth edges on synthetic graphs, compared to other methods. The number of variables varies from 10 to 20, expected degree = 1 or 2, and the α is fixed to 0.1.

annex

We also evaluate how well our model generalizes when the model used for generating the graphs during training and test differs. We trained and tested our model on Erdős-Rényi (ER) and Scale-Free (SF) graph models. To limit the number of experiments to run, we evaluate our model on Dirichlet data with α = 0.1. We run experiments both on N = 10 and N = 20 graphs. The models are trained on Erdős-Rényi (ER) graphs and tested on Scale-Free (SF) graphs. The results are shown in Table 20 and Table 21 .

A.7 ABLATION STUDIES AND ANALYSES

In order to better understand the performance of our model, we performed further analysis and ablations. This section contains full results for Section 6.4. In this section, we aim to answer the following questions:• What does visualizations of the generated graphs look like?• Additional evaluation metrics to measure structural accuracy.• Are these generated graphs from our model acyclic?• Does intervention data improve identifiability in our model?• How does varying the number of samples impact the performance of our model?• How does varying the amount of training datasets impact the peformance of our model?• How does different components of our model (such as sample-level attention and auxiliary loss) impact the performance of our model?

A.7.1 VISUALIZATION OF SAMPLED GRAPHS

We visualized samples that our model generated on the test data. The samples are shown in Figure 8 and Hamming distance H from learned and ground-truth edges on synthetic graphs, for known vs unknown interventions, averaged over 128 sampled graphs (± standard deviation). The number of variables varies from 15 to 20, expected degree = 1 or 2, α = 0.1 or 0.25, and the dimensionality of the variables are fixed to 3.

Auxiliary loss

We also conducted ablation studies to understand the effect of the auxiliary loss in the objective function, results are reported in Table 27 . The model with auxiliary loss is the vanilla CSIvA model, where as the one without is CSIvA trained without the auxiliary loss objective.Experiments are conducted on Dirichlet data with α = 0.1, the model does gain a small amount of performance (H< 1.0) by having the auxiliary loss, the difference becomes more apparent as the size of the graph grows, indicating that auxiliary loss plays an important role as a part of the objective function.A.7.9 COMPUTATION TIME All the models with the proposed method are trained for 500k iterations and all models have converged by then (in fact they have all converged by 400k). Once the model is trained, evaluation/ inference is fast. All models are evaluated on 128 test datasets (graphs). 

