ENABLING COUNTERFACTUAL SURVIVAL ANALYSIS WITH BALANCED REPRESENTATIONS

Abstract

Balanced representation learning methods have been applied successfully to counterfactual inference from observational data. However, approaches that account for survival outcomes are relatively limited. Survival data are frequently encountered across diverse medical applications, i.e., drug development, risk profiling, and clinical trials, and such data are also relevant in fields like manufacturing (for equipment monitoring). When the outcome of interest is time-to-event, special precautions for handling censored events need to be taken, as ignoring censored outcomes may lead to biased estimates. We propose a theoretically grounded unified framework for counterfactual inference applicable to survival outcomes. Further, we formulate a nonparametric hazard ratio metric for evaluating average and individualized treatment effects. Experimental results on real-world and semi-synthetic datasets, the latter which we introduce, demonstrate that the proposed approach significantly outperforms competitive alternatives in both survival-outcome predictions and treatment-effect estimation.

1. INTRODUCTION

Survival analysis or time-to-event studies focus on modeling the time of a future event, such as death or failure, and investigate its relationship with covariates or predictors of interest. Specifically, we may be interested in the causal effect of a given intervention or treatment on survival time. A typical question may be: will a given therapy increase the chances of survival of an individual or population? Such causal inquiries on survival outcomes are common in the fields of epidemiology and medicine (Robins, 1986; Hammer et al., 1996; Yusuf et al., 2016) . As an important current example, the COVID-19 pandemic is creating a demand for methodological development to address such questions, specifically, when evaluating the effectiveness of a potential vaccine or therapeutic outside randomized controlled trial settings. Traditional causal survival analysis is typically carried out in the context of a randomized controlled trial (RCT), where the treatment assignment is controlled by researchers. Though they are the gold standard for causal inference, RCTs are usually long-term engagements, expensive and limited in sample size. Alternatively, the availability of observational data with comprehensive information about patients, such as electronic health records (EHRs), constitutes a more accessible but also more challenging source for estimating causal effects (Häyrinen et al., 2008; Jha et al., 2009) . Such observational data may be used to augment and verify an RCT, after a particular treatment is approved and in use (Gombar et al., 2019; Frankovich et al., 2011; Longhurst et al., 2014) . Moreover, the wealth of information from observational data also allows for the estimation of the individualized treatment effect (ITE), namely, the causal effect of an intervention at the individual level. In this work, we develop a novel framework for counterfactual time-to-event prediction to estimate the ITE for survival or time-to-event outcomes from observational data. Estimating the causal effect for survival outcomes in observational data manifests two principal challenges. First, the treatment assignment mechanism is not known a priori. Therefore, there may be variables, known as confounders, affecting both the treatment and survival time, which lead to selection bias (Bareinboim & Pearl, 2012) , i.e., that the distributions across treatment groups are not the same. In this work, we focus on selection biases due to confounding, but other sources may also be considered. For instance, patients who are severely ill are likely to receive more aggressive therapy, however, their health status may also inevitably influence survival. Traditional survival analysis neglects such bias, leading to incorrect causal estimation. Second, the exact time-to-event is not always observed, i.e., sometimes we only know that an event has not occurred up to a certain point in time. This is known as the censoring problem. Moreover, censoring might be informative depending on the characteristics of the individuals and their treatment assignments, thus proper adjustment is required for accurate causal estimation (Cole & Hernán, 2004; Díaz, 2019) . Traditional causal survival-analysis approaches typically model the effect of the treatment or covariates (not time or survival) in a parametric manner. Two commonly used models are the Cox proportional hazards (CoxPH) model (Cox, 1972) and the accelerated failure time (AFT) model (Wei, 1992) , which presume a linear relationship between the covariates and survival probability. Further, proper weighting for each individual has been employed to account for confounding bias from these models (Austin, 2007; 2014; Hernán et al., 2005) . For instance, probability weighting schemes that account for both selection bias and covariate dependent censoring have been considered for adjusted survival curves (Cole & Hernán, 2004; Díaz, 2019) . Moreover, such probability weighting schemes have been applied to causal survival-analysis under time-varying treatment and confounding (Robins, 1986; Hernán et al., 2000) . See van der Laan & Robins (2003) ; Tsiatis (2007) ; Van der Laan & Rose (2011); Hernán & Robins (2020) for an overview. Such linear specification makes these models interpretable but compromises their flexibility, and makes it difficult to adapt them for high-dimensional data or to capture complex interactions among covariates. Importantly, these methods lack a counterfactual prediction mechanism, which is key for ITE estimation (see Section 2). Fortunately, recent advances in machine learning, such as representation learning or generative modeling, have enabled causal inference methods to handle high-dimensional data and to characterize complex interactions effectively. For instance, there has been recent interest in tree-based (Chipman et al., 2010; Wager & Athey, 2018) and neural-network-based (Shalit et al., 2017; Zhang et al., 2020) approaches. For pre-specified time-horizons, the nonparametric Random Survival Forest (RSF) (Ishwaran et al., 2008) and Bayesian Additive regression trees (BART) (Chipman et al., 2010) have been extended to causal survival analysis. RSF has been applied to causal survival forests with weighted bootstrap inference (Shen et al., 2018; Cui et al., 2020) while a BART is extended to account for survival outcomes in Surv-BART (Sparapani et al., 2016) , and AFT-BART (Henderson et al., 2020) . See Hu et al. (2020) for an extensive investigation of the causal survival tree-based methods. Alternatively, when estimating the ITE, neural-network-based methods propose to regularize the transformed covariates or representations for an individual to have balanced distributions across treatment groups, thus accounting for the confounding bias and improving ITE prediction. However, most approaches employing representation learning techniques for counterfactual inference deal with continuous or binary outcomes, instead of time-to-event outcomes with censoring (informative or non-informative). Hence, a principled generalization to the context of counterfactual survival analysis is needed. In this work we leverage balanced (latent) representation learning to estimate ITE via counterfactual prediction of survival outcomes in observational studies. We develop a framework to predict event times from a low-dimensional transformation of the original covariate space. To address the specific challenges associated with counterfactual survival analysis, we make the following contributions: • We develop an optimization objective incorporating adjustments for informative censoring, as well as a balanced regularization term bounding the generalization error for ITE prediction. For the latter, we repurpose a recently proposed bound (Shalit et al., 2017) for our time-to-event scenario. • We propose a generative model for event times to relax restrictive survival linear and parametric assumptions, thus allowing for more flexible modeling. Our approach can also provide nonparametric uncertainty quantification for ITE predictions. • We provide survival-specific evaluation metrics, including a new nonparametric hazard ratio estimator, and discuss how to perform model selection for survival outcomes. The proposed model demonstrates superior performance relative to the commonly used baselines in real-world and semi-synthetic datasets. • We introduce a survival-specific semi-synthetic dataset and demonstrate an approach for leveraging prior randomized experiments in longitudinal studies for model validation. 

