TIME SERIES COUNTERFACTUAL INFERENCE WITH HIDDEN CONFOUNDERS

Abstract

We present augmented counterfactual ordinary differential equations (ACODEs), a new approach to counterfactual inference on time series data with a focus on healthcare applications. ACODEs model interventions in continuous time with differential equations, augmented by auxiliary confounding variables to reduce inference bias. Experiments on tumor growth simulation and sepsis patient treatment response show that ACODEs outperform other methods like counterfactual Gaussian processes, recurrent marginal structural networks, and time series deconfounders in the accuracy of counterfactual inference. The learned auxiliary variables also reveal new insights into causal interventions and hidden confounders.

1. INTRODUCTION

Decision makers want to know how to produce desired outcomes and act accordingly, which requires causal understanding of cause and effect. In this paper, we consider applications in healthcare, where time series data on past features and outcomes are now widely available. Causality in time series have been long studied in statistics (Box et al., 2008) , and allows more powerful analysis than methods on time-independent data, like instrumental variable regression (Stock & Trebbi, 2003) . However, temporal causality in statistics and econometrics focuses mainly on passively discovering time lag structure (Eichler, 2012) . In contrast, decision-making applications need concrete interventions, which is more amenable to an interventionist approach to causality (Woodward, 2005; Pearl, 2009 ). To give one example, electronic health records (EHR) in healthcare provide an accessible history of a patient's disease progression over time, together with their treatment records and their results. To identify effective treatments, a doctor may want to ask counterfactual questions (Johansson et al., 2016) , like "Would this patient have lower blood sugar had she received a different medication?" Through such counterfactual analysis, medical professionals may hope to discover new cures and improve existing treatments. Similar situations arise in other use cases. For example, a user interface designer may want to ask "Would the user have clicked on this ad had it been in a different color?", substantiating their answer from counterfactual inference on clickstream data or other user behaviors. Counterfactual inference in time series has studied, assuming that all possible causal variables are observed (Soleimani et al., 2017; Schulam & Saria, 2017; Lim, 2018) . In practice, however, this assumption of perfect observability is not testable and too strong for many real-world scenarios (Bica et al., 2020) . For example, there are many ways in general to treat cancer, but each patient requires their own bespoke treatment plan based on unique characteristics of each case such as drug resistance and toxic response (Vlachostergios & Faltas, 2018; Kroschinsky et al., 2017; Bica et al., 2020) . However, these factors are also likely to be unmeasurable in practice, or otherwise not recorded in EHRs. Detecting these hidden confounding variables is therefore crucial to avoid bias in the estimation of treatment effects. The challenge introduced by confounders in counterfactual inference was first studied in the static setting. Wang & Blei (2019) developed a two-step method that estimates confounders with latent factor models, then infers potential outcomes with bias adjustment. However, confounders in time series can have their own dynamics, and can themselves be affected by the history of interventions. Subsequently, Bica et al. ( 2020) introducing recurrent neural networks (RNNs) into the factor model to estimate the dynamics of confounders. However, this method only works in discrete time setting with a fixed time step, due to how RNNs are structured. In this paper, we consider the continuous-time setting, which is more flexible in practice and provides more insights of the underlying mechanisms (Chen et al., 2018; Rubanova et al., 2019) . The continuous-time setting is particularly important for healthcare, where there are many time-varying treatments, irregularly-sampled or partially observed time series (Soleimani et al., 2017) . The classical modeling approach to dynamics uses ordinary differential equations (ODEs) d(x(t))/dt = f (x(t)), encoding domain expertise of underlying mechanisms in the explicit specification of f . In contrast, Chen et al. ( 2018) introduced the concept of neural ODEs by parameterizing f with neural networks, thus allowing dynamics to be described by arbitrarily complicated functions. Several extensions handle even more complicated issues like irregular sampling or switching dynamics (Jia & Benson, 2019; Kidger et al., 2020) . However, these methods cannot be directly applied to time series counterfactual inference, as they focus on initial value problems, which cannot describe interventions without explicit modification of f (Kidger et al., 2020). Furthermore, these existing methods can only handle hidden variables by explicitly describing their dynamics and interdependency with interventions, thus limiting their utility when confounders exist. Our contributions. We propose augmented counterfactual ODEs (ACODEs) to predict how a continuous-time time series will evolve under a sequence of interventions. Our method augments the observed time series with additional dimensions to represent confounders. We then construct counterfactual ODEs based on the neural ODE framework to model the effects of incoming interventions. The ACODE model has three key features. First, it allows for the presence of confounders that can reduce the prediction bias. Second, the ACODE can continuously incorporate incoming interventions using neural ODEs and support irregularly-sampled time series. Third, it demonstrates state-of-the-art performance against competitive baselines for counterfactual inference in both simulation of tumor growth and real-world time series of sepsis patients treatment response. Moreover, the ACODE provides an interface between machine learning and dominant modelling paradigm described in differential equations, which allows for well-understood domain knowledge to be applied to time series counterfactual inference. To the best of our knowledge, this represents the first method for counterfactual inference with confounders in the continuous-time setting.

2. RELATED WORK

Time series counterfactual inference stems from causal inference (Pearl, 2009; Eichler, 2012) . A large body of pioneering work in causal inference focus on causal relations such as structural causal models (Pearl, 2019) and Granger causality (Eichler, 2007) . Counterfactual inference, on the other hand, focus on estimating the effects of actionable interventions, which is a pervasive problem in healthcare (Hoover, 2018) . In literature, the difference between the counterfactual outcomes if an intervention had been taken or not is defined as the causal effect of the intervention (Pearl, 2009) . Originated from the literature on observational studies (Shadish et al., 2002 ), Rubin's potential outcome framework has been a popular language to formalize counterfactuals and intervention effect estimate (Rubin, 2005; Imbens & Rubin, 2015) . The problem of hidden confounders in counterfactual inference was first studied in the static setting. Wang & Blei (2019) developed theory for adjusting the bias introduced by the presence of hidden confounders in the observational data. They found out that the dependencies in these multiple confounders can be used to infer latent variables and act as substitutes for the hidden confounders. In this paper, we are interested in considering hidden confounders in time series setting which is much more complicated than in the static setting. Not only because the hidden confounders may evolve over time, but also because they might be affected by previous interventions. On the other hand, most existing work on time series counterfactual inference including counterfactual Gaussian processes (CGP) (Schulam & Saria, 2017) and recurrent marginal structural networks (RMSNs) (Lim, 2018) assume there is no hidden confounders, i.e. all variables affecting the intervention plan and the potential outcomes are observed, which is not testable in practice and not true in many cases. Recently, Bica et al. (2020) applied the idea of latent factor models from Wang & Blei (2019) to the deconfounding of time series. However, their proposed method is based on recurrent neural networks, which works only with discrete and regularly-spaced time series. Differential equations have been introduced into causal and counterfactual inference in previous studies. Rubenstein et al. (2018) showed that equilibrium states of a first-order ODE system can be described with a deterministic structural causal model, even with non-constant interventions. This

