VARIATIONAL AUTO-ENCODER ARCHITECTURES THAT EXCEL AT CAUSAL INFERENCE

Abstract

This paper provides a generative approach for causal inference using data from observational studies. Inspired by the work of Kingma et al. (2014), we propose a sequence of three architectures (namely Series, Parallel, and Hybrid) that each incorporate their M1 and M2 models as building blocks. Each architecture is an improvement over the previous one in terms of estimating causal effect, culminating in the Hybrid model. The Hybrid model is designed to encourage decomposing the underlying factors of any observational dataset; this in turn, helps to accurately estimate all treatment outcomes. Our empirical results demonstrate the superiority of all three proposed architectures compared to both state-of-the-art discriminative as well as other generative approaches in the literature.

1. INTRODUCTION

As one of the main tasks in studying causality (Peters et al., 2017; Guo et al., 2018) , the goal of Causal Inference is to figure out how much the value of a certain variable would change (i.e., the effect) had another certain variable (i.e., the cause) changed its value. A prominent example is the counterfactual question (Rubin, 1974; Pearl, 2009) "Would this patient have lived longer [and by how much], had she received an alternative treatment?". Such question is often asked in the context of precision medicine, which attempts to identify which medical procedure t ∈ T will benefit a certain patient x the most, in terms of the treatment outcome y ∈ R (e.g., survival time). A fundamental problem in causal inference is the unobservablity of the counterfactual outcomes (Holland, 1986) . That is, for each subject i, any real-world dataset can only contain the outcome of the administered treatment (aka the observed outcome: y i ), but not the outcome(s) of the alternative treatment(s) (aka the counterfactual outcome(s) ) -i.e., y t i for t ∈ T \ {t i }. In other words, the causal effect is never observed (i.e., missing in any training data) and cannot be used to train predictive models, nor can it be used to evaluated a proposed model. This makes estimating causal effects a more difficult problem than that of generalization in the supervised learning paradigm. In general, we can categorize most machine learning algorithms into two general approaches, which differ in how the input features x and their target values y are modeled (Ng & Jordan, 2002) : Discriminative methods focus solely on modeling the conditional distribution p(y|x) with the goal of direct prediction of y for each instance x. For prediction tasks, discriminative approaches are often more accurate since they use the model parameters more efficiently than generative approaches. Most of the current causal inference methods are discriminative, including the Balancing Neural Network (BNN) (Johansson et al., 2016) , CounterFactual Regression Network (CFR-Net) (Shalit et al., 2017) , and CFR-Net's extensions -cf., (Yao et al., 2018; Hassanpour & Greiner, 2019; 2020) -as well as Dragon-Net (Shi et al., 2019) . Generative methods, on the other hand, describe the relationship between x and y by their joint probability distribution p(x, y). This, in turn, would allow the generative model to answer arbitrary queries, including coping with missing features x using the marginal distribution p(x) or [similar to discriminative models] predicting the unknown target values y via p(y|x). A promising direction forward for causal inference is developing generative models, using either Generative Adverserial Network (GAN) (Goodfellow et al., 2014) or Variational Auto-Encoder (VAE) (Kingma & Welling, 2014; Rezende et al., 2014) . This has led to two generative approaches for causal inference: GANs for inference of Individualised Treatment Effects (GANITE) (Yoon et al., 2018) and Causal Effect VAE (CEVAE) Louizos et al. (2017) . However, neither of the two achieve competitive performance in terms of treatment effect estimation compared to the discriminative approaches. Although discriminative models have excellent predictive performance, they suffer from two drawbacks: (i) overfitting, and (ii) making highly-confident predictions, even for instances that are "far" from the observed training data. Generative models based on Bayesian inference, on the other hand, can handle both of these drawbacks: issue (i) can be minimized by taking an average over the posterior distribution of model parameters; and issue (ii) can be addressed by explicitly providing model uncertainty via the posterior (Gordon & Hernández-Lobato, 2020). Although the exact inference is often intractable, efficient approximations to the parameter posterior distribution is possible through variational methods. Here, we use the Variational Auto-Encoder (VAE) (Kingma & Welling, 2014; Rezende et al., 2014) for the Bayesian inference component of our causal inference method. Contribution: In this paper, we propose three interrelated Bayesian model architectures (namely Series, Parallel, and Hybrid) that employ the VAE framework to address the task of causal inference for binary treatments. We find that the best performing architecture is the Hybrid model, that is [partially] successful in decomposing the underlying factors of any observational dataset. This is a valuable property, as that means it can accurately estimate all all treatment outcomes. We demonstrate that these models significantly outperform the state-of-the-art in terms of treatment effect estimation performance on two publicly available benchmarks, as well as a fully synthetic dataset that allows for detailed performance analyses.

CFR-Net

Shalit et al. ( 2017) considered the binary treatment task and attempted to learn a representation space Φ that reduces selection bias by making Pr( Φ(x) | t = 0 ) and Pr( Φ(x) | t = 1 ) as close to each other as possible, provided that Φ( x ) retains enough information that the learned regressors {h t Φ(•) : t ∈ {0, 1}} can generalize well on the observed outcomes. Their objective function includes L y i , h ti Φ(x i ) , which is the loss of predicting the observed outcome for sample i (described as x i ), weighted by ω i = ti 2u + 1-ti 2(1-u) , where u = Pr( t = 1 ). This is effectively setting ω i = 1 2 Pr( ti ) where Pr( t i ) is the probability of selecting treatment t i over the entire population.

DR-CFR

Hassanpour & Greiner (2020) argued against the standard implicit assumption that all of the covariates X are confounders (i.e., contribute to both treatment assignment and outcome determination). Instead, they proposed a graphical model similar to that in Figure 1 and designed a discriminative causal inference approach accordingly -built on top of the CFR-Net. Specifically, their model, named Disentangled Representations for CFR (DR-CFR), includes three representation networks, each trained with constraints to insure that each component corresponds to its respective underlying factor. While the idea behind DR-CFR provides an interesting intuition, it is known that only generative models (and not discriminative ones) can truly identify the underlying data generating mechanism. This paper is a step in this direction.

Dragon-Net

Shi et al. ( 2019)'s main objective was to estimate the Average Treatment Effect (ATE), which they explain requires a two stage procedure: (i) fit models that predict the outcomes for both treatments; and (ii) find a downstream estimator of the effect. Their method is based on a classic result from strong ignorability -i.e., Theorem 3 in (Rosenbaum & Rubin, 1983 ) -that states: (y 1 , y 0 ) ⊥ ⊥ t | x & Pr( t = 1 | x ) ∈ (0, 1) =⇒ (y 1 , y 0 ) ⊥ ⊥ t | b(x) & Pr( t = 1 | b(x) ) ∈ (0, 1) where b(x) is a balancing scorefoot_0 . They consider propensity score as a balancing score and argue that only the parts of X relevant for predicting T are required for the estimation of the causal effectfoot_1 . This theorem only provides a way to match treated and control instances though -i.e., it helps finding potential counterfactuals from the alternative group to calculate ATE. Shi et al. (2019) , however, used this theorem to derive minimal representations on which to regress to estimate the outcomes.



That is, X ⊥ ⊥ T | b(X) (Rosenbaum& Rubin, 1983). The authors acknowledge that this would hurt the predictive performance for individual outcomes. As a result, this yields inaccurate estimation of Individual Treatment Effects (ITEs).