2. PROBLEM FORMULATION

We first introduce the basic setup for performing causal survival analysis in observational studies. Suppose we have N units, with N 1 units being treated and N 0 in the control group (N = N 1 + N 0 ). For each unit (individual), we have covariates X, which can be heterogeneous, e.g., a mixture of categorical and continuous covariates which, in the context of medicine, may include labs, vitals, procedure codes, etc. We also have a treatment indicator A, where A = 0 for the controls and A = 1 for the treated, as well as the outcome (event) of interest T . Under the potential-outcomes framework (Rubin, 2005), let T 0 and T 1 be the potential event times for a given subject under control and treatment, respectively. In practice we only observe one realization of the potential outcomes, i.e., the factual outcome T = T A , while the counterfactual outcome T 1-A is unobserved. In survival analysis, the problem becomes more difficult because we do not always observe the exact event time for each individual, but rather the time up to which we are certain that the event has not occurred; specifically, we have a (right) censoring problem, most likely due to the loss of follow-up. We denote the censoring time as C and censoring indicator as δ ∈ {0, 1}. The actual observed time is Y = min(T A , C), i.e., the outcome is observed (non-censored) if T A < C and δ = 1. In this work, we are interested in the expected difference between the T 1 and T 0 conditioned on X for a given unit (individual), which is commonly known as the individualized treatment effect (ITE). Specifically, we wish to perform inference on the conditional distributions of T 1 and T 0 , i.e., p(T 1 |X) and p(T 0 |X), respectively, as shown in Figure 1a . In practice, we observe N realizations of (Y, δ, X, A) for observed time, censoring indicator, covariates and treatment indicator, respectively; hence, from an observational study the dataset takes the form D = {(y i , δ i , x i , a i )} N i=1 . Below, we discuss several common choices of estimands in survival analysis.

