GENERALIZATION BOUNDS AND ALGORITHMS FOR ES-TIMATING THE EFFECT OF MULTIPLE TREATMENTS AND DOSAGE Anonymous authors Paper under double-blind review

Abstract

Estimating conditional treatment effects has been a longstanding challenge for fields of study such as epidemiology or economics that require a treatment-dosage pair to make decisions, but may not be able to run randomized trials to precisely quantify their effect. This may be due to financial restrictions or ethical considerations. In the context of representation learning, there is an extensive literature relating model architectures with regularization techniques to solve this problem using observational data. However, theoretically motivated loss functions and bounds on generalization errors only exist in selected circumstances, such as in the presence of binary treatments. In this paper, we introduce new bounds on the counterfactual generalization error in the context of multiple treatments and continuous dosage parameters, which subsume existing results. This result, in a principled manner, guides the definition of new learning objectives that can be used to train representation learning algorithms. We show empirically new stateof-the-art performance results across several benchmark datasets for this problem, including in comparison to doubly-robust estimation methods.

1. INTRODUCTION

Treatment effect estimation is the problem of predicting the effect of an intervention (e.g. a treatmentdosage pair) on an outcome of interest to guide decision-making. The challenge for prediction models is to learn this map from observational data, which is formally generated from a different structural causal model in which treatment assignment varies according to an individual's covariates, instead of being fixed by the decision-maker. Counterfactuals define the outcome that would have been observed had the assigned treatment been different. For concreteness, consider designing a policy for the administration of chemotherapy regiments; not all cancer patients in the available data are equally likely to be offered the same type and dosage, with varied factors, e.g. age, wealth, etc., involved in the decision-making process. Evaluating a new treatment combination for a given patient is a data point that is invariably under-represented in the empirical distribution of the data. Treatment effect estimation is studied under a wide range of assumptions, including experimental designs that feature ignorability (Imbens, 2000; Imai & Van Dyk, 2004) , multiple treatments, sequential decision-making problems, and different generative models encoded in general causal graphs (Pearl, 2009) . There is a growing literature on several parts of this problem in the field of machine learning that attempts to define loss functions that are conducive to learning representations of covariates predictive of both observed and counterfactual outcomes. Existing methods could be generally categorized by the theoretical guarantees that inspire training objectives, driven either by bounds for the generalization error or by doubly-robustness guarantees. In the first line of research, Shalit et al. (2017) ; Johansson et al. (2020) showed in the binary treatment setting that the counterfactual error, that is not computable from data by design, could be instead bounded by the in-sample error plus a term that quantifies the difference in distributions between treated and untreated populations, leading to a differentiable loss function that can be used to train expressive neural networks. Several papers used this insight to investigate different neural network architectures for this problem. For example, Johansson et al. (2016) proposed to use separate feed-forward prediction heads on top of a common representation, Zhang et al. (2022) use transformers, De Brouwer et al. (2022) ; Seedat et al. (2022) use neural differential equations. In turn, doubly-robust estimators combine expressive function approximators and inverse probability weighting leveraging statistical non-parametric asymptotic guarantees of both estimators (Funk et al., 2011; Kennedy, 2016; 2020) . In particular, when the direct estimate of the outcome is biased, such as when using nonparametric or high-dimensional regression, the doubly robust estimator weights the model residuals by inverse propensity weights in order to remove the bias. Its convergence and consistency for treatment effect estimation requires only that one of the estimators is consistent. In principle, any consistent function approximator could be used, which in the context of neural networks has led to several adaptations of loss functions and architectures. For example, Shi et al. (2019) adapted the architecture of Johansson et al. (2016) for this purpose introducing targeted regularization, and Nie et al. (2020) proposed varying coefficient networks in the context of continuously-valued dosage parameters. In both cases, however, the authors provide guarantees for population average treatment effect estimation, in contrast with conditional average treatment effect estimation. Despite the generality of these results, no guarantees and no theoretically motivated loss functions exist for learning representations for counterfactual estimation in the general setting of multiple treatment types and/or continuous treatment values or dosages. The challenge in the context of representation learning is that there is no notion of treatment group as each individual gets assigned a potentially different and unique treatment value. Lack of overlap in finite samples and subsequently large estimation variance for counterfactual predictions are exacerbated in this setting to the extreme that adjustments for distributional differences are, in principle, not applicable. In particular, the intuition for reducing variance by regularization deviates from previous proposals (that regularize representations of covariates to match distributions among groups with different treatment types (Shalit et al., 2017) ) as a potentially infinite set of counterfactuals for each individual must be considered. Even the analysis of multiple categorical treatments is currently an open question as, while pairwise comparisons between treatment specific distributions could be implemented in principle, it is not computationally tractable to do so in practice. At this moment, only heuristic neural network architectures for this problem have been proposed, including Dose Response networks that consist of multi-task layers for dosage sub-intervals defined on top of a common representation (Schwab et al., 2020) , variants of generative adversarial networks (Bica et al., 2020) , and varying coefficient networks (Nie et al., 2020) . In this paper, we investigate the design of representation learning-based algorithms for predicting (conditional average) treatment effects in the context of multiple treatments and continuous dosage parameters. Our analysis starts by extending definitions of loss and generalization error to this broader setting, over all possible treatment-dosage pairs. We then show by using the definition of integral probability metrics that the generalization error can be bounded by a term that is computable from data and that involves the factual error and a term that quantifies the statistical dependence between the pair of treatment-dosage random variables and observed confounders. In principle, any treatment space on which we can define a probability measure is consistently accounted for, which gives welldefined bounds on the generalization error for treatments with multiple types and continuous values, and in particular, our bound includes as a special case existing guarantees in the binary treatment case (Shalit et al., 2017) . This bound suggests new training objectives for learning representations conducive to counterfactual estimation. Moreover, such objectives are tractable: both avoiding combinatorial numbers of pairwise comparisons and avoiding binning dosage values into different sub-intervals. A further contribution we make is to design extensive numerical comparisons that compare both methods driven by bounds on the generalization error (that typically target conditional average treatment effects) and methods driven by doubly-robust guarantees (that typically target average treatment effects). Moreover, we do so independently of the adopted neural architecture which provides the first analysis of different objectives for the problem of treatment effect estimation with multiple, continuously-valued treatments. We hope these results can give some insight into the trade-offs of different approaches to this problem and demonstrate the ability of representation learning techniques to tackle wider ranging scenarios within treatment effect estimation.

