PREDICTING CELLULAR RESPONSES WITH VARIATIONAL CAUSAL INFERENCE AND REFINED RELATIONAL INFORMATION

Abstract

Predicting the responses of a cell under perturbations may bring important benefits to drug discovery and personalized therapeutics. In this work, we propose a novel graph variational Bayesian causal inference framework to predict a cell's gene expressions under counterfactual perturbations (perturbations that this cell did not factually receive), leveraging information representing biological knowledge in the form of gene regulatory networks (GRNs) to aid individualized cellular response predictions. Aiming at a data-adaptive GRN, we also developed an adjacency matrix updating technique for graph convolutional networks and used it to refine GRNs during pre-training, which generated more insights on gene relations and enhanced model performance. Additionally, we propose a robust estimator within our framework for the asymptotically efficient estimation of marginal perturbation effect, which is yet to be carried out in previous works. With extensive experiments, we exhibited the advantage of our approach over state-of-the-art deep learning models for individual response prediction.

1. INTRODUCTION

Studying a cell's response to genetic, chemical, and physical perturbations is fundamental in understanding various biological processes and can lead to important applications such as drug discovery and personalized therapies. Cells respond to exogenous perturbations at different levels, including epigenetic (DNA methylation and histone modifications), transcriptional (RNA expression), translational (protein expression), and post-translational (chemical modifications on proteins). The availability of single-cell RNA sequencing (scRNA-seq) datasets has led to the development of several methods for predicting single-cell transcriptional responses (Ji et al., 2021) . These methods fall into two broad categories. The first category (Lotfollahi et al., 2019; 2020; Rampášek et al., 2019; Russkikh et al., 2020; Lotfollahi et al., 2021a) approaches the problem of predicting single cell gene expression response without explicitly modeling the gene regulatory network (GRN), which is widely hypothesized to be the structural causal model governing transcriptional responses of cells (Emmert-Streib et al., 2014) . Notably among those studies, CPA (Lotfollahi et al., 2021a) uses an adversarial autoencoder framework designed to decompose the cellular gene expression response to latent components representing perturbations, covariates and basal cellular states. CPA extends the classic idea of decomposing high-dimensional gene expression response into perturbation vectors (Clark et al., 2014; 2015) , which can be used for finding connections among perturbations (Subramanian et al., 2017) . However, while CPA's adversarial approach encourages latent indepen-dence, it does not have any supervision on the counterfactual outcome construction and thus does not explicitly imply that the counterfactual outcomes would resemble the observed response distribution. Existing self-supervised counterfactual construction frameworks such as GANITE (Yoon et al., 2018 ) also suffer from this problem. The second class of methods explicitly models the regulatory structure to leverage the wealth of the regulatory relationships among genes contained in the GRNs (Kamimoto et al., 2020) . By bringing the benefits of deep learning to graph data, graph neural networks (GNNs) offer a versatile and powerful framework to learn from complex graph data (Bronstein et al., 2017) . GNNs are the de facto way of including relational information in many health-science applications including molecule/protein property prediction (Guo et al., 2022; Ioannidis et al., 2019; Strokach et al., 2020; Wu et al., 2022a; Wang et al., 2022) , perturbation prediction (Roohani et al., 2022) and RNAsequence analysis (Wang et al., 2021) . In previous work, Cao & Gao (2022) developed GLUE, a framework leveraging a fine-grained GRN with nodes corresponding to features in multi-omics datasets to improve multimodal data integration and response prediction. GEARS (Roohani et al., 2022) uses GNNs to model the relationships among observed and perturbed genes to predict cellular response. These studies demonstrated that relation graphs are informative for predicting cellular responses. However, GLUE does not handle perturbation response prediction, and GEARS's approach to randomly map subjects from the control group to subjects in the treatment group is not designed for response prediction at an individual level (it cannot account for heterogeneity of cell states). GRNs can be derived from high-throughput experimental methods mapping chromosome occupancy of transcription factors, such as chromatin immunoprecepitation sequencing (ChIP-seq), and assay for transposase-accessible chromatin using sequencing (ATAC-seq). However, GRNs from these approaches are prone to false positives due to experimental inaccuracies and the fact that transcription factor occupancy does not necessarily translate to regulatory relationships (Spitz & Furlong, 2012) . Alternatively, GRNs can be inferred from gene expression data such as RNA-seq (Maetschke et al., 2014) . It is well-accepted that integrating both ChIP-seq and RNA-seq data can produce more accurate GRNs (Mokry et al., 2012; Jiang & Mortazavi, 2018; Angelini & Costa, 2014) . GRNs are also highly context-specific: different cell types can have very distinctive GRNs mostly due to their different epigenetic landscapes (Emerson, 2002; Davidson, 2010) . Hence, a GRN derived from the most relevant biological system is necessary to accurately infer the expression of individual genes within such system. In this work, we employed a novel variational Bayesian causal inference framework to construct the gene expressions of a cell under counterfactual perturbations by explicitly balancing individual features embedded in its factual outcome and marginal response distributions of its cell population. We integrated a gene relation graph into this framework, derived the corresponding variational lower bound and designed an innovative model architecture to rigorously incorporate relational information from GRNs in model optimization. Additionally, we propose an adjacency matrix updating technique for graph convolutional networks (GCNs) in order to impute and refine the initial relation graph generated by ATAC-seq prior to training the framework. With this technique, we obtained updated GRNs that discovered more relevant gene relations (and discarded insignificant gene relations in this context) and enhanced model performance. Besides, we propose an asymptotically efficient estimator for estimating the average effect of perturbations under a given cell type within our framework. Such marginal inference is of great biological interest because scRNA-seq experimental results are typically averaged over many cells, yet robust estimations have not been carried out in previous works on predicting cellular responses. We tested our framework on three benchmark datasets from Srivatsan et al. (2020), Schmidt et al. (2022) and a novel CROP-seq genetic knockout screen that we release with this paper. Our model achieved state-of-the-art results on out-of-distribution predictions on differentially-expressed genes -a task commonly used in previous works on perturbation predictions. In addition, we carried out ablation studies to demonstrate the advantage of using refined relational information for a better understanding of the contributions of framework components.