Estimands of Interest

We begin by considering survival analysis in the absence of an intervening treatment choice, A. Let F (t|x) P (T ≤ t|X = x) be the cumulative distribution function of the event (failure) time, t, given a realization of the covariates, x. Survival analysis is primarily concerned with characterization of the survival function conditioned on covariates S(t|x) 1 -F (t|x), and the hazard function or risk score, λ(t|x), defined below. S(t|x) is a monotonically decreasing function indicating the probability of survival up to time t. The hazard function measures the instantaneous probability of the event occurring between {t, t + ∆t} given T > t and ∆t → 0. From standard definitions (Kleinbaum & Klein, 2010) , the relationship between cumulative and hazard function is formulated as λ(t|x) = lim dt→0 P (t < T < t + dt|X = x) P (T > t|X = x)dt = - d log S(t|x) dt = f (t|x) S(t|x) . From ( 1) we see that f (t|x) P (T = t|X = x) = λ(t|x)S(t|x), is the conditional event time density function (Kleinbaum & Klein, 2010) . Given the binary treatment A, we are interested in its impact on the survival time. For ITE estimation, we are also interested in the difference between the two potential outcomes T 1 , T 0 . Let S A (t|x) and λ A (t|x) denote the survival and hazard functions for the potential outcomes T A , i.e., T 1 and T 0 . Several common estimands of interest include (Zhao et al., 2012; Trinquart et al., 2016)  : difference in expected lifetime: ITE(t, x) = tmax 0 {S 1 (t|x) -S 0 (t|x)}dt = E{T 1 -T 0 |X = x}, difference in survival function: ITE(t, x) = S 1 (t|x) -S 0 (t|x), and hazard ratio: ITE(t, x) = λ 1 (t|x)/λ 0 (t|x). The inference difficulties associated with the above estimands from observational data are two-fold. First, there are confounders affecting both the treatment assignment and outcomes, which stem from selection bias, i.e., the treatment and control covariate distributions are not necessarily the same. Also, we do not have direct knowledge of the conditional treatment assignment mechanism, i.e., P (A = a|X = x), also known as the propensity score. Let ⊥ ⊥ denote statistical independence. For estimands to be identifiable from observational data, we make two assumptions: (i) {T 1 , T 0 } ⊥ ⊥ A|X, i.e., no unobserved confounders or ignorability, and (ii) overlap in the covariate support 0 < P (A = 1|X = x) < 1 almost surely if p(X = x) > 0. Second, the censoring mechanism is also unknown and may lead to bias without proper adjustment. We consider two censoring mechanisms in our work, (i) conditionally independent or informative censoring: T ⊥ ⊥ C|X, A, and (ii) random or non-informative censoring: T ⊥ ⊥ C. Note that for informative censoring, we also have to consider potential censoring times C 1 and C 0 and their conditionals p(C 1 |X) and p(C 0 |X), respectively. Figure 1 shows causal graphs illustrating these modeling assumptions.

3. MODELING

To overcome the above challenges and adjust for observational biases, we propose a unified framework for counterfactual survival analysis (CSA). Specifically, we repurpose the counterfactual bound in Shalit et al. (2017) for our time-to-event scenario and introduce a nonparametric approach for stochastic survival outcome predictions. Below we formulate a theoretically grounded and unified approach for estimating (i) the encoder function r = Φ(x), which deterministically maps covariates x to their corresponding latent representation r ∈ R d , and (ii) two stochastic time-to-event generative functions, h A (•), to implicitly draw samples from both potential outcome conditionals t a ∼ p h,Φ (T A |X = x), for A = {1, 0}, and where t a indicates the sample from p h,Φ (T A |X = x) is for A = a. Further, we formulate a general extension that accounts for informative censoring by introducing two stochastic censoring generative functions, ν A (•), to draw samples for potential censoring times c a ∼ p ν,Φ (C A |X = x). The model-specifying functions, {h A (•), ν A (•), Φ(•)}, are parameterized via neural networks. See the Supplementary Material (SM) for details. Figure 1a summarizes our modeling approach.

Accounting for selection bias

