INTERPRETING GRAPH NEURAL NETWORKS FOR NLP WITH DIFFERENTIABLE EDGE MASKING

Abstract

Graph neural networks (GNNs) have become a popular approach to integrating structural inductive biases into NLP models. However, there has been little work on interpreting them, and specifically on understanding which parts of the graphs (e.g. syntactic trees or co-reference structures) contribute to a prediction. In this work, we introduce a post-hoc method for interpreting the predictions of GNNs which identifies unnecessary edges. Given a trained GNN model, we learn a simple classifier that, for every edge in every layer, predicts if that edge can be dropped. We demonstrate that such a classifier can be trained in a fully differentiable fashion, employing stochastic gates and encouraging sparsity through the expected L 0 norm. We use our technique as an attribution method to analyse GNN models for two tasks -question answering and semantic role labelling -providing insights into the information flow in these models. We show that we can drop a large proportion of edges without deteriorating the performance of the model, while we can analyse the remaining edges for interpreting model predictions.

1. INTRODUCTION

Graph Neural Networks (GNNs) have in recent years been shown to provide a scalable and highly performant means of incorporating linguistic information and other structural biases into NLP models. They have been applied to various kinds of representations (e.g., syntactic and semantic graphs, co-reference structures, knowledge bases linked to text, database schemas) and shown effective on a range of tasks, including relation extraction (Zhang et al., 2018; Zhu et al., 2019; Sun et al., 2019a; Guo et al., 2019) , question answering (Sorokin & Gurevych, 2018; Sun et al., 2018; De Cao et al., 2019) , syntactic and semantic parsing tasks (Marcheggiani & Titov, 2017; Bogin et al., 2019; Ji et al., 2019 ), summarisation (Fernandes et al., 2019) , machine translation (Bastings et al., 2017) and abusive language detection in social networks (Mishra et al., 2019) . While GNNs often yield strong performance, such models are complex, and it can be difficult to understand the 'reasoning' behind their predictions. For NLP practitioners, it is highly desirable to know which linguistic information a given model encodes and how that encoding happens (Jumelet & Hupkes, 2018; Giulianelli et al., 2018; Goldberg, 2019) . The difficulty in interpreting GNNs represents a barrier to such analysis. Furthermore, this opaqueness decreases user trust, impedes the discovery of harmful biases, and complicates error analysis (Kim, 2015; Ribeiro et al., 2016b; Sun et al., 2019b; Holstein et al., 2019) . The latter is a particular issue for GNNs, where seemingly small implementation differences can make or break models (Zaheer et al., 2017; Xu et al., 2019) . In this work, we focus on post-hoc analysis of GNNs. We are interested especially in developing a method for understanding how the GNN uses the input graph. As such, we seek to identify which edges in the graph the GNN relies on, and at which layer they are used. We formulate some desiderata for an interpretation method, seeking a technique that is: 1. able to identify relevant paths in the input graph, as paths are one of the most natural ways of presenting GNN reasoning patterns to users; 2. sufficiently tractable to be applicable to modern GNN-based NLP models; 3. as faithful (Jacovi & Goldberg, 2020) as possible, providing insights into how the model truly arrives at the prediction.