2. PROPOSED METHOD

In this section we describe our proposed model -Graph Variational Causal Inference (graphVCI), and a relation graph refinement technique. A list of all notations can be found in Appendix A. In the causal diagram, white nodes are observed and dark grey nodes are unobserved; dashed relations are optional (case dependant). In the context of this paper, graph G is a deterministic variable that is invariant across all individuals. T ′ Y ′ X T G Z Y (a) folded X T T ′ Z Y Y ′ G (b) unfolded

2.1. COUNTERFACTUAL CONSTRUCTION FRAMEWORK

We define outcome Y : Ω → R n to be a n-dimensional random vector (e.g. gene expressions), X : Ω → E X to be a m-dimensional mix of categorical and real-valued covariates (e.g. cell types, donors, etc.), T : Ω → E T to be a r-dimensional categorical or real-valued treatment (e.g. drug perturbation) on a probability space (Ω, Σ, P ). We seek to construct an individual's counterfactual outcome under counterfactual treatment a ∈ E T from two major sources of information. One source is the individual features embedded in high-dimensional outcome Y . The other source is the response distribution of similar subjects (subjects that have the same covariates as this individual) that indeed received treatment a. We employ the variational causal inference (Wu et al., 2022b) framework to combine these two sources of information. In this framework, a covariate-dependent feature vector Z : Ω → R d dictates the outcome distribution along with treatment T ; counterfactuals Y ′ and T ′ are formulated as separate variables apart from Y and T with a conditional outcome distribution p(Y ′ |Z, T ′ = a) identical to its factual counterpart p(Y |Z, T = a) on any treatment level a. The learning objective is described as a combined likelihood of individual-specific treatment effect p(Y ′ |Y, X, T, T ′ ) (first source) and the traditional covariate-specific treatment effect p(Y |X, T ) (second source): J(D) = log [p(Y ′ |Y, X, T, T ′ )] + log [p(Y |X, T )] (1) where D = (X, Z, T, T ′ , Y, Y ′ ). Additionally, we assume that there is a graph structure G = (V , E) that governs the relations between the dimensions of Y through latent Z, where V ∈ R n×v is the node feature matrix and E ∈ {0, 1} n×n is the node adjacency matrix. For example, in the case of single-cell perturbation dataset where Y is the expression counts of genes, V is the gene feature matrix and E is the GRN that governs gene relations. See Figure 1 for a visualization of the causal diagram. The objective under this setting is thus formulated as J(D) = log [p(Y ′ |Y, G, X, T, T ′ )] + log [p(Y |G, X, T )] (2) where D = (G, D). The counterfactual outcome Y ′ is always unobserved, but the following theorem provides us a roadmap for the stochastic optimization of this objective. Theorem 1. Suppose that D = (G, X, Z, T, T ′ , Y, Y ′ ) follows a causal structure defined by the Bayesian network in Figure 1 . Then J(D) has the following evidence lower bound: J(D) ≥ E p(Z|Y,G,X,T ) log [p(Y |Z, T )] + log [p(Y ′ |G, X, T ′ )] -KL [p(Z|Y, G, X, T ) ∥ p(Z|Y ′ , G, X, T ′ )] . Proof of the theorem can be found in Appendix B. We estimate p(Z|Y, G, X, T ) and p(Y |Z, T ) (as well as p(Y ′ |Z, T ′ )) with a neural network encoder q ϕ and decoder p θ , and optimize the following weighted approximation of the variational lower bound: where ω 1 , ω 2 are scaling coefficients; Ỹ ′ θ,ϕ ∼ E q ϕ (Z|Y,G,X,T ) p θ (Y ′ |Z, T ′ ) and p is the covariatespecific model fit of the outcome distribution (notice that p(Y ′ |G, X, T ′ = a) = p(Y |G, X, T = a) for any a). In our case where covariates are limited and discrete (cell types and donors), we simply let p(Y |G, X, T ) be the smoothened empirical distribution of Y stratified by X and T (notice that G is fixed across subjects). Generally, one can train a discriminator p(•|Y, G, X, T ) with the adversarial approach (Goodfellow et al., 2014)  and use log 2 for an overview of the model structure. Note that the decoder estimates the conditional outcome distribution of Y ′ , in which case T ′ need not necessarily be sampled according to a certain true distribution p(T ′ |X) during optimization. J(θ, ϕ) = E q ϕ (Z|Y,G,X,T ) log [p θ (Y |Z, T )] + ω 1 • log p( Ỹ ′ θ,ϕ |G, X, T ′ ) -ω 2 • KL q ϕ (Z|Y, G, X, T ) ∥ q ϕ (Z| Ỹ ′ θ,ϕ , G, X, T ′ ) (4) p(1| Ỹ ′ θ,ϕ , G, X, T ′ ) for log p( Ỹ ′ θ,ϕ |G, X, T ′ ) if p(Y |G, X, T ) is hard to fit. See Figure We refer to the negative of the first term in Equation 4 as reconstruction loss, the negative of the second term as distribution loss, and the positive of the third term as KL-divergence. As discussed in Wu et al. (2022b) , the negative KL-divergence term in the objective function encourages the preservation of individuality in counterfactual outcome constructions.

