DT+GNN: A FULLY EXPLAINABLE GRAPH NEURAL NETWORK USING DECISION TREES

Abstract

We propose a new Decision Tree Graph Neural Network (DT+GNN) architecture for Graph Neural Network (GNN) explanation. Existing post-hoc explanation methods highlight important inputs, but fail to reveal how a GNN uses these inputs. In contrast DT+GNN is fully explainable: Humans can inspect and understand the decision making of DT+GNN at every step. DT+GNN internally uses a novel GNN layer that is restricted to categorical state spaces for nodes and messages. After training with gradient descent, we can distill these layers into decision trees. These trees are further pruned using our newly proposed method to ensure they are small and easy to interpret. DT+GNN can also compute node-level importance scores like the existing explanation methods. We demonstrate on real-world GNN benchmarks that DT+GNN has competitive classification accuracy and computes competitive explanations. Furthermore, we leverage DT+GNN's full explainability to inspect the decision processes in synthetic and real-world datasets with surprising results. We make these inspections accessible through an interactive web tool.

1. INTRODUCTION

Graph Neural Networks (GNNs) have been successful in applying machine learning techniques to many graph-based domains [5; 20; 56; 28] . However, current GNNs are black-box models, and their lack of human interpretability limits their use in many application areas. This motivated the adoption of existing deep learning explanation methods [4; 25; 41] to GNNs, and creating new 51] . Such methods allow us to understand what parts of the input were important for making a prediction. However, it usually remains difficult to impossible for users to understand why these parts were important. In this dataset nodes are secondary structural elements of amino acids. Each node has one of three types: helix (input 0), sheet (input 1), or turn (input 2). Figure (a) shows an example graph consisting of mostly helixes (green) and two sheets (yellow). The final layer (b) learns to decide whether a protein is an enzyme based on how many sheet respectively "no sheet" nodes the graph has in the input layer. Whether nodes are helixes or turns does not seem to matter, and consequently the input layer (c) learns to only distinguish between sheet and no sheet. Both trees (b,c) together then imply: A protein is an enzyme if it has at most 8 nodes that are not sheets, or if it has at most 3 sheets. This DT+GNN explanation is consistent with previous human explanations [16]. As a motivating example, consider the PROTEINS example in Figure 1a . Previous methods classify the example as an enzyme, and mark the two yellow nodes in Figure 1a as important. But why are the two yellow nodes important? Is it because they are connected? Do we need exactly two yellow nodes? Or could there be three as well? Existing explanation methods that compute importance scores do not answer such questions, and as such humans have to try to figure out the decision rules by carefully looking at dozens of examples. We overcome this limitation with a new fully explainable GNN architecture called Decision Tree GNN (DT+GNN). Figures 1b and 1c show how DT+GNN describes the decision process (including all thresholds) on the PROTEINS dataset. This explanation goes beyond understanding individual examples. Instead DT+GNN explains how the GNN solves the whole dataset. For PROTEINS, DT+GNN can explain the decision with two simple decision trees. But DT+GNN can also fully explain datasets such as or Tree-Cycle [59] . These advanced datasets combine multi-layer reasoning typical for GNNs, including detecting neighborhood patterns and degree counting. In summary, our contributions are as follows: • We propose a new differentiable Diff-DT+GNN layer. While traditional GNNs are based on synchronous message passing [32], our new layer is inspired on a simplified distributed computing model known as the stone age model [15] . In this model, nodes use a small categorical space for states and messages. We argue that the stone age model is more suitable for interpretation while retaining a high theoretical expressiveness. • We distill all Multi-Layer Perceptrons that Diff-DT+GNN uses internally to Decision Trees. Our new model is called DT+GNN (decision tree graph neural network); it consists of a series of decision tree layers. Our new model is fully explainable, because one can just follow the decision tree(s) to understand a decision. • We propose a way to collectively prune the decision trees in DT+GNN without compromising accuracy. This leads to smaller trees, which further increases explainability. We can further use these trees to compute node-level importance scores similar to current explanation methods. • We test our proposed architecture on established GNN explanation benchmarks and real-world graph datasets. We show that our models are competitive in classification accuracy with traditional GNNs and competitive in explanation accuracy with existing explanation methods. We further validate that the proposed pruning methods considerably reduce tree sizes. Also, we demonstrate DT+GNN's full explainability to discover problems in existing explanation benchmarks and to find interesting insights into real-world datasets. • We provide a user interface for DT+GNN.foot_0 This tool allows for the interactive exploration of the DT+GNN decision process on the datasets examined in this paper. We provide a manual for the interface in Appendix A.

2.1. EXPLANATION METHODS FOR GNNS

In recent years, several methods for providing GNN explanations were proposed. These methods highlight which parts of the input are important in a GNN decision, usually by assigning importance scores to nodes and edges or by finding similar predictions. These scores and examples can assist humans to find patterns that might reveal the GNN decision process. The existing explanation methods can be roughly grouped into the following six groups: 



https://interpretable-gnn.netlify.app/



Figure1: (a) An example graph in the PROTEINS dataset. In this dataset nodes are secondary structural elements of amino acids. Each node has one of three types: helix (input 0), sheet (input 1), or turn (input 2). Figure (a) shows an example graph consisting of mostly helixes (green) and two sheets (yellow). The final layer (b) learns to decide whether a protein is an enzyme based on how many sheet respectively "no sheet" nodes the graph has in the input layer. Whether nodes are helixes or turns does not seem to matter, and consequently the input layer (c) learns to only distinguish between sheet and no sheet. Both trees (b,c) together then imply: A protein is an enzyme if it has at most 8 nodes that are not sheets, or if it has at most 3 sheets. This DT+GNN explanation is consistent with previous human explanations[16].

Gradient based. Baldassarre & Azizpour[4]  andPope et al. [41]  show that it is possible to adopt gradient-based methods that we know from computer vision, for example Grad-CAM[48], to graphs. Gradients can be computed on node features and edges[47]. Mutual-Information based. Ying et al.[59]  measure the importance of edges and node features. Edges are masked with continuous values. Instead of gradients, the authors use mutual information between the inputs and the prediction to quantify the importance.Luo et al. [35]  follow a similar idea

