ADVERSARIAL COUNTERFACTUAL ENVIRONMENT MODEL LEARNING

Abstract

A good model for action-effect prediction, i.e., the environment model, is essential for sample-efficient policy learning, in which the agent can take numerous free trials to find good policies. Currently, the model is commonly learned by fitting historical transition data through empirical risk minimization (ERM). However, we discover that simple data fitting can lead to a model that will be totally wrong in guiding policy learning due to the selection bias in offline dataset collection. In this work, we introduce weighted empirical risk minimization (WERM) to handle this problem in model learning. A typical WERM method utilizes inverse propensity scores to re-weight the training data to approximate the target distribution. However, during the policy training, the data distributions of the candidate policies can be various and unknown. Thus, we propose an adversarial weighted empirical risk minimization (AWRM) objective that learns the model with respect to the worst case of the target distributions. We implement AWRM in a sequential decision structure, resulting in the GALILEO model learning algorithm. We also discover that GALILEO is closely related to adversarial model learning, explaining the empirical effectiveness of the latter. We apply GALILEO in synthetic tasks and verify that GALILEO makes accurate predictions on counterfactual data. We finally applied GALILEO in real-world offline policy learning tasks and found that GALILEO significantly improves policy performance in real-world testing.

1. INTRODUCTION

A good environment model is important for sample-efficient decision-making policy learning techniques like reinforcement learning (RL) (James & Johns, 2016) . The agent can take trials with this model to find better policies, then the costly real-world trial-and-errors can be saved (James & Johns, 2016; Yu et al., 2020) or completely waived (Shi et al., 2019) . In this process, the core of the models is to answer queries on counterfactual data unbiasedly, that is, given states, correctly answer what might happen if we were to carry out actions unseen in the training data (Levine et al., 2020) . Requiring counterfactual queries makes the environment model learning essentially different from standard supervised learning (SL) which directly fits the offline dataset. In real-world applications, the offline data is often collected with selection bias, that is, for each state, each action might be chosen unfairly. Seeing the example in Fig. 1 (a), to keep the ball following a target line, a behavior policy will use a smaller force when the ball's location is closer to the target line. When a dataset is collected with selection bias, the association between the (location) states and (force) actions will make SL hard to identify the correct causal relationship of the states and actions to the next states respectively. Then when we query the model with counterfactual data, the predictions might be catastrophic failures. In Fig. 1(c ), it mistakes that smaller forces will increase the ball's next location. Generally speaking, the problem corresponds to a challenge of training the model in one dataset but testing in another dataset with a shifted distribution (i.e., the dataset generated by counterfactual queries), which is beyond the SL's capability as it violates the independent and identically distributed (i.i.d.) assumption. The problem is widely discussed in causal inference for individual treatment effects (ITEs) estimation in many scenarios like patients' treatment selection (Imbens, 1999; Alaa & van der Schaar, 2018) . ITEs are the effects of treatments on individuals, which are measured by treating each individual under a uniform policy and evaluate the effect differences. Practical solutions use weighted empirical risk minimization (WERM) to handle this problem (Jung et al., 2020; Shimodaira, 2000; Hassanpour & Greiner, 2019) . In particular, they estimate an inverse propensity score (IPS) to re-weight the training data to approximate the data distribution under a uniform policy. Then a model is trained under the reweighted data distribution. The distribution-shift problem is solved as ITEs estimation and model training are under the same distribution. However, the truth is not because the larger a t causes a smaller y t+1 , but the policy selects a small a t when y t is close to the target line. When we estimate the response curves by fixing y t and reassigning action a t with other actions a t + ∆a, where ∆a ∈ [-1, 1] is a variation of action value, the model of SL will exploit the association and give opposite responses, while in AWRM and its practical implementation GALILEO, the predictions are closer to the ground truths. The result is in Subfigure (c), where the darker a region is, the more samples are fallen in. The selection bias can be regarded as an instance of the problem called "distributional shift" in offline model-based RL, which has also received great attention (Levine et al., 2020; Yu et al., 2020; Kidambi et al., 2020; Chen et al., 2021) . However, previous methods, where naive supervised learning is used for environment model learning, ignore the problem in environment model learning to Reviewer rQ79 but handling the problem by suppressing the policy exploration and learning in risky regions. Although these methods have made great progress in many tasks, so far, how to learn a better environment model that can alleviate the problem for faithful offline policy optimization has rarely been discussed. In this work, for faithful offline policy optimization, we introduce WERM to environment model learning. The extra challenge of model learning for policy optimization is that we have to query numerous different policies' feedback besides the uniform policy for finding a good policy. Thus the target data distribution to reweight can be various and unknown. To solve the problem, we propose an objective called adversarial weighted empirical risk minimization (AWRM). AWRM introduces adversarial policies, of which the corresponding counterfactual dataset has the maximal prediction error of the model. For each iteration, the model is learned to be as small prediction risks as possible under the adversarial counterfactual dataset. However, the adversarial counterfactual dataset cannot be obtained in the offline setting, thus we derive an approximation of the counterfactual data distribution queried by the optimal adversarial policy and use a variational representation to give a tractable solution to learn a model from the approximated data distribution. As a result, we derive a practical approach named Generative Adversarial offLIne counterfactuaL Environment mOdel learning (GALILEO) for AWRM. Fig. 2 shows the difference in the prediction errors learned by these algorithms. We also discover that GALILEO is closely related to existing generative-adversarial model learning techniques, explaining the effectiveness of the latter. Experiments are conducted in two synthetic and two realistic environments. The results in the synthetic environments show that GALILEO can reconstruct correct responses for counterfactual queries. The evaluation results in two realistic environments also demonstrate that GALILEO has



Figure 1: An example of selection bias and predictions under counterfactual queries. Subfigure (a)shows how the data is collected: a ball locates in a 2D plane whose position is (x t , y t ) at time t. The ball will move to (x t+1 , y t+1 ) according to x t+1 = x t + 1 and y t+1 ∼ N (y t + a t , 2). Here, a t is chosen by a control policy a t ∼ N ((ϕ -y t )/15, 0.05) parameterized by ϕ, which tries to keep the ball near the line y = ϕ. In Subfigure (a), ϕ is set to 62.5. Subfigure (b) shows the collected training data (grey dashed line) and the two learned models' prediction of the next position of y. All the models discovered the relation that the corresponding next y will be smaller with a larger action. However, the truth is not because the larger a t causes a smaller y t+1 , but the policy selects a small a t when y t is close to the target line. When we estimate the response curves by fixing y t and reassigning action a t with other actions a t + ∆a, where ∆a ∈ [-1, 1] is a variation of action value, the model of SL will exploit the association and give opposite responses, while in AWRM and its practical implementation GALILEO, the predictions are closer to the ground truths. The result is in Subfigure (c), where the darker a region is, the more samples are fallen in.

Figure 2: An illustration of the prediction error in counterfactual datasets.The prediction risks is measured with mean square error (MSE). The error of SL is small only in training data (ϕ = 62.5) but becomes much larger in the dataset "far away from" the training data. AWRM-oracle selects the oracle worst counterfactual dataset for training for each iteration (pseudocode is in Alg. 1) which reaches small MSE in all datasets and gives correct response curves (Fig.1(c)). GALILEO approximates the optimal adversarial counterfactual data distribution based on the training data and model. Although the MSE of GALILEO is a bit larger than SL in the training data, in the counterfactual datasets, the MSE is on the same scale as AWRM-oracle.

