ADVERSARIAL COUNTERFACTUAL ENVIRONMENT MODEL LEARNING

Abstract

A good model for action-effect prediction, i.e., the environment model, is essential for sample-efficient policy learning, in which the agent can take numerous free trials to find good policies. Currently, the model is commonly learned by fitting historical transition data through empirical risk minimization (ERM). However, we discover that simple data fitting can lead to a model that will be totally wrong in guiding policy learning due to the selection bias in offline dataset collection. In this work, we introduce weighted empirical risk minimization (WERM) to handle this problem in model learning. A typical WERM method utilizes inverse propensity scores to re-weight the training data to approximate the target distribution. However, during the policy training, the data distributions of the candidate policies can be various and unknown. Thus, we propose an adversarial weighted empirical risk minimization (AWRM) objective that learns the model with respect to the worst case of the target distributions. We implement AWRM in a sequential decision structure, resulting in the GALILEO model learning algorithm. We also discover that GALILEO is closely related to adversarial model learning, explaining the empirical effectiveness of the latter. We apply GALILEO in synthetic tasks and verify that GALILEO makes accurate predictions on counterfactual data. We finally applied GALILEO in real-world offline policy learning tasks and found that GALILEO significantly improves policy performance in real-world testing.

1. INTRODUCTION

A good environment model is important for sample-efficient decision-making policy learning techniques like reinforcement learning (RL) (James & Johns, 2016) . The agent can take trials with this model to find better policies, then the costly real-world trial-and-errors can be saved (James & Johns, 2016; Yu et al., 2020) or completely waived (Shi et al., 2019) . In this process, the core of the models is to answer queries on counterfactual data unbiasedly, that is, given states, correctly answer what might happen if we were to carry out actions unseen in the training data (Levine et al., 2020) . Requiring counterfactual queries makes the environment model learning essentially different from standard supervised learning (SL) which directly fits the offline dataset. In real-world applications, the offline data is often collected with selection bias, that is, for each state, each action might be chosen unfairly. Seeing the example in Fig. 1(a) , to keep the ball following a target line, a behavior policy will use a smaller force when the ball's location is closer to the target line. When a dataset is collected with selection bias, the association between the (location) states and (force) actions will make SL hard to identify the correct causal relationship of the states and actions to the next states respectively. Then when we query the model with counterfactual data, the predictions might be catastrophic failures. In Fig. 1(c ), it mistakes that smaller forces will increase the ball's next location. Generally speaking, the problem corresponds to a challenge of training the model in one dataset but testing in another dataset with a shifted distribution (i.e., the dataset generated by counterfactual queries), which is beyond the SL's capability as it violates the independent and identically distributed (i.i.d.) assumption. The problem is widely discussed in causal inference for individual treatment effects (ITEs) estimation in many scenarios like patients' treatment selection (Imbens, 1999; Alaa & van der Schaar, 2018) . ITEs are the effects of treatments on individuals, which are measured by treating each individual under a uniform policy and evaluate the effect differences. Practical solutions use weighted empirical risk minimization (WERM) to handle this problem (Jung et al., 2020; Shimodaira, 2000; Hassanpour & Greiner, 2019) . In particular, they estimate an inverse propensity score (IPS) to re-weight the training data to approximate the data distribution under a uniform policy. Then a model is trained under the reweighted data distribution. The distribution-shift problem is solved as ITEs estimation and model training are under the same distribution.

