EXPLORING TRANSFORMER BACKBONES FOR HETEROGENEOUS TREATMENT EFFECT ESTIMATION

Abstract

Previous works on Treatment Effect Estimation (TEE) are not in widespread use because they are predominantly theoretical, where strong parametric assumptions are made but untractable for practical application. Recent works use Multilayer Perceptron (MLP) for modeling casual relationships, however, MLPs lag far behind recent advances in ML methodology, which limits their applicability and generalizability. To extend beyond the single domain formulation and towards more realistic learning scenarios, we explore model design spaces beyond MLPs, i.e., transformer backbones, which provide flexibility where attention layers govern interactions among treatments and covariates to exploit structural similarities of potential outcomes for confounding control. Through careful model design, Transformers as Treatment Effect Estimators (TransTEE) is proposed. We show empirically that TransTEE can: (1) serve as a general-purpose treatment effect estimator which significantly outperforms competitive baselines on a variety of challenging TEE problems (e.g., discrete, continuous, structured, or dosage-associated treatments.) and is applicable to both when covariates are tabular and when they consist of structural data (e.g., texts, graphs); (2) yield multiple advantages: compatibility with propensity score modeling, parameter efficiency, robustness to continuous treatment value distribution shifts, explainable in covariate adjustment, and real-world utility in auditing pre-trained language models.

1. INTRODUCTION

One of the fundamental tasks in causal inference is to estimate treatment effects given covariates, treatments and outcomes. Treatment effect estimation is a central problem of interest in clinical healthcare and social science (Imbens & Rubin, 2015) , as well as econometrics (Wooldridge, 2015) . Under certain conditions (Rosenbaum & Rubin, 1983) , the task can be framed as a particular type of missing data problem, whose structure is fundamentally different in key ways from supervised learning and entails a more complex set of covariate and treatment representation choices. Previous works in statistics leverage parametric models (Imbens & Rubin, 2015; Wager & Athey, 2018; Künzel et al., 2019; Foster & Syrgkanis, 2019) to estimate heterogeneous treatment effects. To improve their utilities, feed-forward neural networks have been adapted for modeling causal relationships and estimating treatment effects (Yoon et al., 2018; Bica et al., 2020b; Schwab et al., 2020; Nie et al., 2021; Curth & van der Schaar, 2021b) , in part due to their flexibility in modeling nonlinear functions (Hornik et al., 1989) and high-dimensional input (Johansson et al., 2016) . Among them, the specialized NN's architecture plays a key role in learning representations for counterfactual inference (Alaa & Schaar, 2018; Curth & van der Schaar, 2021b ) such that treatment variables and covariates are well distinguished (Shalit et al., 2017) . Despite these encouraging results, several key challenges make it difficult to adopt these methods as standard tools for treatment effect estimation. Most current works based on subnetworks do not sufficiently exploit the structural similarities of potential outcomes for heterogeneous TEEfoot_0 and accounting for them needs complicated regularizations, reparametrization or multi-task architectures that are problem-specific (Curth & van der Schaar, 2021b). Moreover, they heavily rely on their treatment-specific designs and cannot be easily extended beyond the narrow context in which they are originally. For example, they have poor practicality and generalizability when high-dimensional  O(n) O(BT ) O(n) FLEXTENET (CURTH & VAN DER SCHAAR, 2021B) O(n) OURS O(1) O(1) O(1) O(1) structural data (e.g., texts and graphs) are given as input (Kaddour et al., 2021) . Besides, those MLP-based models currently lag far behind recent advances in machine learning methodology, which are prone to issues of scale, expressivity and flexibility. Specifically, those side limitations include parameter inefficiency (Table 1 ), and brittleness under different scenarios, such as when treatments shift slightly from the training distribution. The above limitations clearly show a pressing need for an effective and practical framework to estimate treatment effects. In this work, we explore recent advanced models in the deep learning community to boost the model design for TEE tasks. Specifically, the core idea of our approach consists of three parts: as an Slearner, TransTEE embeds all treatments and covariates, which avoids multi-task architecture and shows improved flexibility and robustness to continuous treatment value distribution shifts; attention mechanisms are used for modeling treatment interaction and treatment-covariate interaction. In this way, TransTEE enables adaptive covariate selection (De Luna et al., 2011; VanderWeele, 2019) for inferring causal effects. For example, one can observe in Figure 1 that both pre-treatment covariates and confounders are appropriately adjusted with higher weights, which recovers the "disjunctive cause criterion" (De Luna et al., 2011) that accounts for those two kinds of covariates and is helpful for ensuring the plausibility of the conditional ignorability assumption when complete knowledge of a causal graph is not available. This recipe also gives improved versatility when working with heterogeneous treatments types (Figure 2 ). Our first contribution shows that transformer backbones, equipped with proper design choices, can be effective and versatile treatment effect estimators under the Rubin-Neyman potential outcomes framework. TransTEE is empirically verified to be (i) a flexible framework applicable for a wide range of TEE settings; (ii) compatible and effective with propensity score modeling; (iii) parameterefficient; (iv) explainable in covariate adjustment; (v) robust under continuous treatment shifts; (vi) useful for debugging pre-trained language models (LMs) to promote favorable social outcomes. Moreover, comprehensive experiments on six benchmarks with four types of treatments are conducted to verify the effectiveness of TransTEE in estimating treatment effects. We show that TransTEE produces covariate adjustment interpretation and significant performance gains given discrete, continuous or structured treatments on popular benchmarks including IHDP, News, TCGA. We introduce a new surrogate modeling task to broaden the scope of TEE beyond semi-synthetic evaluation and show that TransTEE is effective in real-world applications like auditing fair predictions of LMs.

2. RELATED WORK

Neural Treatment Effect Estimation. There are many recent works on adapting neural networks to learn counterfactual representations for treatment effect estimation (Johansson et al., 2016; Shalit et al., 2017; Louizos et al., 2017; Yoon et al., 2018; Bica et al., 2020b; Schwab et al., 2020; Nie et al., 



For example, E[Y (1) -Y (0)|X] is often of a much simpler form to estimate than either E[Y (1)|X] or E[Y (0)|X], due to inherent similarities between Y (1) and Y (0).



Figure 1: A motivating example with a corresponding causal graph. Prev denotes previous infection condition and BP denotes blood pressure. TransTEE adjusts an appropriate covariate set {Prev, BP} with attention which is visualized via a heatmap.

Comparison of existing works and TransTEE in terms of parameter complexity. n is the number of treatments. B T , B D are the number of branches for approximating continuous treatment and dosage. Treatment interaction means explicitly modeling collective effects of multiple treatments. TransTEE is general for all the factors.