We wish to estimate the potential outcomes, i.e., event times, which are sampled by distributions parameterized by functions {h A (•), Φ(•)}, i.e., t ∼ p h,Φ (T |X = x, A = a) (2) t a ∼ p h,Φ (T a |X = x) We obtain (3) from (2) via the strong ignorability assumption, i.e., {T 0 , T 1 } ⊥ ⊥ A|X (consis- tent with the causal graphs in Figure 1b and 1c) and 0 < P (A = a|X = x) < 1, and the consistency assumption, i.e., T = T A |A = a. A similar argument can be made for informative censoring based on Figure 1c , so we can also write c a ∼ p ν,Φ (C A |X = x). Given (3), model functions {h A (•), Φ(•)} and ν A (•) for informative censoring can be learned by leveraging standard statistical optimization approaches, that minimize a loss hypothesis L given samples from the empirical distribution (y, δ, x, a) ∼ p(Y, δ, X, A), i.e., from dataset D. Specifically, we write L = E (y,δ,x,a)∼p(Y,δ,X,A) [ h,Φ (t a , y, δ)], where h,Φ (t a , y, δ) is a loss function that measures the agreement of t a ∼ p h,Φ (T A |X = x) (and c a ∼ p ν,Φ (C A |X = x) for informative censoring) with ground truth {y, δ}, the observed time and censoring indicator, respectively. For some parametric formulations of event time distribution p h,Φ (T A |X = x), e.g., exponential, Weibull, log-Normal, etc., and provided the censoring mechanism is non-informative,h,Φ (t a , y, δ) is the closed form log likelihood. Specifically,h,Φ (t a , y, δ) log p h,Φ (T a |X = x) = δ • log f h,Φ (t a |x) + (1 -δ) • log S h,Φ (t a |x) , which implies that the conditional event time density and survival functions can be calculated in closed form from transformations {h A (•), Φ(•)} of x. See the SM for parametric examples of L accounting for informative censoring. We further define the expected loss for a given realization of covariates x and treatment assignment a over observed times y (censored and non-censored), and the censoring indicator δ as ζ h,Φ (x, a) E (y,δ,x)∼p(Y,δ|X) h,Φ (t a , y, δ) as in Shalit et al. (2017) . For a given subject with covariates x and treatment assignment a, we wish to minimize both the factual and counterfactual losses, L F and L CF , respectively, by decomposing L = L F + L CF as follows L F = E (x,a)∼p(A,X) ζ h,Φ (x, a) , L CF = E (x,a)∼p(1-A,X) ζ h,Φ (x, a) . Let u P (A = 1) denote the marginal probability of treatment assignment. We can readily decompose the losses in (4) according to treatment assignments. The decomposed factual L F = u • L A=1 F + (1 -u) • L A=0 F , and similarly, the decomposed counterfactual L CF = (1 -u) • L A=1 CF + u • L A=0 CF . In practice, only factual outcomes are observed, hence, for a non-randomized non-controlled experiment, we cannot obtain an unbiased estimate of L CF from data due to selection bias (or confounding). Therefore, we bound L CF and L below following Shalit et al. (2017) . Corollary 1 Assume Φ(•) is an invertible map, and α -1 ζ h,Φ (x, a) ∈ G, where G is a family of functions, p A=a Φ p Φ (R|A = a) is the latent distribution for group A = a, and α > 0 is a constant. Then, we have: L CF ≤ (1 -u) • L A=1 F + u • L A=0 F + α • IPM G (p A=1 Φ , p A=0 Φ ) L ≤ L A=1 F + L A=0 F + α • IPM G (p A=1 Φ , p A=0 Φ ) . The integral probability metric (IPM) (Müller, 1997; Sriperumbudur et al., 2012) measures the distance between two probability distributions p and q defined over M , i.e., the latent space of R. Formally, IPM G (p, q) sup g∈G | M g(m) (p(m) -q(m)) dm|, where g : m → R, represents a class of real-valued bounded measurable functions on M (Shalit et al., 2017) . Therefore, model functions {h a (•), Φ(•)} can be learned by minimizing the upper bound in (5) consisting of (i) only factual losses under both treatment assignments and (ii) an IPM regularizer enforcing latent distributional equivalence between the treatment groups. Note that if the data originates from a RCT it follows (by construction) that IPM G (p A=1 Φ , p A=0 Φ ) = 0. Accounting for censoring bias Below we formulate an approach for estimating functions h A (•) and ν A (•) for synthesizing (sampling) non-censored t a ∼ p h,Φ (T A |X = x) and censored c a ∼ p ν,Φ (C A |X = x) times, respectively. While some parametric assumptions for p h,Φ (T A |X = x) yield easy-to-evaluate closed forms for S h,Φ (t a |x) that can be used as likelihood for censored observations, they are restrictive, and have been shown to generate unrealistic high variance samples (Chapfuwa et al., 2018) . So motivated, we seek a nonparametric likelihood-based approach that can model a flexible family of distributions, with an easy-to-sample approach for event times t a ∼ p h,Φ (T a |X = x). We model the event time generation process with a source of randomness, p( ), e.g. Gaussian or uniform, which is obtained from a neural-network-based nonlinear transformation. In the experiments we use a planar flow formulation parameterized by Mohamed, 2015) , however, other specifications can also be used. Note that Miscouridou et al. (2018) has previously leveraged normalizing flows for survival analysis, however, our approach is very different in that it focuses on formulating i) a counterfactual survival analysis framework that accounts for informative or non-informative censoring mechanisms and confounding, and ii) model event times as a continuous variable instead of discretizing them. Specifically, we transform the source of randomness, , using a single layer specification as follows {U h , W h , b h } (Rezende & ˜ h = + U h tanh(W h + b h ) , ∼ Uniform(0, 1) , t a = h A (r, ˜ h ) , r = Φ(x) where {U h , W h } ∈ R d×d , {b h , } ∈ R d , d is the dimensionality of the normalizing flow; each component of is drawn independently from Uniform(0, 1), and ˜ h may be viewed as a skip connection with stochasticity in . Further, h A (r, ˜ h ) and Φ(x) are time-to-event generative and encoding functions, respectively, parameterized as neural networks. For simplicity, the dimensions of r and are set to d, however, they can be set independently if desired. In practice, we are interested in generating realistic event-time samples; therefore, we account for both censored and non-censored observations by adopting the objective from Chapfuwa et al. (2018) , formulated as L CSA F E (y,δ,x,a)∼p(Y,δ,X,A), ∼p( ) [δ • (|y -t a |) + (1 -δ) • (max(0, y -t a ))] , where the first term encourages sampled event times t a to be close to y, the ground truth for observed events, i.e., δ = 1, while penalizing t a for being smaller than the censoring time when δ = 0. Further, the expectation is taken over samples (a minibatch) from empirical distribution p(Y, δ, X, A). Informative censoring We model informative censoring similar to (7) but mirroring the censoring indicators to encourage accurate censoring time samples c a for δ = 0, while penalizing c a for being smaller than y for δ = 1 (observed events). Specifically, we set an independent source of randomness like in (6) but parameterized by {U ν , W ν , b ν } and censoring generative functions ν A (r, ˜ ν ), parameterized as neural networks, where c a ∼ p ν,Φ (C A |X = x) formulated as c (ν, Φ) = E (y,δ,x,a)∼p(y,δ,X,A), ∼p( ) [(1 -δ) • (|y -c a |) + δ • (max(0, y -c a ))] . Further, we introduce an additional time-order-consistency loss that enforces the correct order of the observed time relative to the censoring indicator, i.e., c a < t a if δ = 0 and t a < c a if δ = 1, thus TC (h, ν, Φ) = E (δ,x,a)∼p(δ,X,A), ∼p( ) [δ • (max(0, t a -c a )) + (1 -δ) • (max(0, c a -t a ))] (9) Note that TC (h, ν, Φ) does not depend on the observed event times but only on the censoring indicators. Finally, we write the consolidated CSA loss for informative censoring (CSA-INFO) by aggregating ( 7), ( 8) and ( 9 5), we optimize the dual formulation of the Wasserstein distance, via the regularized optimal transport (Villani, 2008; Cuturi, 2013) . Consequently, we only require α -1 ζ h,Φ (x, a) to be 1-Lipschitz (Shalit et al., 2017) and α is selected by grid search on the validation set using only factual data (details below).

4. METRICS

We propose a comprehensive evaluation approach that accounts for both factual and causal metrics. Factual survival outcome predictions are evaluated according to standard survival metrics that measure diverse performance characteristics, such as concordance index (C-Index) (Harrell Jr et al., 1984) , mean coefficient of variation (COV) and calibration slope (C-slope) (Chapfuwa et al., 2020) . See the SM for more details on these metrics. For causal metrics, defined below, we introduce a nonparametric hazard ratio (HR) between treatment outcomes, and adopt the conventional precision in estimation of heterogeneous effect (PEHE) and average treatment effect (ATE) performance metrics (Hill, 2011) . Note that PEHE and ATE require ground truth counterfactual event times, which is only possible in (semi-)synthetic data. For HR, we compare our findings with those independently reported in the literature from gold-standard RCT data. Nonparametric Hazard Ratio In a medical setting, the population hazard ratio HR(t) between treatment groups is considered informative thus has been widely used in drug development and RCT (Yusuf et al., 2016; Mihaylova et al., 2012) . For example, HR(t) < 1, > 1, or ≈ 1 indicate population positive, negative and neutral treatment effects at time t, respectively. Moreover, HR(t) naturally accounts for both censored and non-censored outcomes. Standard approaches for computing HR(t) rely on the restrictive proportional hazard assumption from CoxPH (Cox, 1972) , which is constituted as a semi-parametric linear model λ(t|a) = λ b (t) exp(aβ). However, the constant covariate (time independent) effect is often violated in practice (see Figure 2b ). For CoxPH, the marginal HR between treatment and control can be obtained from regression coefficient β learned via maximum likelihood without the need for specifying the baseline hazard λ b (t): HR CoxPH (t) = λ(t|a=1) λ(t|a=0) = exp(β). So motivated, we propose a nonparametric, model-free approach for computing HR(t), in which we do not assume a parametric form for the event time distribution or the proportional hazard assumption from CoxPH. This approach only relies on samples from the conditional event time density functions, f (t 1 |x) and f (t 0 |x), via t a = h A (•) from (6). Definition 1 We define the nonparametric marginal Hazard Ratio and its approximation, ĤR(t), as HR(t) = λ 1 (t) λ 0 (t) = S 0 (t) S 1 (t) • S 1 (t) S 0 (t) , ĤR(t) = ŜPKM 0 (t) ŜPKM 1 (t) • m 1 (t) m 0 (t) , where for HR(t) we leveraged (1) to obtain (10) and S (t) dS(t)/dt. The nonparametric assumption for S(t) makes the computation of S (t) challenging. Provided that S(t) is a monotonically decreasing function, for simplicity, we fit a linear function S(t) = m • t + c, and set S (t) ≈ m. Note that the linear model is only used for estimating S (t) from the nonparametric estimation of S(t). (t) and m a . A similar formulation for the conditional, ĤR(t|x), can also be derived. See the SM for full details on the evaluation or ĤR(t) and ĤR(t|x). Note that for some AFT-or CoxPH-based parametric formulations, HR(t|x), can be readily evaluated because f (t a |x) and S(t a |x) are available in closed form. In the experiments, we will use HR(t) to compare different approaches against results reported in RCTs (see Tables 1 and 3 ). Further, we will use HR(t|x) to illustrate stratified treatment effects (see Figure 2 ). Note that though a neural-based survival recommender system (Katzman et al., 2018) has been previously used to estimate HR(t|x), their approach does not account for confounding or informative censoring thus it is susceptible to bias.

Precision in Estimation of Heterogeneous Effect (PEHE) A general individualized estimation

error is formulated as PEHE = E X [(ITE(x) -ÎTE(x)) 2 ], where ITE(x) is the ground truth, ÎTE(x) = E T [γ (T 1 ) -γ (T 0 ) |X = x] and γ(•) is a deterministic transformation. In our experiments, γ(•) is the average over samples from t a ∼ p h,Φ (T A |X = x). Alternative estimands, e.g., thresholding survival times γ(T A ) = I{T A > τ }, can also be considered as described above. Average Treatment Effect (ATE) The population treatment effect estimation error is defined as ATE = |ATE -ÂTE|, where ATE = E X [ITE(x)] (ground truth) and ÂTE = E X [ ÎTE(x)].

5. EXPERIMENTS

We describe the baselines and datasets that will be used to evaluate the proposed counterfactual survival analysis methods (CSA and CSA-INFO). Pytorch code including the new semi-synthetic dataset (see below) will be made publicly available. Throughout the experiments, we use the standard HR(t) for CoxPH based methods and ( 10) for all others. The bound in ( 5) is sensitive to α, thus we propose approximating proxy counterfactual outcomes {Y CF , δ CF } for the validation set, according to the covariate Euclidean nearest-neighbour (NN) from the training set. We select the α that minimizes the validation loss L = L F + L CF from the set (0, 0.1, 1, 10, 100). Baselines We consider the following competitive baseline approaches: (i) propensity weighted CoxPH (Schemper et al., 2009; Buchanan et al., 2014; Rosenbaum & Rubin, 1983) ; (ii) IPM (5) regularized AFT (log-Normal and Weibull) models; (iii) an IPM (5) regularized deterministic semisupervised regression (SR) model with accuracy objective from (Chapfuwa et al., 2018) , as a contrast for the proposed stochastic predictors (CSA and CSA-INFO); and (iv) survival Bayesian additive regression trees (Surv-BART) (Sparapani et al., 2016) . For CoxPH, we consider three normalized weighting schemes: (i) inverse probability weighting (IPW) (Horvitz & Thompson, 1952; Cao et al., 2009) , where IPW i = ai êi + 1-ai 1-êi ; ii) overlapping weights (OW) (Crump et al., 2006; Li et al., 2018) , where OW i = a i • (1 -êi ) + (1 -a i ) • êi ; and iii) the standard RCT uniform assumption. A simple linear logistic model êi = σ(x i ; w), is used as an approximation, êi , to the unknown propensity score P (A = 1|X = x). See the SM for a details of the baselines. Table 3 : Performance comparisons on FRAMINGHAM data, with 95% HR(t) confidence interval. Test set NN assignment of y CF and δ CF yields biased HR(t) = 1.23 (1.17,1.25) , while previous large scale longitudinal RCT studies estimated HR(t) = 0.75 (0.64,0.88) (Yusuf et al., 2016) . Datasets We consider the following datasets: (i) FRAMINGHAM, is an EHR-based longitudinal cardiovascular cohort study that we use to evaluate the effect of statins on future coronary heart disease outcomes (Benjamin et al., 1994) ; (ii) ACTG, is a longitudinal RCT study comparing monotherapy with Zidovudine or Didanosine with combination therapy in HIV patients (Hammer et al., 1996) ; and (iii) ACTG-SYNTHETIC, is a semi-synthetic dataset based on ACTG covariates. We simulate potential outcomes according to a Gompertz-Cox distribution (Bender et al., 2005) with selection bias from a simple logistic model for P (A = 1|X = x) and AFT-based censoring mechanism. The generative process is detailed in the SM. Table 2 summarizes the datasets according to (i) covariates of size p; (ii) proportion of non-censored events, treated units, and missing entries in the N × p covariate matrix; and (iii) time range t max for both censored and non-censored events. Missing entries are imputed with median or mode if continuous or categorical, respectively.

