COUNTERFACTUAL SELF-TRAINING

Abstract

Unlike traditional supervised learning, in many settings only partial feedback is available. We may only observe outcomes for the chosen actions, but not the counterfactual outcomes associated with other alternatives. Such settings encompass a wide variety of applications including pricing, online marketing and precision medicine. A key challenge is that observational data are influenced by historical policies deployed in the system, yielding a biased data distribution. We approach this task as a domain adaptation problem and propose a self-training algorithm which imputes outcomes with finite discrete values for finite unseen actions in the observational data to simulate a randomized trial. We offer a theoretical motivation for this approach by providing an upper bound on the generalization error defined on a randomized trial under the self-training objective. We empirically demonstrate the effectiveness of the proposed algorithms on both synthetic and real datasets.

1. INTRODUCTION

Counterfactual inference (Pearl et al., 2000) attempts to address a question central to many applications -What would be the outcome had an alternative action was chosen? It may be selecting relevant ads to engage with users in online marketing (Li et al., 2010) , determining prices that maximize profit in revenue management (Bertsimas & Kallus, 2016) , or designing the most effective personalized treatment for a patient in precision medicine (Xu et al., 2016) . With observational data, we have access to past actions, their outcomes, and possibly some context, but in many cases not the complete knowledge of the historical policy which gave rise to the action (Shalit et al., 2017) . Consider a pricing setting in the form targeted promotion. We might record information of a customer (context), promotion offered (action) and whether an item was purchased (outcome), but we do not know why a particular promotion was selected. Unlike traditional supervised learning, we only observe feedback for the chosen action in observational data, but not the outcomes associated with other alternatives (i.e., in the pricing example, we do not observe what would occur if a different promotion was offered). In contrast to the gold standard of a randomized controlled trial, observational data are influenced by historical policy deployed in the system which may over or under represent certain actions, yielding a biased data distribution. A naive but widely used approach is to learn a machine learning algorithm directly from observational data and use it for prediction. This is often referred to as direct method (DM) (Dudík et al., 2014) . Failure to account for the bias introduced by historical policy often results in an algorithm which has high accuracy on the data it was trained on, but performs considerably worse under a different policy. For example in the pricing setting, if historically most customers who received high promotion offers bear a certain profile, then a model based on direct method may fail to produce reliable predictions on these customers when low offers are given. To overcome the limitations of direct method, Shalit et al. ( 2017 2020) cast counterfactual learning as a domain adaptation problem, where the source domain is observational data and the target domain is a randomized trial whose assignment of actions follows a uniform distribution for a given context. The key idea is to map contextual features to an embedding space and jointly learn a representation that encourages similarity between these two domains, leading to better counterfactual inference. The embedding is generally learned by a neural network and the estimation of the domain gap is usually slow to compute. Figure 1 : Illustration of the proposed Counterfactual Self-Training (CST) framework. There are two sales records (observational data) shown in the table, i.e., , Customer A was offered $2 and bought an item; Customer B was offered $1 and did not buy. The question marks in the tables represent the counterfactual outcome which we do not observe. For all these unseen counterfactual outcomes, pseudo-labels which are colored in red in the tables are imputed by a model and are used to augment the observational data. The model is subsequently updated by training on both the imputed counterfactual data and the factual data. This iterative training procedure continues until it converges. In this paper, while we also view counterfactual inference as a domain adaptation problem between observational data and an ideal randomized trial, we take a different approach -instead of estimating the domain gap between the two distributions via an embedding, we explicitly simulate a randomized trial by imputing pseudo-labels for the unobserved actions in the observational data. The optimization process is done by iteratively updating the pseudo-labels and a model that is trained on both the factual and the counterfactual data, as illustrated in Figure 1 . As this method works in a selfsupervised fashion (Zou et al., 2018; Amini & Gallinari, 2002) , we refer to our proposed framework as Counterfactual Self-Training (CST). The contribution of our paper is as follows. First, we propose a novel self-training algorithm for counterfactual inference. To the best of our knowledge, this is the first application of self-training algorithm for learning from observational data. Moreover, in contrast to the existing methods from domain adaption on counterfactual inference, CST is flexible and can work with a wide range of machine learning algorithms, not limited to neural networks. Second, we offer a theoretical motivation of our approach by providing an upper bound on the generalization error defined on a randomized trial under the self-training objective. In other words, we show that the counterfactual self-training algorithm helps minimizing the risk on the target domain. Our theoretical bounds suggest generating pseudo-labels with random imputation, which is a methodological departure from traditional self-training algorithms which impute hard labels. Third, we present comprehensive experiments on several synthetic datasets and three counterfactual learning datasets converted from multi-label classification tasks to evaluate our method against state-of-the-art baselines. In all experiments, CST shows competitive or superior performance against all the baselines. Moreover, our algorithm is easy to optimize with a much faster training time than other baselines.

2. RELATED WORK

Counterfactual policy optimization has received a lot of attention in the machine learning community in the recent years (Swaminathan & Joachims, 2015a; Joachims et al., 2018; Shalit et al., 2017; Lopez et al., 2020; Kallus, 2019; Kallus & Zhou, 2018; Wang et al., 2019) . Most of the proposed algorithms can be divided into two categories: counterfactual risk minimization (CRM) and direct method (DM). Both can be used together to construct doubly robust estimators (Dudík et al., 2014) to further improve efficiency. CRM, also known as off-policy learning or batch learning from bandit feedback, typically utilizes inverse propensity weighting (IPW) (Rosenbaum, 1987; Rosenbaum & Rubin, 1983) to account for the bias in the data. Swaminathan & Joachims (2015a) introduces the CRM principle with a variance regularization term derived from an empirical Bernstein bound (Maurer & Pontil, 2009) for finite samples. In order to reduce the variance of the IPW



); Johansson et al. (2016); Lopez et al. (

