CDT: CASCADING DECISION TREES FOR EXPLAIN-ABLE REINFORCEMENT LEARNING Anonymous authors Paper under double-blind review

Abstract

Deep Reinforcement Learning (DRL) has recently achieved significant advances in various domains. However, explaining the policy of RL agents still remains an open problem due to several factors, one being the complexity of explaining neural networks decisions. Recently, a group of works have used decision-tree-based models to learn explainable policies. Soft decision trees (SDTs) and discretized differentiable decision trees (DDTs) have been demonstrated to achieve both good performance and share the benefit of having explainable policies. In this work, we further improve the results for tree-based explainable RL in both performance and explainability. Our proposal, Cascading Decision Trees (CDTs) apply representation learning on the decision path to allow richer expressivity. Empirical results show that in both situations, where CDTs are used as policy function approximators or as imitation learners to explain black-box policies, CDTs can achieve better performances with more succinct and explainable models than SDTs. As a second contribution our study reveals limitations of explaining black-box policies via imitation learning with tree-based explainable models, due to its inherent instability.

1. INTRODUCTION

Explainable Artificial Intelligence (XAI), especially Explainable Reinforcement Learning (XRL) (Puiutta and Veith, 2020) is attracting more attention recently. How to interpret the action choices in reinforcement learning (RL) policies remains a critical challenge, especially as the gradually increasing trend of applying RL in various domains involving transparency and safety (Cheng et al., 2019; Junges et al., 2016) . Currently, many state-of-the-art DRL agents use neural networks (NNs) as their function approximators. While NNs are considered stronger function approximators (for better performances), RL agents built on top of them are generally lack of interpretability (Lipton, 2018) . Indeed, interpreting the behavior of NNs themselves remains an open problem in the field (Montavon et al., 2018; Albawi et al., 2017) . In contrast, traditional DTs (with hard decision boundaries) are usually regarded as models with readable interpretations for humans, since humans can interpret the decision making process by visualizing the decision path. However, DTs may suffer from weak expressivity and therefore low accuracy. An early approach to reduce the hardness of DT was the soft/fuzzy DT (shorten as SDT) proposed by Suárez and Lutsko (1999) . Recently, differentiable SDTs (Frosst and Hinton, 2017) have shown both improved interpretability and better function approximation, which lie in the middle of traditional DTs and neural networks. People have adopted differentiable DTs for interpreting RL policies in two slightly different settings: an imitation learning setting (Coppens et al., 2019; Liu et al., 2018) , in which imitators with interpretable models are learned from RL agents with black-box models, or a full RL setting (Silva et al., 2019) , where the policy is directly represented as an interpretable model, e.g., DT. However, the DTs in these methods only conduct partitions in raw feature spaces without representation learning that could lead to complicated combinations of partitions, possibly hindering both model interpretability and scalability. Even worse, some methods have axis-aligned partitions (univariate decision nodes) (Wu et al., 2017; Silva et al., 2019) with much lower model expressivity. In this paper, we propose Cascading Decision Trees (CDTs) striking a balance between model interpretability and accuracy, this is, having an adequate representation learning based on interpretable models (e.g. linear models). Our experiments show that CDTs share the benefits of having a significantly smaller number of parameters (and a more compact tree structure) and better performance than related works. The experiments are conducted on RL tasks, in either imitation-learning or RL settings. We also demonstrate that the imitation-learning approach is less reliable for interpreting the RL policies with DTs, since the imitating DTs may be prominently different in several runs, which also leads to divergent feature importances and tree structures.

2. RELATED WORKS

A series of works were developed in the past two decades along the direction of differentiable DTs (Irsoy et al., 2012; Laptev and Buhmann, 2014) . Recently, Frosst and Hinton (2017) proposed to distill a SDT from a neural network. Their approach was only tested on MNIST digit classification tasks. Wu et al. (2017) further proposed the tree regularization technique to favor the models with decision boundaries closer to compact DTs for achieving interpretability. To further boost the prediction accuracy of tree-based models, two main extensions based on single SDT were proposed: (1) ensemble of trees, or (2) unification of NNs and DTs. An ensemble of decision trees is a common technique used for increasing accuracy or robustness of prediction, which can be incorporated in SDTs (Rota Bulo and Kontschieder, 2014; Kontschieder et al., 2015; Kumar et al., 2016) , giving rise to neural decision forests. Since more than one tree needs to be considered during the inference process, this might yield complications in the interpretability. A common solution is to transform the decision forests into a single tree (Sagi and Rokach, 2020). As for the unification of NNs and DTs, Laptev and Buhmann (2014) propose convolutional decision trees for feature learning from images. Adaptive Neural Trees (ANTs) (Tanno et al., 2018) incorporate representation learning in decision nodes of a differentiable tree with nonlinear transformations like convolutional neural networks (CNNs). The nonlinear transformations of an ANT, not only in routing functions on its decision nodes but also in feature spaces, guarantee the prediction performances in classification tasks on the one hand, but also hinder the potential of interpretability of such methods on the other hand. Wan et al. (2020) propose the neural-backed decision tree (NBDT) which transfers the final fully connected layer of a NN into a DT with induced hierarchies for the ease of interpretation, but shares the convolutional backbones with normal deep NNs, yielding the state-of-the-art performances on CIFAR10 and ImageNet classification tasks. However, these advanced methods either employ multiple trees with multiplicative numbers of model parameters, or heavily incorporate deep learning models like CNNs in the DTs. Their interpretability is severely hindered due to their model complexity. To interpret an RL agent, Coppens et al. (2019) propose distilling the RL policy into a differentiable DT by imitating a pre-trained policy. Similarly, Liu et al. ( 2018) apply an imitation learning framework but to the Q value function of the RL agent. They also propose Linear Model U-trees (LMUTs) which allow linear models in leaf nodes. Silva et al. (2019) propose to apply differentiable DTs directly as function approximators for either Q function or the policy in RL. They apply a discretization process and a rule list tree structure to simplify the trees for improving interpretability. The VIPER method proposed by Bastani et al. (2018) also distills policy as NNs into a DT policy with theoretically verifiable capability, but for imitation learning settings and nonparametric DTs only. Our proposed CDT is distinguished from other main categories of methods with differentiable DTs for XRL in the following ways: (i) Compared with SDT (Frosst and Hinton, 2017) , partitions in CDT not only happen in original input space, but also in transformed spaces by leveraging intermediate features. This is well documented in recent works (Kontschieder et al., 2015; Xiao, 2017; Tanno et al., 2018) to improve model capacity, and it can be further extended into hierarchical representation learning with advanced feature learning modules like CNN (Tanno et al., 2018) . (ii) Compared with work by Coppens et al. ( 2019), space partitions are not limited to axis-aligned ones (which hinders the expressivity of trees with certain depths), but achieved with linear models of features as the routing functions. Moreover, the adopted linear models are not a restriction (but as an example) and other interpretable transformations are also allowed in our CDT method. (iii) Compared with ANTs (Tanno et al., 2018) , our CDT method unifies the decision making process based on different intermediate features with a single decision making tree, which follows the low-rank decomposition of a large matrix with linear models. It thus greatly improves the model simplicity for achieving interpretability. About model simplicity and interpretability in DTs, see our motivating example in Appendix A.

3.1. SOFT DECISION TREE (SDT)

A SDT is a differentiable DT with a probabilistic decision boundary at each node. Considering we have a DT of depth D, each node in the SDT can be represented as a weight vector (with the bias as