Quantitative Results

Experimental results for two data-sets in Tables 1 and 3 3 , demonstrates that CSA-INFO is clearly the best performing approach. Specifically, its HR(t), reverses the biased observational treatment effect, to demonstrate positive treatment from statins, which is consistent with prior large RCT longitudinal findings (Yusuf et al., 2016) . Qualitative Results Figure 2a demonstrates that CSA-INFO matches the ground truth population hazard, HR(t), better than alternative methods on ACTG-SYNTHETIC data. See the SM for ACTG and FRAMINGHAM. Figure 2b shows sub-population log hazard ratios for four patient clusters obtained via hierarchical clustering on the individual log hazard ratios, log HR(t|x), of the test set of FRAMINGHAM data. Interestingly, these clusters stratify treatment effects into: positive (2), negative (1 and 3), and neutral (4) sub-populations. Moreover, the estimated density of median log HR(t|x) values in Figure 2c illustrates that nearly 70% of the testing set individuals have log HR(t|x) < 0, thus may benefit from taking statins. Further, we isolated the extreme top and bottom quantiles, HR(t|x) < 0.024 and HR(t|x) > 1.916, respectively, of the median log HR(t|x) values for the test set of FRAMINGHAM, as shown in Figure 2c . After comparing their covariates, we found that individuals with the following characteristics may benefit from taking statins: young, male, diabetic, without prior history (CAD, PAD, stroke or MI), high BMI, cholesterol, triglycerides, fasting glucose, and low high-density lipoprotein. There seem to be consensus that diabetics and high-cholesterol patients benefit from statins (Cheung et al., 2004; Wilt et al., 2004) . See SM for additional results.

