SELECTING TREATMENT EFFECTS MODELS FOR DOMAIN ADAPTATION USING CAUSAL KNOWLEDGE Anonymous authors Paper under double-blind review

Abstract

Selecting causal inference models for estimating individualized treatment effects (ITE) from observational data presents a unique challenge since the counterfactual outcomes are never observed. The problem is challenged further in the unsupervised domain adaptation (UDA) setting where we only have access to labeled samples in the source domain, but desire selecting a model that achieves good performance on a target domain for which only unlabeled samples are available. Existing techniques for UDA model selection are designed for the predictive setting. These methods examine discriminative density ratios between the input covariates in the source and target domain and do not factor in the model's predictions in the target domain. Because of this, two models with identical performance on the source domain would receive the same risk score by existing methods, but in reality, have significantly different performance on the test domain. We leverage the invariance of causal structures across domains to introduce a novel model selection metric specifically designed for ITE models under the UDA setting. In particular, we propose selecting models whose predictions of the effects of interventions satisfy known causal structures in the target domain. Experimentally, our method selects ITE models that are more robust to covariate shifts on several synthetic and real healthcare datasets, including on estimating the effect of ventilation in COVID-19 patients from different geographic locations. Causal graph describing data generating process.

1. INTRODUCTION

Causal inference models for estimating individualized treatment effects (ITE) are designed to provide actionable intelligence as part of decision support systems and, when deployed on mission-critical domains, such as healthcare, require safety and robustness above all (Shalit et al., 2017; Alaa & van der Schaar, 2017) . In healthcare, it is often the case that the observational data used to train an ITE model may come from a setting where the distribution of patient features is different from the one in the deployment (target) environment, for example, when transferring models across hospitals or countries. Because of this, it is imperative to select ITE models that are robust to these covariate shifts across disparate patient populations. In this paper, we address the problem of ITE model selection in the unsupervised domain adaptation (UDA) setting where we have access to the response to treatments for patients on a source domain, and we desire to select ITE models that can reliably estimate treatment effects on a target domain containing only unlabeled data, i.e., patient features. UDA has been successfully studied in the predictive setting to transfer knowledge from existing labeled data in the source domain to unlabeled target data (Ganin et al., 2016; Tzeng et al., 2017) . In this context, several model selection scores have been proposed to select predictive models that are most robust to the covariate shifts between domains (Sugiyama et al., 2007; You et al., 2019) . These methods approximate the performance of a model on the target domain (target risk) by weighting the performance on the validation set (source risk) with known (or estimated) density ratios. However, ITE model selection for UDA differs significantly in comparison to selecting predictive models for UDA (Stuart et al., 2013) . Notably, we can only approximate the estimated counterfactual error (Alaa & van der Schaar, 2019), since we only observe the factual outcome for the received treatment and cannot observe the counterfactual outcomes under other treatment options (Spirtes et al., 2000) . Consequently, existing methods for selecting predictive models for UDA that compute a weighted sum of the validation error as a proxy of the target risk (You et al., 2019) is suboptimal for selecting ITE models, as their validation error in itself is only an approximation of the model's ability to estimate counterfactual outcomes on the source domain. To better approximate target risk, we propose to leverage the invariance of causal graphs across domains and select ITE models whose predictions of the treatment effects also satisfy known or discovered causal relationships. It is well-known that causality is a property of the physical world, and therefore the physical (functional) relationships between variables remain invariant across domains (Schoelkopf et al., 2012; Bareinboim & Pearl, 2016; Rojas-Carulla et al., 2018; Magliacane et al., 2018) . As shown in Figure 1 , we assume the existence of an underlying causal graph that describes the generating process of the observational data. We represent the selection bias present in the source observational datasets by arrows between the features {X 1 , X 2 }, and treatment T . In the target domain, we only have access to the patient features, and we want to estimate the patient outcome (Y ) under different settings of the treatment (intervention). When performing such interventions, the causal structure remains unchanged except for the arrows into the treatment node, which are removed. Contributions. To the best of our knowledge, we present the first UDA selection method specifically tailored for machine learning models that estimate ITE. Our ITE model selection score uniquely leverages the estimated patient outcomes under different treatment settings on the target domain by incorporating a measurement of how well these outcomes satisfy the causal relationships in the interventional causal graph G T . This measure, which we refer to as causal risk, is computed using a log-likelihood function quantifying the model predictions' fitness to the underlying causal graph. We provide a theoretical justification for using the causal risk, and we show that our proposed ITE model selection metric for UDA prefers models whose predictions satisfy the conditional independence relationships in G T and are thus more robust to changes in the distribution of the patient features. We also show experimentally that adding the causal risk to existing state-of-the-art model selection scores for UDA results in selecting ITE models with improved performance on the target domain. We provide an illustrative example of model selection for several real-world datasets for UDA, including ventilator assignment for COVID-19.

2. RELATED WORKS

Our work is related to causal inference and domain adaptation. In this section, we describe existing methods for ITE estimation, UDA model selection in the predictive setting, and domain adaptation from a causal perspective. ITE models. Recently, a large number of machine learning methods for estimating heterogeneous ITE from observational data have been developed, leveraging ideas from representation learning (Johansson et al., 2016; Shalit et al., 2017; Yao et al., 2018 ), adversarial training, (Yoon et al., 2018) , causal random forests (Wager & Athey, 2018) and Gaussian processes (Alaa & van der Schaar, 2017; 2018) . Nevertheless, no single model will achieve the best performance on all types of observational data (Dorie et al., 2019) and even for the same model, different hyperparameter settings or training iterations will yield different performance.

ITE model selection.

Evaluating ITE models' performance is challenging since counterfactual data is unavailable, and consequently, the true causal effects cannot be computed. Several heuristics for estimating model performance have been used in practice (Schuler et al., 2018; Van der Laan & Robins, 2003) . Factual model selection only computes the error of the ITE model in estimating the factual patient outcomes. Alternatively, inverse propensity weighted (IPTW) selection uses the estimated propensity score to weigh each sample's factual error and thus obtain an unbiased estimate (Van der Laan & Robins, 2003) . Alaa & van der Schaar (2017) propose using influence functions to approximate ITE models' error in predicting both factual and counterfactual outcomes. Influence



Figure 1: Method overview. We propose selecting ITE model whose predictions of the treatment effects on the target domain satisfy the causal relationships in the interventional causal graph G T .

Y

< l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t >

Candidate ITE models

Test dataset

Interventional causal graph

X 1< l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > X 2< l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > T < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > Y < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > do(T = t)< l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t >

Source domain Target domain

Train Test

Select causal inference model robust to distributional shift in target domain

Intervene on T < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t >

Assess fitness of predictions to

G < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > G T < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > G T< l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t >x ⇠ p µ (X)< l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t >x ⇠ p ⇡ (X)< l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > p µ (X) 6 = p ⇡ (X)< l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t >Covariate shifts between domains D src < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > D tgt < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " ( n u l l ) " > ( n u l l ) < / l a t e x i t >