Marginal Effect Estimation

Although perturbation responses at single cell resolution offers microscopic view of the biological landscape, oftentimes it is fundamental to estimate the average population effect of a perturbation in a given cell type. Hence in this work, we developed a robust estimation for the causal parameter Ψ(p) = E p (Y ′ |X = c, T ′ = a) -the marginal effect of treatment a within a covariate group c. We propose the following estimator that is asymptotically efficient when E p θ (Y ′ | Zk,ϕ , T ′ k = a) is estimated consistently and some other regularity conditions (Van Der Laan & Rubin, 2006) hold: Ψθ,ϕ = 1 n a,c na,c k=1a,c Y k -E p θ (Y ′ | Zk,ϕ , T ′ k = a) + 1 n c nc k=1c E p θ (Y ′ | Zk,ϕ , T ′ k = a) where  (Y k , X k , T k ) are the observed variables of the k-th individual; Zk,ϕ ∼ q ϕ (Z|Y k , G, X k , T k ); (1 c , . . . , n c )

2.2. INCORPORATING RELATIONAL INFORMATION

Since the elements in the outcome Y are not independent, we aim to design a framework that can exploit predefined relational knowledge among them. In this section, we demonstrate our model structure for encoding and aggregating relation graph G within the graphVCI framework. Denote deterministic models as f • , probabilistic models (output probability distributions) as q • and p • . We construct feature vector Z as an aggregation of two latent representations: Z = (Z G , Z H ) (6) Z H ∼ q H Z M , aggr G (Z G ) (7) Z M = f M (Y, X, T ) (8) Z G = f G (G) (9) where aggr G is a node aggregation operation such as sum, max or mean. The optimization of q ϕ can then be performed by optimizing the MLP encoder f M,ϕ1 : R n+m+r → R d , the GNN encoder f G,ϕ2 : R n×v → R n×dG and the encoding aggregator q H,ϕ3 : R d+dG → ([0, 1] d ) Σ where ϕ = (ϕ 1 , ϕ 2 , ϕ 3 ). We designed such construction so that Z possesses a node-level graph embedding Z G that enables more involved decoding techniques than generic MLP vector decoding, while the calculation of term KL [q ϕ (Z|Y, G, X, T ) ∥ q ϕ (Z|Y ′ , G, X, T ′ )] can also be reduced to the KLdivergence between the conditional distributions of two graph-level vector representations Z H and Z ′ H (since Z G is deterministic and invariant across subjects). In decoding, we construct Y = n ∥ i=1 f H Z G(i) , Y M ⊤ • Y M (10) Y M ∼ p M (Z H , T ) where ∥ represents vector concatenation and optimize p θ by optimizing the MLP decoder p M,θ1 : R d+r → ([0, 1] d ) Σ and the decoding aggregator f H,θ2 : R d+dG → R d where θ = (θ 1 , θ 2 ). The decoding aggregator maps graph embedding Z G(i) of the i-th node along with feature vector Y M to outcome Y (i) of the i-th dimension, for which we use an attention mechanism and let Y (i) = att θ2 (Z G(i) , Y M ) ⊤ • Y M where att(Q (i) , K) gives the attention score for each of K's feature dimensions given a querying node embedding vector Q (i) (i.e., a row vector of Q). One can simply use a key-independent attention mechanism Y (i) = att θ2 (Z G(i) ) ⊤ • Y M if GPU memory is of con- cern. See Figure 3 for a visualization of the encoding and decoding models. A complexity analysis of the model can be found in Appendix D.

