ENABLING COUNTERFACTUAL SURVIVAL ANALYSIS WITH BALANCED REPRESENTATIONS

Abstract

Balanced representation learning methods have been applied successfully to counterfactual inference from observational data. However, approaches that account for survival outcomes are relatively limited. Survival data are frequently encountered across diverse medical applications, i.e., drug development, risk profiling, and clinical trials, and such data are also relevant in fields like manufacturing (for equipment monitoring). When the outcome of interest is time-to-event, special precautions for handling censored events need to be taken, as ignoring censored outcomes may lead to biased estimates. We propose a theoretically grounded unified framework for counterfactual inference applicable to survival outcomes. Further, we formulate a nonparametric hazard ratio metric for evaluating average and individualized treatment effects. Experimental results on real-world and semi-synthetic datasets, the latter which we introduce, demonstrate that the proposed approach significantly outperforms competitive alternatives in both survival-outcome predictions and treatment-effect estimation.

1. INTRODUCTION

Survival analysis or time-to-event studies focus on modeling the time of a future event, such as death or failure, and investigate its relationship with covariates or predictors of interest. Specifically, we may be interested in the causal effect of a given intervention or treatment on survival time. A typical question may be: will a given therapy increase the chances of survival of an individual or population? Such causal inquiries on survival outcomes are common in the fields of epidemiology and medicine (Robins, 1986; Hammer et al., 1996; Yusuf et al., 2016) . As an important current example, the COVID-19 pandemic is creating a demand for methodological development to address such questions, specifically, when evaluating the effectiveness of a potential vaccine or therapeutic outside randomized controlled trial settings. Traditional causal survival analysis is typically carried out in the context of a randomized controlled trial (RCT), where the treatment assignment is controlled by researchers. Though they are the gold standard for causal inference, RCTs are usually long-term engagements, expensive and limited in sample size. Alternatively, the availability of observational data with comprehensive information about patients, such as electronic health records (EHRs), constitutes a more accessible but also more challenging source for estimating causal effects (Häyrinen et al., 2008; Jha et al., 2009) . Such observational data may be used to augment and verify an RCT, after a particular treatment is approved and in use (Gombar et al., 2019; Frankovich et al., 2011; Longhurst et al., 2014) . Moreover, the wealth of information from observational data also allows for the estimation of the individualized treatment effect (ITE), namely, the causal effect of an intervention at the individual level. In this work, we develop a novel framework for counterfactual time-to-event prediction to estimate the ITE for survival or time-to-event outcomes from observational data. Estimating the causal effect for survival outcomes in observational data manifests two principal challenges. First, the treatment assignment mechanism is not known a priori. Therefore, there may be variables, known as confounders, affecting both the treatment and survival time, which lead to selection bias (Bareinboim & Pearl, 2012) , i.e., that the distributions across treatment groups are not the same. In this work, we focus on selection biases due to confounding, but other sources may also be considered. For instance, patients who are severely ill are likely to receive more aggressive therapy, however, their health status may also inevitably influence survival. Traditional survival analysis neglects such bias, leading to incorrect causal estimation. Second, the exact time-to-event is not always observed, i.e., sometimes we only know that an event has not occurred up to a certain point in time. This is known as the censoring problem. Moreover, censoring might be informative depending on the characteristics of the individuals and their treatment assignments, thus proper adjustment is required for accurate causal estimation (Cole & Hernán, 2004; Díaz, 2019) . Traditional causal survival-analysis approaches typically model the effect of the treatment or covariates (not time or survival) in a parametric manner. Two commonly used models are the Cox proportional hazards (CoxPH) model (Cox, 1972) and the accelerated failure time (AFT) model (Wei, 1992) , which presume a linear relationship between the covariates and survival probability. Further, proper weighting for each individual has been employed to account for confounding bias from these models (Austin, 2007; 2014; Hernán et al., 2005) . For instance, probability weighting schemes that account for both selection bias and covariate dependent censoring have been considered for adjusted survival curves (Cole & Hernán, 2004; Díaz, 2019) . Moreover, such probability weighting schemes have been applied to causal survival-analysis under time-varying treatment and confounding (Robins, 1986; Hernán et al., 2000) . See van der Laan & Robins (2003); Tsiatis (2007); Van der Laan & Rose (2011); Hernán & Robins (2020) for an overview. Such linear specification makes these models interpretable but compromises their flexibility, and makes it difficult to adapt them for high-dimensional data or to capture complex interactions among covariates. Importantly, these methods lack a counterfactual prediction mechanism, which is key for ITE estimation (see Section 2). Fortunately, recent advances in machine learning, such as representation learning or generative modeling, have enabled causal inference methods to handle high-dimensional data and to characterize complex interactions effectively. For instance, there has been recent interest in tree-based (Chipman et al., 2010; Wager & Athey, 2018) and neural-network-based (Shalit et al., 2017; Zhang et al., 2020) approaches. For pre-specified time-horizons, the nonparametric Random Survival Forest (RSF) (Ishwaran et al., 2008) and Bayesian Additive regression trees (BART) (Chipman et al., 2010) have been extended to causal survival analysis. RSF has been applied to causal survival forests with weighted bootstrap inference (Shen et al., 2018; Cui et al., 2020) while a BART is extended to account for survival outcomes in Surv-BART (Sparapani et al., 2016) , and AFT-BART (Henderson et al., 2020) . See Hu et al. (2020) for an extensive investigation of the causal survival tree-based methods. Alternatively, when estimating the ITE, neural-network-based methods propose to regularize the transformed covariates or representations for an individual to have balanced distributions across treatment groups, thus accounting for the confounding bias and improving ITE prediction. However, most approaches employing representation learning techniques for counterfactual inference deal with continuous or binary outcomes, instead of time-to-event outcomes with censoring (informative or non-informative). Hence, a principled generalization to the context of counterfactual survival analysis is needed. In this work we leverage balanced (latent) representation learning to estimate ITE via counterfactual prediction of survival outcomes in observational studies. We develop a framework to predict event times from a low-dimensional transformation of the original covariate space. To address the specific challenges associated with counterfactual survival analysis, we make the following contributions: • We develop an optimization objective incorporating adjustments for informative censoring, as well as a balanced regularization term bounding the generalization error for ITE prediction. For the latter, we repurpose a recently proposed bound (Shalit et al., 2017) for our time-to-event scenario. • We propose a generative model for event times to relax restrictive survival linear and parametric assumptions, thus allowing for more flexible modeling. Our approach can also provide nonparametric uncertainty quantification for ITE predictions. • We provide survival-specific evaluation metrics, including a new nonparametric hazard ratio estimator, and discuss how to perform model selection for survival outcomes. The proposed model demonstrates superior performance relative to the commonly used baselines in real-world and semi-synthetic datasets. • We introduce a survival-specific semi-synthetic dataset and demonstrate an approach for leveraging prior randomized experiments in longitudinal studies for model validation.