2. BACKGROUND

We start by introducing the notation and definitions used throughout the paper. In particular, we use capital letters for random variables pXq, small letters for their values pxq, bold letters for sets of variables pXq and their values pxq, and Ω for the spaces where they are defined pΩ X q if not explicitly stated. To simplify notation, we consistently use the shorthand P pxq to represent probabilities or densities P pX " xq and similarly P py | xq to represent P pY " y | X " xq. For three sets of variables X, Y, Z the conditional independence statement "X is conditionally independent of Y given Z " z" is written as X |ù Y|Z. We use the semantics of the Rubin-Neyman potential outcomes framework, see e.g. Section 2 in Rubin (2005) . We assume that for an individual with observed covariates x P Ω X , and tuple T " pW, Sq defining the treatment type out of k distinct treatments W P Ω W " tw 1 , . . . , w k u and dosage parameter S P Ω S " R, there is a corresponding potential outcome Y t that would have been observed had the assigned treatment been T " t. With observational data only one of these potential outcomes is observed for each unit depending on the treatment assignment. We will refer to the unobserved potential outcomes as counterfactuals. Let Y denote the observed outcome and Assumption 2 (Overlap). For any x P Ω X such that P pxq ą 0, we have 1 ą P pt | xq ą 0 for each t P Ω T . Ω T " Assumption 3 (Consistency). The observed outcome is the potential outcome, as a function of treatment, when the treatment is set to the observed exposure, i.e. Y " Y t if T " t for any t P Ω T . Proposition 1 (Identifiability). Under assumptions 1 and 3, and any t 1 , t 2 P Ω T , ErY t1 ´Yt2 | xs " ErY | x, t 1 s ´ErY | x, t 2 s, which is composed entirely of observable quantities and can be estimated from data given Assump. 2. We refer to the quantity ErY t1 ´Yt2 | xs as the conditional (or individual if the conditioning set identifies a unit) treatment effect (CATE), and the ErY t1 ´Yt2 s as the average, or population, treatment effect (ATE). Our results rely on defining representation functions ϕ : Ω X Ñ Ω R , where Ω R is the representation space, that preserve unconfoundedness and overlap, and the identifiability of the treatment effect. For this purpose, it is sufficient to assume ϕ to be injectivefoot_0 . Corollary 1 (Identifiability given representation). Under the assumption that the representation ϕ is injective, 1 ą P pt | ϕpxqq ą 0 and Y T |ù T | ϕpXq, that is unconfoundedness and overlap hold conditional on features ϕpxq. Without loss of generality we will assume that Ω R is the image of Ω X under ϕ. We will write P also to denote the distribution induced by ϕ over Ω R and let h : Ω R ˆΩT Ñ Ω Y be a prediction function defined over Ω R . Next, we define two complimentary loss functions: one is the standard machine learning loss, which we will call the factual error on the estimation at the observed treatment type and dosage tuple, and the other is the counterfactual error, as an average error over all other treatment assignment options, made for an individual with a particular treatment type and dosage tuple. Definition 1. For a given loss function L : Ω Y ˆΩY Ñ R `, the expected factual and counterfactual losses of h and ϕ at treatment t P Ω T are defined as, L F ptq " ż Ω X ż Ω Y Lpy t , hpϕpxq, tqqP py t |xqP px|tqdxdy t , L CF ptq " E t 1 "P ż Ω X ż Ω Y Lpy t , hpϕpxq, tqqP py t |xqP px|t 1 qdxdy t . (3) The counterfactual error defines the average error made for the counterfactual prediction at treatment tuple t " pw, sq on all individuals that are observed to be assigned a different treatment t 1 ‰ t. This definition extends the binary treatment case to assess the quality of counterfactual predictions at t P Ω T . Similarly, we define an average measure of factual and counterfactual performance over all possible treatment options t P Ω T . Definition 2. The average factual and counterfactual error over all treatment options are defined. L F " ż Ω T L F ptqP ptqdt, L CF " ż Ω T L CF ptqP ptqdt. Next we define the error made on the estimation of a counterfactual contrast for a given pair of treatments, instead of an average over all counterfactual treatment options. Definition 3. Let the treatment effect between two different treatments tuples t 1 , t 2 P Ω T be given by τ pt1,t2q pxq " ErY t1 | xs ´ErY t2 | xs. The error in treatment effect estimation is then defined as, L pt1,t2q :" ż Ω X Lpτ pt1,t2q pxq, τpt1,t2q pxqqP pxqdx, where τpt1,t2q : Ω X Ñ Ω Y denotes its estimate.

3. REPRESENTATION LEARNING FOR COUNTERFACTUAL ESTIMATION

As is apparent in the presence of multiple treatments and continuously-valued dosages, there is no notion of treatment group as each individual gets assigned a potentially different and unique treatment value. The intuition for reducing variance by regularization deviates from previous proposals as a potentially infinite set of counterfactuals for each individual must be considered Shalit et al. (2017) . The following theorem shows that the average counterfactual error defined in Def. 2 can be bounded by terms that are explicitly computable from observational data. Theorem 1 (Bound on average counterfactual generalization error). Under the assumption that ϕ is injective, it holds that, L CF ď L F `λ ¨sup gPΩg ˇˇż Ω T ż Ω R gpr, tq ¨pP prqP ptq ´P pr, tqqdrdt ˇˇ. Ω g defines a space of functions g : Ω R ˆΩT Ñ R expressive enough to include ş Ω Y Lpy t , hpϕpxq, tqqP py t |xqdy t {λ as a function of ϕpxq and t, where λ ą 0 depends on the choice of representation function ϕ. This theorem states that the average counterfactual error is upper-bounded by the factual error plus a term that quantifies the dependence between treatment tuple T and covariates X. As the treatment tuple contains multiple treatment types w, as well as continuous dosages s, this single bound is valid for multiple treatment values as well as continuous dosages. Bias Variance tradeoff Counterfactual estimation would be unbiased by minimizing factual losses L F by Prop. 1, but the variance in the estimation of counterfactuals for treatment-dosage pairs that are not heavily represented in observational data will be high. This will contribute to larger generalization error and is captured in the supremum in the second term of Eq. ( 6). In particular, the supremum quantifies an imbalance in the association of T and R by using distributional distances between joint distributions and the product of marginals. |P prqP ptq ´P pr, tq| is large if not all treatment and feature combinations are evenly represented in the data. This observation recovers an interesting intuition if Ω g is chosen to be expressive enough. The observation being that sup gPΩg ˇˇş Ω T ş Ω R gpr, tq ¨pP prqP ptq ´P pr, tqqdrdt ˇˇ" 0 if and only if the representation is independent of treatment assignment, i.e. ϕpXq |ù T . This extreme case leads to lower variance as counterfactuals for a treatment-dosage tuple have the same effective sample size as that of the observational data. The hyperparamter λ controls the tradeoff between the bias and the variance of the counterfactuals. Two choices for Ω g we consider are the space of functions in a universal Reproducing Kernel Hilbert Space (RKHS) with characteristic kernels (Sriperumbudur et al., 2011) , which recovers the well-known Hilbert Schmidt Independence Criterion (Gretton et al., 2007) , and the space of Lipschitz functions with Lipschitz constant bounded by 1 which recovers the Wasserstein distance (Villani, 2009) . Binary treatment case One insight from Thm. 1 is that bias in the treatment assignment in the context of a general treatment choices, such as multiple treatment types or continuously-valued treatments, takes the form of high statistical dependence between random variables, that is more general than differences between distributions. In particular, differences in distributions between treatment groups as defined by Shalit et al. (2017) in the binary treatment case can be formulated as statistical dependence between random variables. The following corollary recovers the generalization bound of Shalit et al. (2017) as a special case. Corollary 2. Let Ω T " t0, 1u. Then, by Thm. 1, L CF ď L F `λ ¨sup gPΩg ˇˇż Ω R gprq ¨pP pr | T " 1q ´P pr | T " 0qqdr ˇˇ, and is equivalent to (Shalit et al., 2017, Lemma 1) . We show next a similar result that gives generalization bounds for the treatment effect comparing two specific treatment options, instead of an average over all possible counterfactual options, that may be of interest in applications specifically comparing two treatment options. Theorem 2. Let t 1 , t 2 P Ω T be two treatment tuples to be compared. Then, L pt1,t2q {2 ď L F pt 1 q `sup gPΩg ˇˇż Ω R gprq ¨pP prq ´P pr | T " t 1 qqdr ˇˇ`LF pt 2 q `sup gPΩg ˇˇż Ω R gprq ¨pP prq ´P pr | T " t 2 qqdr ˇˇ´σY t 1 ´σYt 2 , ( ) where σ Yt 1 and σ Yt 2 stand for the variance of the random variables Y t1 and Y t2 , respectively, under the distribution P pxq.

3.1. ARCHITECTURES AND ALGORITHMS FOR COUNTERFACTUAL ESTIMATION

This section discusses the architectures of the representation and prediction functions used, as well as training objectives to leverage the generalization bound in Thm. 1. The training objective that we define can be used with any neural network architecture that parameterizes a representation function ϕ η : Ω X Ñ Ω R and a separate prediction function h θ : Ω R ˆΩT Ñ Ω Y with sets of parameters η and θ respectively. Following the discussion in Sec. 3, we learn a representation ϕ and prediction function h to minimize a trade-off between predictive accuracy and imbalance in the representation space using the following objective: min θ,η N ÿ n"1 ´ypnq ´hθ pϕ η px pnq q, t pnq q ¯2 `γ ¨IPM Ωg pϕpXq, T q, where γ ě 0 is a hyperparameter, n is the number of samples, and IPM Ωg pϕpXq, T q :" sup gPΩg ˇˇş Ω T ş Ω R gpr, tq ¨pP prqP ptq ´P pr, tqqdrdt ˇˇis the integral probability metric for a chosen space of functions Ω g . Concretely, we wish to increase the predictive accuracy while making the representation as independent of the treatment as possible. We consider the Hilbert Schmidt Independence Criterion (HSIC) and the Wassertein distance as choices for the integral probability metric. In practice, the HSIC can be approximated with a finite data sample using (Gretton et al., 2007, Eq. (3)). For the Wasserstein distance, we simulate a sample with joint distribution P prqP ptq by randomly permuting the observed treatment-dosage pair across individuals to generate a sample tpr pnq , t pσpnqq q : n " 1, . . . , N u, where σ : t1, . . . , N u Ñ t1, . . . , N u is a bijective function. The original data tpr pnq , t pnq q : n " 1, . . . , N u is drawn from the distribution P pr, tq. The two empirical distributions are compared using the arguments in (Cuturi & Doucet, 2014) . Both these regularization terms are differentiable and all parameters can be updated using stochastic gradient descent. Each treatment type w corresponds to a separate prediction network head, i.e. h θ :" th pwq θ u wPΩ W , while the representation layer is common across all treatment types. In particular, this implies that each sample px pnq , w pnq , s pnq , y pnq q is used to update only the prediction network h pw pnq q θ corresponding to the observed treatment w pnq , while all samples are used to update the representation layer ϕ η . A sketch of this training routine is given in Fig. 1 . The following network architectures for the prediction functions th pwq u wPΩ W have been proposed in the literature. Dose Response Networks (DRNet) Schwab et al. (2020) propose Dose Response Networks for predicting the effect of dosage on an outcome of interest. The architecture takes the form of a multitask network with a shared set of layers and multiple task-specific heads. In this context, the range of dosage values is split into separate bins and each of them is associated with a separate head. Each task-specific network in addition takes the dosage value as input, but crucially the parameterization of the prediction function is common to all dosages belonging to the same sub-interval. For example, the range of dosage values for treatment type w could be divided into 5 sub-intervals, thus using 5 task-specific heads h pwq θ " ph pw,1q θ , . . . , h pw,5q θ q, h pw,iq θ : Ω R ˆΩT Ñ Ω Y , i " 1, . . . , 5. To some extent this approach accounts for the heterogeneity in the dose-response function but remains limited by the binning choice and may be vulnerable to abrupt changes in the prediction on the dosage values that separate two bins, as demonstrated by Nie et al. (2020) . Varying Coefficient networks (VCNet) Varying Coefficient networks (Nie et al., 2020) are proposed for dose-response estimation, but a multi-task architecture can be designed as a special case. In particular, the authors define the parameters θ for each prediction network h pwq θ :" h pwq θpsq : Ω R Ñ Ω Y to be functions θpsq " pθ 1 psq, . . . , θ d θ psqq of dosage themselves, where d θ is the total number of parameters. Each scalar parameter θ i : Ω S Ñ R is given by a linear combination θ i psq " ř L l"1 α i,l ψ l psq of polynomial basis functions tψ l u L l"1 defined on the space of dosage values Ω S . The coefficients tα i,l : i " 1, . . . , d θ , l " 1, . . . , Lu define the trainable parameters and the map h pwq θ pr, sq :" h pwq θpsq pr, sq is differentiable with respect to tα i,l : i " 1, . . . , d θ , l " 1, . . . , Lu. For example, DRNets are recovered by choosing tψ l u L l"1 to be a piece-wise constant functions spline basis of the form 1ps i ď t ă s j q with different s i , s j . More general choices can be made, such as B-splines, that lead to continuous dose response curves. The influence of the dosage parameter is different to that of a covariate and thus ensures dosage information is not lost in high-dimensional representations, which in practice has been shown to lead to better counterfactual prediction performance.

4. EXPERIMENTS

This section conducts controlled experiments on synthetic and semi-synthetic datasets previously used in the literature. Overall, we found that simulation results support our generalization guarantees with different architectures benefiting from the proposed regularization strategy using both the HSIC and Wasserstein distances.

4.1. BASELINES AND METRICS

We consider several baselines for comparison, including different neural network architectures without regularization and with doubly-robust regularization techniques. In particular, we consider a standard multilayer perceptron (MLP) that optimises the (factual) squared error loss objective to learn the weights of the network, a standard VCNet (Nie et al., 2020) , and DRNet (Schwab et al., 2020) . In the context of doubly-robust optimization, Shi et al. (2019) ; Nie et al. (2020) propose to learn a joint representation ϕpxq that is conducive to both counterfactual h 1 : Ω R ˆΩT Ñ Ω Y and propensity score estimation h 2 : Ω R Ñ Ω T by a using a loss function that trades-off the two objectives, e.g., 1 N N ÿ n"1 ´ypnq ´h1 pϕpx pnq q, t pnq q ¯2 `α ¨CrossEntropy ´h2 pϕpx pnq qq, t pnq ¯, If h 1 and h 2 are consistent estimators of the outcome and propensity scores respectively, as well as satisfy the non-parametric estimating equation, 1 N N ÿ n"1 µpy pnq , t pnq , x pnq ; ĥ1 , ĥ2 , εq " 0, where ϵ denotes a perturbation term that is optimized and where (in the binary treatment case for simplicity), µpy, t, x; h 1 , h 2 , ϵq " h 1 px, 1q ´h1 px, 0q `ˆt h 2 pxq ´1 ´t 1 ´h2 pxq ˙¨py ´h1 px, tqq ´ϵ, ( ) then the resulting estimator will have desirable asymptotic properties for average treatment effect (Shi et al., 2019; Kennedy, 2016) . We consider h 1 parameterized by both VCNets and DRNets. Algorithms trained to minimize Eq. ( 10) are denoted VCNet-PS, DRNet-PS, and algorithms trained to minimize both Eqs. ( 10) and ( 11) (also known as Targeted Regularization), are referred to as VCNet-TR, DRNet-TR. Finally, we consider Generalized Propensity Scores (GPS) (Imbens, 2000; Imai & Van Dyk, 2004 ) that fit a linear model using inverse propensity scores. Our proposed methods are labeled DRNet-HSIC, DRNet-Wass, VCNet-HSIC, and VCNet-Wass, which combine existing architectures with the proposed regularization methods. We include details on network architectures, hyperparameters optimisation and computational time in Appendix C. For performance comparisons, we consider the Mean Integrated Squared Error (MISE), MISE " 1 N 1 k N ÿ n"1 ÿ wPΩ W E " ´ypnq pw,sq ´ŷ pnq pw,sq ¯2ȷ , ( ) where we use the notation y pnq pw,sq and ŷpnq pw,sq for the true and predicted outcome for individual n given treatment-dosage pairs pw, sq P Ω T , and the expectation is taken with respect to the dosage parameter, i.e. E " y pnq pw,sq ı " ş Ω S y pnq pw,sq P psqds. Intuitively, MISE calculates how well an algorithm is at estimating individual level dose response and thus accounts for the heterogeneity in treatment response. In contrast, the Average Mean Squared Error (AMSE) evaluates population average counterfactual prediction by taking sums and integrals before comparisons between predicted and true outcomes. We define and evaluate AMSE in Appendix D.

4.2. DATASETS

The nature of the treatment-effects estimation problem does not allow for meaningful evaluation on real-world datasets. This is simply because we never observe a counterfactual for a given unit. There are, however, established synthetic and semi-synthetic datasets that have been used by Schwab et al. (2020) ; Bica et al. (2020) ; Nie et al. (2020) . Following these proposals we use, • Fully synthetic. A data generating mechanism with a total of 6 randomly generated covariates and a single treatment with dosage ranging from 0 to 1 that involve complex functions for both treatment assignment and outcome function, as defined by Nie et al. (2020) . • IHDP-continuous. The original semi-synthetic IHDP dataset from Hill (2011) contains binary treatments with 747 observations on 25 covariates. We adapt this dataset to the continuous dosage context by changing the treatment assignment and outcome function. We generate these in a similar way to Nie et al. ( 2020). • News. The News dataset consists of 3000 randomly sampled news items from the NY Times corpus (Newman, 2008) , which was originally introduced as a benchmark in the binary treatment setting. We generate a continuously-valued treatment and corresponding outcome in a similar way as Bica et al. (2020) . In each of our experiments we generate 50 independent realizations from each of the above datasets (20 for News), with samples split into a train/validation/test set with ratios 0.6/0.2/0.2. Further details on the data generating mechanisms, as well as about networks architecture, hyper-parameters tuning and training times are provided in Appendix C.

4.3. EFFECTIVENESS OF REGULARISATION

Our first experiment tests the effectiveness of the proposed regulariser by evaluating counterfactual prediction performance as a function of γ that determines the influence of the independence constraint in feature space in Eq. ( 9). We consider both DRNets and VCNets architectures, with both HSIC and Wasserstein regularizers on the Synthetic and IHDP-continuous datasets. Fig. 2 compares MISE performance results for these models with varying values of γ relative to γ " 0 (without regularisation). Both datasets include confounding factors which induce bias or imbalance in the treatment assignment T for different covariate subgroups X. On both plots we observe that the proposed regularization term (with increasing γ ą 0 relative to γ " 0) confers an advantage to training with a regularization term that explicitly corrects for this imbalance for the purpose of predicting counterfactuals. The gain of some γ ą 0 is consistent across different neural network architectures and across different datasets, which illustrates our generalization guarantees but also shows that some form of regularisation may broadly be applicable in practice.

4.4. PERFORMANCE COMPARISONS

In this section we conduct a wide-range comparison against the benchmark prediction algorithms using the three data generating mechanisms described in Section 4.2. Table 1 reports average values and standard deviations of ? MISE over 50 (20) realizations of Synthetic and IHDP-continuous (News) datasets. On average, the proposed regularization technique, using either the HSIC or Wasserstein distances between distributions, outperforms all other regularization techniques on both choices of neural network architecture. Several trends are interesting to discuss in more detail. Existing representation learning algorithms that optimise doubly-robust objectives are not always optimal. The results show that, in terms of the MISE, our regularisation based on counterfactual generalisation outperforms doubly robust methods. This can be explained by the fact that doubly robust methods have guarantees when estimating average treatment effects, and not individual or conditional treatment effects. The proposed regularization techniques, with guarantees for counterfactual generalization error, instead, are designed for good performance in conditional average treatment effect estimation and often substantially outperform in terms of MISE. We believe that this discrepancy is due to the doubly robust methods discarding information that helps predict the individual outcome, resulting in a worse MISE performance. This also emphasizes the fact that estimating average counterfactuals and individual counterfactuals can require different objectives. Indeed, the cross-entropy term in Eq. ( 10) encourages the representation to retain information that is predictive of the treatment; hence, it encourages the discarding of information that is predictive of the outcome but not the treatment, which is simply noise when predicting the treatment. On average there is also a significant gain by considering more expressive neural network architectures, for instance DRNet outperforms MLP and VCNet outperforms DRNet on all metrics and data generating mechanisms. Finally, we note that GPS requires matrix inversion which was not feasible to compute on the high-dimensional News dataset.

5. CONCLUSION

In this paper, we investigate the task of estimating the conditional average causal effect of dosage from a combination of observational data and assumptions on the causal relationships in the underlying system. When these assumptions hold, we give new bounds on the counterfactual generalization error in the context of multiple treatment types and continuously-valued dosage parameters that subsume generalization guarantees from the binary treatment case. Using this result, we provide new learning objectives that can be used to guide the training of representation learning algorithms. We show empirically new state-of-the-art performance results across several benchmark datasets for this problem. To our knowledge, this is the first paper exploring representation learning and regularization for conditional average counterfactual estimation in the context of multiple, continuous-valued treatments in a principled manner. We hope these results can demonstrate the ability of representation learning techniques to tackle wider ranging scenarios within treatment effect estimation.

A RELATED WORK ON DOUBLY ROBUST ESTIMATION OF THE AVERAGE TREATMENT EFFECT

Thm. 1 suggests that the imbalance in the distribution of X across treatment dosage pairs is relevant for the expected generalization error of fitted models. Estimators inspired from the semi-parametric literature, known as doubly robust estimators (Van Der Laan & Rubin, 2006; Chernozhukov et al., 2017) , instead try to optimize average treatment effects (ATE), e.g. E X ErY 1 | xs ´EX ErY 0 | xs, by constructing a prediction function h 1 : Ω X ˆΩT Ñ Ω Y , a propensity score function h 2 : Ω X Ñ Ω T , and perturbation term ϵ, satisfying the non-parametric estimating equation, 1 N N ÿ n"1 µpy pnq , t pnq , x pnq ; ĥ1 , ĥ2 , εq " 0, ( ) where (in the binary treatment case for simplicity), µpy, t, x; h 1 , h 2 , ϵq " h 1 px, 1q ´h1 px, 0q `ˆt h 2 pxq ´1 ´t 1 ´h2 pxq ˙¨py ´h1 px, tqq ´ϵ. ( ) h 1 px, tq is an estimator of ErY t | xs, while h 2 pxq is an estimator of the probability of treatment P pt | xq and ϵ P Ω T is a perturbation term that is optimized. In the literature, a common estimation approach is to rely on (task-agnostic) fitted models ĥ1 and ĥ2 , and then choose ϵ so that this equation is satisfied. If h 1 and h 2 are consistent estimators of the outcome and propensity scores respectively, as well as satisfy Eq. ( 14), the resulting estimator of the ATE will have desirable asymptotic properties (Shi et al., 2019; Kennedy, 2016) . However, as these guarantees are on the average treatment effects, they do not necessarily guarantee accurate estimates of conditional treatment effects. In the context of neural networks, Shi et al. (2019) ; Nie et al. (2020) propose to learn a joint representation ϕpxq that is conducive to both counterfactual h 1 : Ω R ˆΩT Ñ Ω Y and propensity score estimation h 2 : Ω R Ñ Ω T by a using a loss function that trades-off the two objectives, e.g., 1 N N ÿ n"1 ´ypnq ´h1 pϕpx pnq q, t pnq q ¯2 `α ¨CrossEntropy ´h2 pϕpx pnq qq, t pnq ¯, as in (Nie et al., 2020, Eq. (1)) or (Shi et al., 2019, Eq. (2.2)). The motivation is that: "If the average treatment effect is identifiable conditioning on the propensity score [. . . ] it suffices to adjust for only the information in x that is relevant for predicting the treatment", see (Shi et al., 2019, Theorem 2.1) . Intuitively, the cross entropy term in Eq. ( 16) encourages the representation to retain information that is predictive of the treatment. Hence, it encourages the discarding of information that is predictive of the outcome but not the treatment, which is simply noise when predicting the treatment. Variables that affect the outcome and not treatment are referred to as effect modifiers in the literature, see e.g. (Hernán & Robins, 2010) . By definition, the treatment effect varies across different conditioning sets of these effect modifiers. As effect modifiers are responsible for the heterogeneity of treatment effects, it is necessary to condition on them to obtain accurate conditional treatment effects. Thus, to compute conditional average or "individualized" treatment effects such representations may be too restrictive because they tend to ignore effect modifiers. In contrast, our regularizer penalizes the dependence between the representation and the treatment distributions explicitly. Loosely speaking we discard covariate information predictive of treatment but outcome information is retained. Hence, our regularizer should preserve these effect modifiers leading to more accurate estimates of conditional treatment effects. We conclude that, in general, optimal average treatment effects does not necessarily imply optimal conditional average treatment effects as measured by expected losses in Definitions 1 and 2 2 . We verify this intuition in our experiments. 2 Definitions 1 and 2 also involve averages but makes a head to head comparisons between observed outcomes and predicted outcomes for each individual px, tq in the term ş Ω Y Lpyt, hpϕpxq, tqqP pyt|xqdyt (which are then averaged across individuals) instead of averaging predicted counterfactuals across the whole population before comparison with average true outcomes across different dosage levels.

B PROOFS

Theorem 1 (Generalization bound for the average counterfactual error). Under the assumption that ϕ is one to one, it holds that, L CF ď L F `λ ¨sup gPΩg ˇˇż Ω T ż Ω R gpr, tq ¨pP prqP ptq ´P pr, tqqdrdt ˇˇ. (17) Ω g defines a space of functions g : Ω R ˆΩT Ñ R expressive enough to include ş Ω Y Lpy t , hpϕpxq, tqqP py t |xqdy t {λ as a function of ϕpxq and t, where λ ą 0 depends on the choice of representation function ϕ. Proof. Let ψ : Ω R Ñ Ω X be the inverse of ϕ and let l h,ϕ px, tq :" ş Ω Y Lpy t , hpϕpxq, tqqP py t |xqdy t . The following derivations show the claim. L CF ´LF " ż Ω T ż Ω X l h,ϕ px, tqP pxqP ptqdxdt ´żΩ T ż Ω X l h,ϕ px, tqP px|tqP ptqdxdt " ż Ω T ż Ω X l h,ϕ px, tq ¨pP pxqP ptq ´P px, tqqdxdt " ż Ω T ż Ω R l h,ϕ pψprq, tq ¨pP pψprqqP ptq ´P pψprq, tqqdψprqdt " ż Ω T ż Ω R l h,ϕ pψprq, tq ¨pP prqP ptq ´P pr, tqqJ ψ J ´1 ψ drdt ď λ ¨sup gPΩg ˇˇż Ω T ż Ω R gpr, tq ¨pP prqP ptq ´P pr, tqqdrdt ˇˇ. For the third equality, the distribution P over Ω R ˆΩT can be obtained by the standard change of variables formula, using the determinant of the Jacobian of ψprq, denoted J ψ giving P pψprq, tq " P pr, tqJ ψ (which cancels with the inverse Jacobian that appears after the change of variables in the differential term). The last inequality comes from the assumption that l h,ϕ px, tq{λ P Ω g , which is justified and extensively discussed in Shalit et al. (2017) . To prove Thm. 2, we will use the following lemma. " 2pL CF pt 1 q ´σYt 1 q `2pL CF pt 2 q ´σYt 2 q (27) ď 2 ˜LF pt 1 q `sup gPΩg ˇˇż Ω R gprq ¨pP prq ´P pr|T " t 1 qqdr ˇˇ´σY t 1 ¸( 28) `2 ˜LF pt 2 q `sup gPΩg ˇˇż Ω R gprq ¨pP prq ´P pr|T " t 2 qqdr ˇˇ´σY t 2 ¸. The first inequality holds by the fact that pa `bq 2 ď 2a 2 `2b 2 for any a, b P R. The second equality holds by Lemma 2 and the last inequality holds by the same arguments used in Theorem 1.

C EXPERIMENTAL DETAILS C.1 DATA GENERATING MECHANISMS

This section describes the data generating mechanisms used in our experiments. Synthetic. We generate synthetic data similar to Nie et al. (2020) . With covariates x P R 6 all drawn from a uniform distribution between 0 and 1, we generate the continuous dosages and outcomes as follows, s|x " 10 sinpmaxpx 1 , x 2 , x 3 qq `maxpx 3 , x 4 , x 5 q 3 1 `px 1 `x5 q 2 `sinp0.5x 3 qp1 `exppx 4 ´0.5x 3 qq(

30)

x 2 3 `2 sinpx 4 q `2x 5 ´6.5 `N p0, 0.25q, y|x, s " cosp2πps ´0.5qq ˆs2 `4 maxpx 1 , x 6 q 3 1 `2x 2 3 sinpx 4 q ˙`N p0, 0.25q, where s " p1 `expp´sqq ´1. IHDP Continuous. The IHDP dataset contains 25 covariates with binary treatments and continuous outcomes (Hill, 2011) . Disregarding original treatments and outcomes, we use the covariates to generate new continuous dosages and outcomes to test our method. We follow the data generating procedure of Nie et al. (2020) , namely: s|x " 2x 1 1 `x2 `2 maxpx 3 , x 5 , x 6 q 0.2 `minpx 3 , x 5 , x 6 q `2 tanh ˆ5 ř iPI x i ´c2 |I| ˙´4 `N p0, 0.25q, (32) y|x, s " sinp3πsq 1.2 ´s ˆtanh ˆ5 ř iPJ x i ´c1 |J| ˙`exp ˆ0.2px 1 ´x6 q 0.5 `minpx 2 , x 3 , x 5 q ˙˙`N p0, 0.25q, (33) c 1 " E ppxq " ř iPJ x i |J| ȷ , c 2 " E ppxq " ř iPI x i |I| ȷ , where s " p1 `expp´sqq ´1, I " t16, 17, 18, 19, 20, 21, 22, 23, 24, 25u, and J " t4, 7, 8, 9, 10, 11, 12, 13, 14, 15u . News. This dataset contains words sampled from 5000 news articles (Newman, 2008) . The covariates are word counts. We generated continuous dosage and outcomes by following the data generation method listed in Bica et al. (2020) . We first sample three vectors v 1 i " N p0, 1q, with v i " v 1 i {||v 1 i || 2 for i " 1, 2, 3. Then, dosages are drawn from a Beta distribution: s " Betap2, βq, β " max ˆ1, ˇˇˇ2 x T v 2 x T v 1 ˇˇˇ˙. Finally, outcomes are sampled according to: y 1 " exp ˆˇˇˇx T v 2 x T v 1 ˇˇˇ´0 .3 ˙(37) y " 2 ´maxp´2, minp2, y 1 q `20x T v 3 ˚p4ps ´0.5q 2 q ˚sin ´πs 2 ¯¯`N p0, 0.25q.

C.2 ARCHITECTURES AND TRAINING DETAILS

In both VCNet and DRNet, the representation part of the network ϕ η and the prediction heads h θ have two layers each, with 50 hidden units and ReLU activations. Following Nie et al. (2020) , we use B-spline with degree 2 and knots placed at t1{3, 2{3u for VCNet and 5 regression heads for DRNet. For the MLP model, we use a 4-layers network to represent similar power of approximations to ensure fair comparison. We optimise the networks using Adam (Kingma & Ba, 2014) with a weight decay of 0.005 for regularisation and a batch size of 1000. Learning rate is chosen within the set t0.01, 0.005, 0.001, 0.0005, 0.0001, 0.00005u using the procedure outlined in C.3. Each data set is split into a train/validation/test set with ratios 0.6/0.2/0.2. To avoid overfitting, we stop the training if the validation loss did not improve after 50 epochs. Propensity score regularization (-PS methods) In addition to the representation net ϕ : Ω X Ñ Ω R and to the prediction net h 1 : Ω R ˆΩT Ñ Ω Y , propensity score regularized methods also include a separate head h 2 : Ω R Ñ Ω T . Parameters are tuned by minimizing the loss L P S pϕ, h 1 , h 2 q " 1 N N ÿ n"1 ´ypnq ´h1 pϕpx pnq q, t pnq q ¯2 `α ¨CrossEntropy ´h2 pϕpx pnq qq, t pnq ¯. (39) In our experiments, h 2 is modelled through a softmax layer over a grid of 10 bins. Average treatment effects are then estimated by considering an additional perturbation term following Shi et al. (2019) and Nie et al. (2020) . α is treated as a hyperparameter and chosen within the set t0.5, 1u using the procedure detailed in C.3. The implementation in practice follows the publicly available code of Nie et al. (2020) . Targeted regularization (-TR methods) Methods labeled -TR use the functional targeted regularization approach presented in Nie et al. ( 2020) which optimizes the loss function L T R pϕ, h 1 , h 2 , ϵ N q " L P S pϕ, h 1 , h 2 q `β N N ÿ n"1 ˆypnq ´h1 pϕpx pnq q, t pnq q ´ϵN pt pnq q h 2 pϕpx pnq qq ˙2 where ϵ N p¨q " ř J N j"1 a j ψ j p¨q is modelled through J N spline basis functions ψ j of degree 2. The number of basis might change with the sample size N . Following Nie et al. (2020) , we select the learning rate for ϵ N p¨q, β and the number of spline knots within the sets {0.001, 0.0001}, t20, 10, 5u{ ? N and t5, 10, 20u, respectively. Again, the implementation in practice follows the publicly available code of Nie et al. (2020) . IPM regularization (-HSIC and -Wass methods) IPM regularized methods minimize the proposed loss in Equation ( 13) in the main body of this paper, where the γ is selected within the set t10 i{6 , i " ´18, ´17, ¨¨¨, 9, 10u using the procedure in C.3. The implementation of Wasserstein distance regulariser follows the one available at https://github.com/clinicalml/cfrnet/ blob/master/cfr/util.py#L166. HSIC regulariser is computed according to (Greenfeld & Shalit, 2020, Eq. 3) , using two RBF kernels with length-scales t0.05, 0.1, 0.5, 1, 5, 10, 50, 100, 500u. Generalised Propensity Score (GPS) We use our own implementation following Hirano & Imbens (2004) .

C.3 HYPER-PARAMETERS TUNING

We use grid-search to tune the hyper-parameters. Namely, we generate a dataset for each hyperparameters setting, randomly splitting it into a train/test set with a ratio of 0.8/0.2 and we choose the hyperparameters values giving the best MISE test score.

C.4 RUN TIME COMPARISONS

Table 2 reports the computational time (in seconds) required by the algorithms compared in the experimental section for 2000 training epochs. These results are machine and algorithm-specific but do serve as a relative comparison of run times for different neural network architectures and regularization techniques. In general, Wassertstein IPM regularisation is more computationally efficient than TR and IPM regularisation through HSIC metric. where y pnq pw, sq and ŷpnq pw, sq stand for the true and predicted outcome for individual n given treatment-dosage pairs pw, sq P Ω T , and Egpsq " ş Ω S gpsqP psqds. The AMSE calculates the accuracy of the population level dose response. As the doubly robust methods get rid of effect modifiers, that are useful for accurate predictions, but have theoretical guarantees for the average treatment effects, we expect these methods to get a better AMSE. On the other hand, as the regularizers proposed in this work provide guarantees on the counterfactual error, we expect models trained with these to achieve a better MISE score. In Table 3 we show mean performance for both MISE (as in the main body of this paper) and AMSE with the objective to contrast the proposed methods, designed for optimal conditional average counterfactual prediction, i.e. MISE, and doubly-robust methods designed for population average counterfactual prediction, i.e. AMSE. Overall, we note that the proposed regularization technique, using either the HSIC or Wasserstein distances between distributions, is competitive across all datasets and metrics, including AMSE. Across all datasets, there is a clear trend showing that regularizing for optimal generalization performance in terms of the MISE with HSIC leads to good population average performance as well, as measured by AMSE. Doubly-robust methods (-PS, -TR) are designed for optimality in estimation of the AMSE, and across all datasets they are either optimal or competitive compared to all other algorithms. It is interesting to note that the performance achieved by the proposed regularisation techniques (HSIC, Wasserstein) are very close to the optimum AMSE while, in contrast, doubly robust methods often perform significantly worse in terms of MISE than the optima achieved by HSIC and Wasserstein regularisation. This further confirms the intuition presented in Appendix A: estimating average counterfactuals and individual counterfactuals can require different objectives.



The remark has been made that injectivity of representation is difficult to enforce(Zhang et al., 2020;Johansson et al., 2019). An algorithmic solution, discussed byZhang et al. (2020), is to include a decoder from the representation to the input domain and reconstruction loss in the training objective to encourage solutions with invertible latent representations. A reconstruction loss and encoder-decoder architecture can be included on top of the regularization terms proposed in this paper.



Figure 1: Sketch of the architecture.

Figure 2: Out-of-sample MISE error versus IPM regularization, relative to the error at γ " 0 (no regularisation), on 50 realizations of Synthetic (a) and IHDP-continuous (b) datasets. Average values (dot markers) and one standard deviation (shaded areas) are shown.

tpw, sq : w P Ω W , s P Ru denote the set of all treatment options. The goal is to derive estimates of the expected potential outcomes for a given set of input covariates: ErY t | xs, for any value of t and x. Under the following standard assumptions(Rubin, 2005), it is well understood that the treatment effect between two selected treatment options t 1 and t 2 reduces to a contrast of conditional distributions, presented in Prop. 1 below. Assumption 1 (Unconfoundedness). The treatment assignment, t " pw, sq P Ω T , and potential outcomes, Y t , are conditionally independent given the covariates x, i.e. Y T |ù T | X.

Average values and standard deviations (within brackets) of ?MISE across 50 (20) realizations of Synthetic, IHDP-continuous (News) datasets. Bold notation highlights the best-performing algorithm on each dataset.

Lemma 2. For convenience, we write mpt, xq :" ErY t | xs and we define its estimate given a prediction function f : Ω R ˆΩX Ñ Ω Y by f pt, xq. If L is the square loss, it then holds that, The third term in the third equality evaluates to zero because mpx, tq :" ş Ω Y y t P py t | xqdy t and we have defined the variance of Y t with respect to the distribution P pxq as σ Yt :" ş P py t | xqP pxqdy t dx. Theorem 2 (Generalization bound for selected treatment tuples t 1 and t 2 ). Let t 1 , t 2 P Ω T be two treatment tuples to be compared. Then, L pt1,t2q {2 ď L F pt 1 q `sup gprq ¨pP prq ´P pr | T " t 1 qqdr ˇˇ`LF pt 2 q gprq ¨pP prq ´P pr | T " t 2 qqdr ˇˇ´σY t 1 ´σYt 2 , (24) where σ Yt 1 and σ Yt 2 stand for the variance of the random variables Y t1 and Y t2 , respectively, under the distribution P pxq. , xq ´f pt 2 , xq ´mpt 1 , xq `mpt 2 , xqq 2 P pxqdx , xq ´mpt 2 , xqq 2 P pxqdx (26)

Computational times (in seconds) required for 2000 epochs of training. Averages and standard deviations (within brackets) over 10 runs (5 for News dataset) are reported.

Average values and standard deviations (within brackets) of ? MISE and ? AMSE across 50 (20) realizations of Synthetic, IHDP-continuous (News) datasets. Bold notation highlights the best-performing algorithm on each dataset. D FURTHER EXPERIMENTS This section includes experiments comparing performance with respect to Average Mean Squared Error (AMSE), in addition to Mean Integrated Squared Error (MISE),

annex

Moreover, in terms of AMSE, it still holds that neural network architectures with better expressiveness to model heterogeneous dose-response curves perform better on all datasets.