2.3. RELATION GRAPH REFINEMENT

Since GRNs are often highly context-specific (Oliver, 2000; Romero et al., 2012) , and experimental methods such as ATAC-seq and Chip-seq are prone to false positives, we provide an option to impute and refine a prior GRN by learning from the expression dataset of interest. We propose an approach to update the adjacency matrix in GCN training while maintaining sparse graph operations, and use it in an auto-regressive pre-training task to obtain an updated GRN. Let g(•) be a GCN with learnable edge weights between all nodes. We aim to acquire an updated adjacency matrix by thresholding updated edge weights post optimization. In practice, such GCNs with complete graphs performed poorly on our task and had scalability issues. Hence we apply dropouts to edges in favor of the initial graph -edges present in Ẽ = |E -I| + I (I is the identity matrix) are accompanied with a low dropout rate r l and edges not present in Ẽ with a high dropout rate r h (r l ≪ r h ). A graph convolutional layer of g(•) is then given as H l+1 = σ(softmax r (M ⊙ L)H l Θ l ) where L ∈ R n×n is a dense matrix containing logits of the edge weights and M ∈ R n×n is a sparse mask matrix where each element M i,j is sampled from Bern(I( Ẽi,j = 1)r l + I( Ẽi,j = 0)r h ) in every iteration; ⊙ is element-wise matrix multiplication; softmax r is row-wise softmax operation; H l ∈ R n×d l , H l+1 ∈ R n×d l+1 are the latent representations of V after the l-th, (l + 1)-th layer; Θ l ∈ R d l ×d l+1 is the weight matrix of the (l + 1)-th layer; σ is a non-linear function. The updated adjacency matrix Ê ∈ R n×n is acquired by rescaling and thresholding the unnormalized weight matrix W = exp(L) after optimizing g(•): Êi,j = sgn (1 + W -1 i,j ) -1 -α (13) where α is a threshold level. We define W ∈ R n×n to be the rescaled weight matrix having Wi,j = (1 + W -1 i,j ) -1 . With this design, the convolution layer operates on sparse graphs which benefit performance and scalability, while each absent edge in the initial graph still has an opportunity to come into existence in the updated graph. We use this approach to obtain an updated GRN Ê prior to main model training presented in the previous sections. We train g(•) : R n×(1+v+m) → R n×1 on a simple node-level prediction task where the output of the i-th node is the expression Y (i) of the i-th dimension; the input of the i-th node is a combination O ′ (i) = (Y (i) , V (i) , X) with expression Y (i) , gene features V (i) of the i-th gene and cell covariates X. Essentially, we require the model to predict the expression of a gene (node) from its neighbors in the graph. This task is an effective way to learn potential connections in a gene regulatory network as regulatory genes should be predictive of their targets (Kamimoto et al., 2020) . The objective is a lasso-style combination of reconstruction loss and edge penalty: J(g) = -∥g(O) -Y ∥ L 2 -ω • ∥ W ∥ L 1 (14) where ω is a scaling coefficient. Note that although V (i) is not cell-specific and X is not genespecific, the combination of (V (i) , X) forms a unique representation for each gene of each cell type. We employ such dummy task since node-level predictions with strong self-connections grants interpretability, but additional penalization on the diagonal of W can be applied if one wishes to weaken the self-connections. Otherwise, although self-connections are the strongest signals in this task, dropouts and data correlations will still prevent them from being the only notable signals. See Figure 4 for an example of an updated GRN in practice.

