VARIATIONAL AUTO-ENCODER ARCHITECTURES THAT EXCEL AT CAUSAL INFERENCE

Abstract

This paper provides a generative approach for causal inference using data from observational studies. Inspired by the work of Kingma et al. (2014), we propose a sequence of three architectures (namely Series, Parallel, and Hybrid) that each incorporate their M1 and M2 models as building blocks. Each architecture is an improvement over the previous one in terms of estimating causal effect, culminating in the Hybrid model. The Hybrid model is designed to encourage decomposing the underlying factors of any observational dataset; this in turn, helps to accurately estimate all treatment outcomes. Our empirical results demonstrate the superiority of all three proposed architectures compared to both state-of-the-art discriminative as well as other generative approaches in the literature.

1. INTRODUCTION

As one of the main tasks in studying causality (Peters et al., 2017; Guo et al., 2018) , the goal of Causal Inference is to figure out how much the value of a certain variable would change (i.e., the effect) had another certain variable (i.e., the cause) changed its value. A prominent example is the counterfactual question (Rubin, 1974; Pearl, 2009) "Would this patient have lived longer [and by how much], had she received an alternative treatment?". Such question is often asked in the context of precision medicine, which attempts to identify which medical procedure t ∈ T will benefit a certain patient x the most, in terms of the treatment outcome y ∈ R (e.g., survival time). A fundamental problem in causal inference is the unobservablity of the counterfactual outcomes (Holland, 1986) . That is, for each subject i, any real-world dataset can only contain the outcome of the administered treatment (aka the observed outcome: y i ), but not the outcome(s) of the alternative treatment(s) (aka the counterfactual outcome(s) ) -i.e., y t i for t ∈ T \ {t i }. In other words, the causal effect is never observed (i.e., missing in any training data) and cannot be used to train predictive models, nor can it be used to evaluated a proposed model. This makes estimating causal effects a more difficult problem than that of generalization in the supervised learning paradigm. In general, we can categorize most machine learning algorithms into two general approaches, which differ in how the input features x and their target values y are modeled (Ng & Jordan, 2002) : Discriminative methods focus solely on modeling the conditional distribution p(y|x) with the goal of direct prediction of y for each instance x. For prediction tasks, discriminative approaches are often more accurate since they use the model parameters more efficiently than generative approaches. Most of the current causal inference methods are discriminative, including the Balancing Neural Network (BNN) (Johansson et al., 2016) , CounterFactual Regression Network (CFR-Net) (Shalit et al., 2017) , and CFR-Net's extensions -cf., (Yao et al., 2018; Hassanpour & Greiner, 2019; 2020) -as well as Dragon-Net (Shi et al., 2019) . Generative methods, on the other hand, describe the relationship between x and y by their joint probability distribution p(x, y). This, in turn, would allow the generative model to answer arbitrary queries, including coping with missing features x using the marginal distribution p(x) or [similar to discriminative models] predicting the unknown target values y via p(y|x). A promising direction forward for causal inference is developing generative models, using either Generative Adverserial Network (GAN) (Goodfellow et al., 2014) or Variational Auto-Encoder (VAE) (Kingma & Welling, 2014; Rezende et al., 2014) . This has led to two generative approaches for causal inference: GANs for inference of Individualised Treatment Effects (GANITE) (Yoon et al., 2018) and Causal Effect

