NBDT: NEURAL-BACKED DECISION TREE

Abstract

Machine learning applications such as finance and medicine demand accurate and justifiable predictions, barring most deep learning methods from use. In response, previous work combines decision trees with deep learning, yielding models that (1) sacrifice interpretability for accuracy or (2) sacrifice accuracy for interpretability. We forgo this dilemma by jointly improving accuracy and interpretability using Neural-Backed Decision Trees (NBDTs). NBDTs replace a neural network's final linear layer with a differentiable sequence of decisions and a surrogate loss. This forces the model to learn high-level concepts and lessens reliance on highlyuncertain decisions, yielding (1) accuracy: NBDTs match or outperform modern neural networks on CIFAR, ImageNet and better generalize to unseen classes by up to 16%. Furthermore, our surrogate loss improves the original model's accuracy by up to 2%. NBDTs also afford (2) interpretability: improving human trust by clearly identifying model mistakes and assisting in dataset debugging. Code and pretrained NBDTs are at github.com/alvinwan/neural-backed-decision-trees.

1. INTRODUCTION

Many computer vision applications (e.g. medical imaging and autonomous driving) require insight into the model's decision process, complicating applications of deep learning which are traditionally black box. Recent efforts in explainable computer vision attempt to address this need and can be grouped into one of two categories: (1) saliency maps and (2) sequential decision processes. Saliency maps retroactively explain model predictions by identifying which pixels most affected the prediction. However, by focusing on the input, saliency maps fail to capture the model's decision making process. For example, saliency offers no insight for a misclassification when the model is "looking" at the right object for the wrong reasons. Alternatively, we can gain insight into the model's decision process by breaking up predictions into a sequence of smaller semantically meaningful decisions as in rule-based models like decision trees. However, existing efforts to fuse deep learning and decision trees suffer from (1) significant accuracy loss, relative to contemporary models (e.g., residual networks), (2) reduced interpretability due to accuracy optimizations (e.g., impure leaves and ensembles), and (3) tree structures that offer limited insight into the model's credibility. To address these, we propose Neural-Backed Decision Trees (NBDTs) to jointly improve both (1) accuracy and (2) interpretability of modern neural networks, utilizing decision rules that preserve (3) properties like sequential, discrete decisions; pure leaves; and non-ensembled predictions. These properties in unison enable unique insights, as we show. We acknowledge that there is no universally-accepted definition of interpretability (Lundberg et al., 2020; Doshi-Velez & Kim, 2017; Lipton, 2016) , so to show interpretability, we adopt a definition offered by Poursabzi-Sangdeh et al. (2018) : A model is interpretable if a human can validate its prediction, determining when the model has made a sizable mistake. We picked this definition for its importance to downstream benefits we can evaluate, specifically (1) model or dataset debugging and (2) improving human trust. To accomplish this, NBDTs replace the final linear layer of a neural network with a differentiable oblique decision tree and, unlike its predecessors (i.e. decision trees, hierarchical classifiers), uses a hierarchy derived from model parameters, does not employ a hierarchical softmax, and can be created from any existing classification neural network without architectural modifications. These improvements tailor the hierarchy to the network rather than overfit to the feature space, lessens the decision tree's reliance on highly uncertain decisions, and encourages accurate recognition of high-level concepts. These benefits culminate in joint improvement of accuracy and interpretability. Our contributions: 1. We propose a tree supervision loss, yielding NBDTs that match/outperform and outgeneralize modern neural networks (WideResNet, EfficientNet) on ImageNet, TinyIma-geNet200, and CIFAR100. Our loss also improves the original model by up to 2%. 2. We propose alternative hierarchies for oblique decision treesinduced hierarchies built using pre-trained neural network weights -that outperform both data-based hierarchies (e.g. built with information gain) and existing hierarchies (e.g. WordNet), in accuracy. 3. We show NBDT explanations are more helpful to the user when identifying model mistakes, preferred when using the model to assist in challenging classification tasks, and can be used to identify ambiguous ImageNet labels. Transfer to Explainable Models. Prior to the recent success of deep learning, decision trees were state-of-the-art on a wide variety of learning tasks and the gold standard for interpretability. Despite this recency, study at the intersection of neural network and decision tree dates back three decades, where neural networks were seeded with decision tree weights (Banerjee, 1990; 1994; Ivanova & Kubat, 1995a; b) , and decision trees were created from neural network queries (Krishnan et al., 1999; Boz, 2000; Dancey et al., 2004; Craven & Shavlik, 1996; 1994) , like distillation (Hinton et al., 2015) .

2. RELATED WORKS

The modern analog of both sets of work (Humbird et al., 2018; Siu, 2019; Frosst & Hinton, 2017) evaluate on feature-sparse, sample-sparse regimes such as the UCI datasets (Dua & Graff, 2017) or MNIST (LeCun et al., 2010) and perform poorly on standard image classification tasks. Hybrid Models. Recent work produces hybrid decision tree and neural network models to scale up to datasets like CIFAR10 (Krizhevsky, 2009), CIFAR100 (Krizhevsky, 2009), TinyImageNet (Le & Yang, 2015), and ImageNet (Deng et al., 2009) . One category of models organizes the neural network into a hierarchy, dynamically selecting branches to run inference (Veit & Belongie, 2018; McGill & Perona, 2017; Teja Mullapudi et al., 2018; Redmon & Farhadi, 2017; Murdock et al., 2016) . However, these models use impure leaves resulting in uninterpretatble, stochastic paths. Other approaches fuse deep learning into each decision tree node: an entire neural network (Murthy et al., 2016) , several layers (Murdock et al., 2016; Roy & Todorovic, 2016) , a linear layer (Ahmed et al., 2016) , or some other parameterization of neural network output (Kontschieder et al., 2015) . These models see reduced interpretability by using k-way decisions with large k (via depth-2 trees) (Ahmed et al., 2016; Guo et al., 2018) or employing an ensemble (Kontschieder et al., 2015; Ahmed et al., 2016) , which is often referred to as a "black box" (Carvalho et al., 2019; Rudin, 2018) . Hierarchical Classification (Silla & Freitas, 2011) . One set of approaches directly uses a preexisting hierarchy over classes, such as WordNet (Redmon & Farhadi, 2017; Brust & Denzler, 2019; Deng et al.) . However conceptual similarity is not indicative of visual similarity. Other models build a hierarchy using the training set directly, via a classic data-dependent metric like Gini impurity (Alaniz & Akata, 2019) or information gain (Rota Bulo & Kontschieder, 2014; Bic ¸ici et al., 2018) . These models are instead prone to overfitting, per (Tanno et al., 2019) . Finally, several works introduce hierarchical surrogate losses (Wu et al., 2017; Deng et al., 2012) , such as hierarchical softmax (Mohammed & Umaashankar, 2018) , but as the authors note, these methods quickly suffer from major accuracy loss with more classes or higher-resolution images (e.g. beyond CIFAR10). We demonstrate hierarchical classifiers attain higher accuracy without a hierarchical softmax.