3. EXPERIMENTS

We tested our framework on three datasets in experiments. We employ the publicly available sci-Plex dataset from Srivatsan et al. (2020) (Sciplex) and CRISPRa dataset from Schmidt et al. (2022) (Marson) . Sci-Plex is a method for pooled screening that relies on nuclear hashing, and the Sciplex dataset consists of three cancer cell lines (A549, MCF7, K562) treated with 188 compounds. Marson contains perturbations of 73 unique genes where the intervention served to increase the expression of those genes. In addition, we open source in this work a new dataset (L008) designed to showcase the power of our model in conjunction with modern genomics. L008 dataset We used the CROP-seq platform (Shifrut et al., 2018) to knock out 77 genes related to the interferon gamma signaling pathway in CD4 + T cells. They include genes at multiple steps of the interferon gamma signaling pathway such as JAK1, JAK2 and STAT1. We hope that by including multiple such genes, machine learning models will learn the signaling pathway in more detail. Baseline We compare our framework to three state-of-the-art self-supervised models for individual counterfactual outcome generation -CEVAE (Louizos et al., 2017) , GANITE (Yoon et al., 2018) and CPA (Lotfollahi et al., 2021b) , along with the non-graph version of our framework VCI (Wu et al., 2022b) . To give an idea how well these models are doing, we also compare them to a generic autoencoder (AE) with covariates and treatment as additional inputs, which serves as an ablation study for all other baseline models. For this generic approach, we simply plug in counterfactual treatments instead of factual treatments during test time.

3.1. OUT-OF-DISTRIBUTION PREDICTIONS