6. CONCLUSIONS

We have proposed a unified counterfactual inference framework for survival analysis. Our approach adjusts for bias from two unknown sources, namely, confounding due to covariate dependent selection bias and censoring (informative or non-informative). Relative to competitive alternatives, we demonstrate superior performance for both survival-outcome prediction and treatment-effect estimation, across three diverse datasets, including a semi-synthetic dataset which we introduce. Moreover, we formulate a model-free nonparametric hazard ratio metric for comparing treatment effects or leveraging prior randomized real-world experiments in longitudinal studies. Yao Zhang, Alexis Bellot, and Mihaela van der Schaar. Learning overlapping representations for the estimation of individualized treatment effects. In AISTATS, 2020. Lihui Zhao, Lu Tian, Hajime Uno, Scott D Solomon, Marc A Pfeffer, Jerald S Schindler, and Lee Jen Wei. Utilizing the integrated difference of two survival functions to quantify the treatment contrast for designing, monitoring, and analyzing a comparative clinical study. Clinical trials, 2012.



Figure 1: (a) Illustration of the proposed counterfactual survival analysis (CSA). Covariates X = x are mapped into latent representation r via deterministic mapping r = Φ(x). The potential outcomes are sampled from t a ∼ p(T A |X = x) for A = a via stochastic mapping h A (r, ˜ ), where randomness is induced with a flow-based transformation, ˜ , of a simple distribution p( ), i.e., uniform or Gaussian. (b) and (c) show the proposed causal graphs for non-informative and informative censoring, respectively.

