TOWARDS INTERPRETABLE DEEP REINFORCEMENT LEARNING WITH HUMAN-FRIENDLY PROTOTYPES

Abstract

Despite recent success of deep learning models in research settings, their application in sensitive domains remains limited because of their opaque decision-making processes. Taking to this challenge, people have proposed various eXplainable AI (XAI) techniques designed to calibrate trust and understandability of black-box models, with the vast majority of work focused on supervised learning. Here, we focus on making an "interpretable-by-design" deep reinforcement learning agent which is forced to use human-friendly prototypes in its decisions, thus making its reasoning process clear. Our proposed method, dubbed Prototype-Wrapper Network (PW-Net), wraps around any neural agent backbone, and results indicate that it does not worsen performance relative to black-box models. Most importantly, we found in a user study that PW-Nets supported better trust calibration and task performance relative to standard interpretability approaches and black-boxes.

1. INTRODUCTION

Deep reinforcement learning (RL) models have achieved state-of-the-art results in Go (Silver et al., 2016) , Chess (Silver et al., 2017) , Atari (Mnih et al., 2015) , self-driving cars (Kiran et al., 2021) , and robotic control (Kober et al., 2013) . However, the usage of these agents in truly sensitive domains is limited due to the opaque nature of such systems. Extracting a deep model's rationale in a human interpretable format remains a challenging problem, but doing so would be highly useful to troubleshoot an agent's actions and by extension its possible failure states (Hayes & Shah, 2017) . One popular approach to do so is with post-hoc explanation methods, which give after-the-fact rationales for model predictions mostly through some form of saliency map (Bach et al., 2015) or exemplar (Kenny & Keane, 2021) . However, whilst popular, these approaches may be incomplete or unsuitable for explanation (Slack et al., 2020; Zhou et al., 2022) , and recent work has instead started to focus on pre-hoc interpretability (Rudin, 2019) . The core idea behind this latter paradigm is to design inherently explainable models, so that you can clearly see and understand their decisionmaking process in such a way that you can calibrate user trust and predict the system's capabilities. In this paper, we present (to the best of our knowledge) the first general, inherently interpretable, well performing, deep reinforcement learning (RL) algorithm that uses an intuitive exemplar-based approach for decision making. Specifically, we train a "wrapper" model called Prototype-Wrapper Network (PW-Net) that can be added to any pre-trained agent, which allows them to be interpretableby-design, offering the same intuitive reasoning process as popular pre-hoc methods (Li et al., 2018) . Crucially however, when using PW-Nets, the main advantages of post-hoc methods remain in-tact, in that the black-box model's performance is not lost, and it doesn't need to be retrained from scratch, which we show across multiple domains notoriously difficult for XAI.

2. RELATED WORK

This paper builds upon recent work building prototype-based neural networks for interpretable supervised learning. Such networks are interpretable by design because they utilize these prototypes in their forward pass by classifying test instances based upon their proximity to these prototypes, thus allowing users to intuitively understand predictions. Perhaps the first notable example of this was by Li et al. (2018) who learned prototypes in latent space which classified test instances using their L 2 distance to each prototype. Work in this area was followed up by Chen et al. ( 2019) who used image "parts" rather than the whole instance. This helped spawn many followup works in NLP (Ming et al., 2019) , fairness (Tucker & Shah, 2022) , and computer vision tasks (Davoudi & Komeili, 2021; Donnelly et al., 2022) . These prior works establish the value of prototype-based neural nets, but they focus on traditional classification tasks; we are similarly inspired by such methods but seek to build interpretable agents in RL settings. Whilst we are not the first to build interpretable RL models (Vouros, 2022; Milani et al., 2022) , almost all prior research uses interpretable proxy models (such as trees) to imitate agents in symbolic domains. However, these techniques do not apply to richer domains with high-dimensional inputs (such as raw pixels) that we focus on in this work. To date, most work in these deep RL settings has focused on post-hoc approximations involving attention weights (Zambaldi et al., 2018; Mott et al., 2019) , or trees (Liu et al., 2018) , but these methods do not allow transparency of the agent's actions or intent (Rudin et al., 2022) . Another interesting approach distills recurrent neural network (RNN) policies into finite-state-machines (Danesh et al., 2021; Koul et al., 2018) , but the approach does not always reveal easily analyzable results, and is restricted to RNNs. Perhaps the most relevant approach is that by Annasamy & Sycara (2019) who learn exemplars to explain Atari games, but their method only works in discrete action spaces, and suffers a performance gap relative to blackbox counterparts. In contrast to these approaches, PW-Nets are designed to be human-interpretable, generalize to any neural-based agent with any action space, do not lose model performance relative to black-box agents, and do not rely on post-hoc approximations prone to error. User studies in deep RL trying to predict an agent's actions have shown mixed results (Anderson et al., 2020) , with the focus now moving towards arguable more useful tasks such as identifying defective models (Olson et al., 2021) . In the prototype literature, studies have generally focused on how "similar" test instances look to prototypes used (Das et al., 2020; Rymarczyk et al., 2021) , but this does not evaluate if the explanation is useful in downstream applications, or appropriately calibrates trust in users (Sanneman & Shah, 2022) . In contrast to these studies, we ask users to simulate model behaviour and predict failure (and success) cases for the agent, a useful application that should illustrate if trust is appropriately calibrated in the agent's abilities.

3. METHOD

Section 3.1 first details our assumptions of a Markov environment and trained neural agent, before Section 3.2 describes the proposed Prototype Wrapper Network (PW-Net), which creates a "wrapper" around the agent to make a new end-to-end interpretable policy that reasons with human-defined prototypes. Finally, Section 3.3 gives certain performance guarantees for the system which shows it can always closely approximate the performance of a black-box agent.

3.1. MARKOV FRAMEWORK

Our technique assumes access to a neural agent pretrained in a Markov Decision Process (MDP). An MDP is defined by the (S, A, T, R, γ) tuple (Sutton & Barto, 2018) . S is the set of states; A ∈ R M is the sets of M -dimensional actions; T : S × A -→ S is the probabilistic transition between states due to actions. We note that actions in RL, unlike outputs of simple classifiers, may be multi-dimensional (e.g., a car must control both steering and acceleration at the same time). We therefore dub each dimension in R M a separate "action". Lastly, γ and R define the discount factor and reward function, respectively. In standard RL training methods, the goal is to find the policy, π : s ∈ S -→ A, that maximizes the expected discounted reward. Our approach only requires access to a pre-trained black box policy, π bb ; any existing method from prior art may be used to generate such a policy (Sutton et al., 2000; Williams, 1992) . Assuming a neural net instantiation of this policy with a final linear layer, we may decompose π bb into an encoder f enc , alongside the last layer with weights W and bias b as follows: π bb (s) = W f enc (s) + b.

3.2. PROTOTYPE-WRAPPER NETWORK

Our primary contribution is a prototype wrapper neural net model, PW-Net, that converts black-box neural models into prototype-based agents for RL by forcing them to use human-understandable