We evaluate our model and benchmarks on a widely accepted and biologically meaningful metricthe R 2 (coefficient of determination) of the average prediction against the true average from the outof-distribution (OOD) set (see Appendix E.2) on all genes and differentially-expressed (DE) genes (see Appendix E.1). Same as Lotfollahi et al. (2021a) , we calculate the R 2 for each perturbation of each covariate level (e.g. each cell type of each donor), then take the average and denote it as R2 . Table 1 shows the mean and standard deviation of R2 for each model over 5 independent runs. Training setups can be found in Appendix E.3. As can be seen from these experiments, our variational Bayesian causal inference framework with refined relation graph achieved a significant advantage over other models on all genes of the OOD set, and a remarkable advantage on DE genes on Marson. Note that losses were evaluated on all genes during training and DE genes were not being specifically optimized in these runs. We also examined the predicted distribution of gene expression for various genes and compared to experimental results. Fig 5 shows an analysis of the CRISPRa dataset where MAP4K1 and GATA3 were overexpressed in CD8 + T cells (Schmidt et al., 2022) , but these cells were not included in the model's training set. Nevertheless, the model's prediction for the distribution of gene expression frequently matches the ground truth. Quantitative agreement can be obtained from Table 1 . Figure 5 : Model predictions versus true distributions for overexpression of genes in CRISPRa experiments (Schmidt et al., 2022) . For two perturbations in CD8 + T cells, (a) MAP4K1 overexpression and (b) GATA3 overexpression, we plot the distribution of gene expressions for unperturbed cells ("Control"), the model's prediction of perturbed gene expressions using unperturbed cells as factual inputs ("Pred"), and the true gene expressions for perturbed cells ("True"). The predicted distributional shift relative to control often matches the direction of the true shift. For graphVCI, we used the key-dependent attention (see Section 2.2) for the decoding aggregator in all runs and there were a few interesting observations we found in these experiments. Firstly, the key-independent attention is more prone to the quality of the GRN and exhibited a more significant difference on model performance with the graph refinement technique compared to the key-dependent attention. Secondly, with graphVCI and the key-dependent attention, we are able to get stable performances across runs while setting ω 1 to be much higher than that of VCI. § Autoencoder with covariates and treatment as additional inputs. ¶ GANITE's counterfactual block. GANITE's counterfactual generator does not scale with a combination of high-dimensional outcome and multi-level treatment, hence we made the same adaptation as Wu et al. (2022b) .

3.2. GRAPH EVALUATIONS

In this section, we perform some ablation studies and analysis to validate the claim that the adjacency updating procedure improves the quality of the GRN. We first examine the performance impact of the refined GRN over the original GRN derived from ATAC-seq data. For this purpose, we used the key-independent attention decoding aggregator and conducted 5 separate runs with the original GRN and refined GRN on the Marson dataset (Fig 6a ). We found that the refined GRN helps the graphVCI to learn faster and achieve better predictive performance of gene expression, suggesting the updated graph contains more relevant gene connections for the model to predict counterfactual expressions. Note that there is a difference in length of the curves and bands in Fig 6a because we applied early stopping criteria to runs similar to CPA. Refined edges between these TFs and their targets agree well with the interactions in the database, according to a precision-recall curve where we treat the database as the true label and edge weights as a probability of interaction. We also compare the refined edges ("GCN") to a gene-gene correlation benchmark ("Corr") and find that the refined graph can outperform the benchmark. Next, we compared the learned edge weights from our method to known genetic interactions from the ChEA transcription factor targets database (Lachmann et al., 2010) . We treat edge weights following graph refinement as a probability of an interaction and treat the known interactions in the database as ground truth for two key TFs, STAT1 and IRF1. We found the refined GRN obtained by graph refinement is able to place higher weights on known targets than a more naive version of the GRN based solely on gene-gene correlation in the same dataset (Fig 6b ). The improvement is particularly noticeable in the high-precision regime where we expect the ground truth data is more accurate, since it is expected that a database based on ChIP-seq would contain false positives.

4. CONCLUSION