Model functions {h A (•), Φ(•), ν A (•)} are learned by minimizing the bound (5), via stochastic gradient descent on minibatches from D, with L CSA F for non-informative censoring and L CSA-INFO F for informative censoring. Further, for the IPM regularization loss in (

Figure 2: (a) Inferred population HR(t) compared against ground truth (EMP) on ACTG-SYNTHETIC data. CSA-INFO-based (b) cluster-specific average log HR(t|x) curves and (c) estimated density of median log HR(t|x) values on the test set of the FRAMINGHAM dataset. Clusters assignment were obtained via hierarchical clustering of individualized log HR(t|x) traces.

Performance comparisons on ACTG-SYNTHETIC data, with 95% HR(t) confidence interval. The ground truth, test set, hazard ratio is HR(t) = 0.52 (0.39,0.71) .Bias from S (t) can be reduced by considering more complex function approximations for S(t), e.g., polynomial or spline. For the nonparametric estimation of S(t) we leverage the model-free population point-estimate-based nonparametricKaplan-Meier (Kaplan & Meier, 1958) estimator of the survival function ŜPKM (t) inChapfuwa et al. (2020) to marginalize both factual and counterfactual predictions given covariates x. The approximated hazard ratio, ĤR(t), is thus obtained by combining the approximations ŜPKM

Summary statistics of the datasets.

, illustrate that AFTbased methods are high variance, inferior in calibration and C-Index than accuracy-based methods (SR, CSA, CSA-INFO). Surv-BART is the least calibrated but low variance method. CSA-INFO and CSA outperform all methods across all factual metrics, whereas CSA-INFO is better calibrated, low variance but slightly lower C-Index than CSA. Note that we fit CoxPH using the entire dataset; since it does not support counterfactual inference, we do not present factual metrics. By properly adjusting for both informative censoring and selection bias, CSA-INFO significantly outperforms all methods in treatment effect estimation according to HR(t) and PEHE , across non-RCT datasets, while remaining comparable to AFT-Weibull on the RCT dataset (see the SM). Further, RCT-based results on ACTG data in the SM illustrate comparable HR(t) across all models except for AFT-log-Normal and Surv-BART, which overestimate, and SR, which underestimates risk. For non-RCT datasets (ACTG-SYNTHENTIC and FRAMINGHAM), CoxPH-OW has a clear advantage over all CoxPH based methods, mostly credited to the well-behaved bounded propensity weights ∈ [0, 1]. Interestingly, the FRAMINGHAM observational data exhibits a common paradox, where without proper adjustment of selection and censoring bias, naive approaches would result in a counter-intuitive treatment effect from statins. However, there is severe confounding from covariates such as age, BMI, diabetes, CAD, PAD, MI, stroke, etc., that influence both treatment likelihood and survival time. Table