In this paper, we developed a theoretically grounded novel model architecture to combine deep graph representation learning with variational Bayesian causal inference for the prediction of single-cell perturbation effect. We proposed an adjacency matrix updating technique producing refined relation graphs prior to model training that was able to discover more relevant relations to the data and enhance model performances. Experimental results showcased the advantage of our framework compared to state-of-the-art methods. In addition, we included ablation studies and biological analysis to generate sensible insights. Further studies could be conducted regarding complete outof-distribution prediction -the prediction of treatment effect when the treatment is completely held out, by rigorously incorporating treatment relations into our framework on top of outcome relations.  Ψθ,ϕ = 1 n n k=1 I(T k = a, X k = c) p(T k |X k )p(X k ) Y k -E p θ (Y ′ | Zk,ϕ , T ′ k = a) + I(X k = c) p(X k ) E p θ (Y ′ | Zk,ϕ , T ′ k = a) where (Y k , X k , T k ) are the observed variables of the k-th individual and Zk,ϕ ∼ q ϕ (Y k , G, X k , T k ); p(T |X) is an estimation of the propensity score and p(X) is an estimation of the density of X. In the context of this work, X and T are discrete, hence p(T |X) and p(X) can be estimated by the empirical density p n (T |X) and p n (X). The above estimator then reduces to: Ψθ,ϕ = 1 n a,c na,c k=1a,c Y k -E p θ (Y ′ | Zk,ϕ , T ′ k = a) + 1 n c nc k=1c E p θ (Y ′ | Zk,ϕ , T ′ k = a) where (1 c , . . . , n c ) are the indices of the observations having X = c and (1 a,c , . . . , n a,c ) are the indices of the observations having both T = a and X = c.

D COMPLEXITY ANALYSIS

Overall, the time complexity of VCI compared to a generic framework like VAE (or the time complexity of graphVCI compared to a generic GNN framework like GLUE (Cao & Gao, 2022) does not over-scale on any factor of any parameter -to put it simply, the workflow of VCI is just twice the forward passes of an VAE for a batch of inputs, with an additional distribution loss which is implemented on the same scale O(r) as other losses. Comparing graphVCI to VCI, graphVCI has additionally a few graph operations. We give a thorough analysis of the time complexity of each layer in our experiments below: • Generic method AE has 2 MLP layers with O(rd) number of operations and 4 layers of O(d 2 ) number of operations. d is the number of hidden neurons. • VCI has the same layer sizes: 2 MLP layers with O(rd) number of operations and 4 layers of O(d 2 ) number of operations. But every layer is forward passed twice. • graphVCI has 1 GNN layer with O(rvd 2 g + pd g r 2 ) number of operations. v is number of gene features, d g is number of hidden neurons for GNN (usually a lot less than d since v ≪ r) and p is the sparsity of GRN (number of connections divided by r 2 , usually around 1%), and the following layers which are forward passed twice: 1 MLP layer and 1 dotproduct operation each with O(rd) number of operations, 1 MLP layer with O(rd g (d g +d)) number of operations and 3 MLP layers with O(d 2 ) number of operations. So the terms to be concerned compared to AE and VCI are O(rvd 2 g ), O(pd g r 2 ) and O(rd g d). Since p is around 1% and r is 2000 in our experiment, pd g r 2 ≈ 20d g r hence the second term is comparably smaller than the other two terms. Therefore, as long as d g is set to be reasonably small compared to d, the graph approach is reasonably scaled compared to pure MLP approaches. We note that this is a limitation of ours and GNN approaches in general: if there is a much high number of genes r than 2000 to be considered, or a high number of gene features v for each gene (which results in that d g has to be higher), GNN methods does not scale favorably compared to MLP methods.

E EXPERIMENTS DETAILS E.1 DIFFERENTIALLY-EXPRESSED GENES

To evaluate the predictions on the genes that were substantially affected by the perturbations, we select sets of 50 differentially-expressed genes associated with each perturbation and separately report performance on these genes. The same procedure was carried out by Lotfollahi et al. (2021a) .



† https://github.com/yulun-rayn/graphVCI



Figure 1: The causal relation diagram. Each individual has a feature state Z following a conditional distribution p(Z|G, X). Treatment T (or counterfactual treatment T ′ ) along with Z determines outcome Y (or counterfactual outcome Y ′ ). In the causal diagram, white nodes are observed and dark grey nodes are unobserved; dashed relations are optional (case dependant). In the context of this paper, graph G is a deterministic variable that is invariant across all individuals.

Figure 2: Model workflow -variational causal perspective. In a forward pass, the graphVCI encoder takes graph G (e.g. gene relation graph), outcome Y (e.g. gene expressions), covariates X (e.g. cell types, donors, etc.) and treatment T (e.g. drug perturbation) as inputs and generates latent Z; (Z, T ) and (Z, T ′ ) where T ′ is a randomly sampled counterfactual treatment are separately passed into the graphVCI decoder to attain reconstruction of Y and construction of counterfactual outcome Y ′ ; Y ′ is then passed back into the encoder along with G, X, T ′ to attain counterfactual latent Z ′ . The objective consists of the reconstruction loss of Y , the distribution loss of Y ′ and the KL-divergence between the conditional distributions of Z and Z ′ .

Figure 3: Model architecture -graph attentional perspective. Structure of the graphVCI encoder and decoder defined by Equations (6) to (11). Note that in the case of single-cell perturbation datasets, the graph inputs are fixed across samples and graph attention can essentially be reduced to weighted graph convolution.

Figure 4: An example of an updated gene regulatory network U (where U i,j = | Wi,j -α| with α = 0.2) after refining the original ATAC-seq-based network Kamimoto et al. (2020) using the Schmidt et al. (2022) dataset. Source nodes are shown as rows and targets are shown as columns for key immune-related genes. The learned edge weights in (b) recapitulate known biology such as STAT1 regulating IRF1. Also note that while many of the edges are present in the original ATAC-seq data from (a), we see some novel edges in (b) such as IFNg regulating MYC (Ramana et al., 2000).

Figure 6: (a) Graph refinement improves model training.We learned a GRN with graph refinement as described in section 2.3, and edges were retained using a threshold of α = 0.3. Following the refinement, the model is better able to reconstruct all genes and differentially expressed (DE) genes during training as can be seen in a graph of R 2 vs. number of training epochs. (b) We examine whether the edges learned in refinement are accurate by comparing to a database of targets for two important TFs, STAT1 and IRF1. Refined edges between these TFs and their targets agree well with the interactions in the database, according to a precision-recall curve where we treat the database as the true label and edge weights as a probability of interaction. We also compare the refined edges ("GCN") to a gene-gene correlation benchmark ("Corr") and find that the refined graph can outperform the benchmark.

R2 of OOD predictions

By Theorem 2, we propose the following estimator that is asymptotically efficient among regular estimators under some regularity conditions(Van Der Laan & Rubin, 2006):

ACKNOWLEDGMENTS

We thank Balasubramaniam Srinivasan, Drausin Wulsin, Meena Subramaniam, and Maxime Dhainaut for the insightful discussions. Work by author Robert A. Barton was done prior to joining Amazon.

B PROOF OF THEOREM 1

Proof. By the d-separation (Pearl, 1988) of paths on the causal graph defined in Figure 1 = E p(Z|Y,G,X,T ) log p(Y ′ , Z|Y, G, X, T, T ′ ) p(Z|Y, G, X, T ) (17)

C MARGINAL EFFECT ESTIMATION C.1 EXPERIMENT

To evaluate the marginal estimator Ψθ,ϕ in Equation 5, we compute Ψθ,ϕ for treatment a and covariate level c with samples from the training set and calculate its R 2 against the true average of the samples with treatment a and covariate level c in the validation set. We record the average R 2 of all treatment-covariate combo similar to Section 3.1, and compare it (robust) to that of the regular empirical mean estimator (mean). Table 2 shows the results on Marson (Schmidt et al., 2022) episodically during training. These runs reflects that the robust estimator was able to produce a more accurate estimation of the covariate-stratified marginal treatment effect E p (Y ′ |X = c, T ′ = a) with a tigher confidence bound.

C.2 DERIVATION

By Van der Vaart (2000) , we derive the efficient influence function of Ψ(p) and thus provides a mean for asymptotically efficient estimation: Theorem 2. Suppose D : Ω → E follows a causal structure defined by the Bayesian network in Figure 1 , where the counterfactual conditional distribution p(Y ′ , T ′ |Z, X) is identical to that of its factual counterpart p(Y, T |Z, X). Then Ψ(p) has the following efficient influence function: . Let X = (G, X) and minuscule of a variable denote the value it takes. By Levy (2019), we have 

