INTERPRETING GRAPH NEURAL NETWORKS FOR NLP WITH DIFFERENTIABLE EDGE MASKING

Abstract

Graph neural networks (GNNs) have become a popular approach to integrating structural inductive biases into NLP models. However, there has been little work on interpreting them, and specifically on understanding which parts of the graphs (e.g. syntactic trees or co-reference structures) contribute to a prediction. In this work, we introduce a post-hoc method for interpreting the predictions of GNNs which identifies unnecessary edges. Given a trained GNN model, we learn a simple classifier that, for every edge in every layer, predicts if that edge can be dropped. We demonstrate that such a classifier can be trained in a fully differentiable fashion, employing stochastic gates and encouraging sparsity through the expected L 0 norm. We use our technique as an attribution method to analyse GNN models for two tasks -question answering and semantic role labelling -providing insights into the information flow in these models. We show that we can drop a large proportion of edges without deteriorating the performance of the model, while we can analyse the remaining edges for interpreting model predictions.

1. INTRODUCTION

Graph Neural Networks (GNNs) have in recent years been shown to provide a scalable and highly performant means of incorporating linguistic information and other structural biases into NLP models. They have been applied to various kinds of representations (e.g., syntactic and semantic graphs, co-reference structures, knowledge bases linked to text, database schemas) and shown effective on a range of tasks, including relation extraction (Zhang et al., 2018; Zhu et al., 2019; Sun et al., 2019a; Guo et al., 2019) , question answering (Sorokin & Gurevych, 2018; Sun et al., 2018; De Cao et al., 2019) , syntactic and semantic parsing tasks (Marcheggiani & Titov, 2017; Bogin et al., 2019; Ji et al., 2019 ), summarisation (Fernandes et al., 2019) , machine translation (Bastings et al., 2017) and abusive language detection in social networks (Mishra et al., 2019) . While GNNs often yield strong performance, such models are complex, and it can be difficult to understand the 'reasoning' behind their predictions. For NLP practitioners, it is highly desirable to know which linguistic information a given model encodes and how that encoding happens (Jumelet & Hupkes, 2018; Giulianelli et al., 2018; Goldberg, 2019) . The difficulty in interpreting GNNs represents a barrier to such analysis. Furthermore, this opaqueness decreases user trust, impedes the discovery of harmful biases, and complicates error analysis (Kim, 2015; Ribeiro et al., 2016b; Sun et al., 2019b; Holstein et al., 2019) . The latter is a particular issue for GNNs, where seemingly small implementation differences can make or break models (Zaheer et al., 2017; Xu et al., 2019) . In this work, we focus on post-hoc analysis of GNNs. We are interested especially in developing a method for understanding how the GNN uses the input graph. As such, we seek to identify which edges in the graph the GNN relies on, and at which layer they are used. We formulate some desiderata for an interpretation method, seeking a technique that is: 1. able to identify relevant paths in the input graph, as paths are one of the most natural ways of presenting GNN reasoning patterns to users; 2. sufficiently tractable to be applicable to modern GNN-based NLP models; 3. as faithful (Jacovi & Goldberg, 2020) as possible, providing insights into how the model truly arrives at the prediction.

Original model Gated model

Figure 1 : GRAPHMASK uses vertex hidden states and messages at layer k (left) as input to a classifier g that predicts a mask z ( ) . We use this to mask the messages of the kth layer and re-compute the forward pass with modified node states (right). The classifier g is trained to mask as many hidden states as possible without changing the output of the gated model. A simple way to perform interpretation is to use erasure search (Li et al., 2016; Feng et al., 2018) , an approach wherein attribution happens by searching for a maximal subset of features that can be entirely removed without affecting model predictions. Applied to GNNs, erasure search would involve searching for the largest subgraph which can be completely discarded. Besides faithfulness considerations and conceptual simplicity, discrete attributions would also simplify the comparison of relevance between paths; this contrasts with continuous attribution to edges, where it is not straightforward to extract and visualise important paths. Furthermore, in contrast to techniques based on artificial gradients (Pope et al., 2019; Xie & Lu, 2019; Schwarzenberg et al., 2019) , erasure search would provide implementation invariance (Sundararajan et al., 2017) . This is important in NLP, as models commonly use highly parametrised decoders on top of GNNs (e.g., Koncel-Kedziorski et al. ( 2019)). While arguably satisfying criteria (1) and (3) in our desiderata, erasure search unfortunately fails on tractability. In practical scenarios, it is infeasible, and even approximations, which remove one feature at a time (Zintgraf et al., 2017) and underestimate their contribution due to saturation (Shrikumar et al., 2017) , remain prohibitively expensive. Our GRAPHMASK aims at meeting the above desiderata by achieving the same benefits as erasure search in a scalable manner. That is, our method makes easily interpretable hard choices on whether to retain or discard edges such that discarded edges have no relevance to model predictions, while remaining tractable and model-agnostic (Ribeiro et al., 2016a) . GRAPHMASK can be understood as a differentiable form of subset erasure, where, instead of finding an optimal subset to erase for every given example, we learn an erasure function which predicts for every edge u, v at every layer k whether that connection should be retained. Given an example graph G, our method returns for each layer k a subgraph G (k) S such that we can faithfully claim that no edges outside G (k) S influence the predictions of the model. To enable gradient-based optimization for our erasure function, we rely on sparse stochastic gates (Louizos et al., 2018; Bastings et al., 2019) . In erasure search, optimisation happens individually for each example. This can result in a form of overfitting where even non-superfluous edges are aggressively pruned because a similar prediction could be made using an alternative smaller subgraph; we refer to this problem as hindsight bias. Because our interpretation method relies on a parametrised erasure function rather than an individual per-edge choice, we can address this issue by amortising parameter learning over a training dataset through a process similar to the readout bottleneck introduced in Schulz et al. (2020) . In other words, the decision to drop or keep an edge is made based on the information available in the network (i.e., representation of the graph nodes) without having access to the final prediction (or to the gold standard). As we demonstrate in Section 4, this strategy avoids hindsight bias.



The removal guarantees that all information about the discarded features is ignored by the model. This contrasts with approaches which use heuristics to define feature importance, for example attention-based methods (Serrano & Smith, 2019; Jain & Wallace, 2019) or back-propagation techniques(Bach et al., 2015; Sundararajan et al.,  2017). They do not guarantee that the model ignores low-scoring features, attracting criticism in recent years (Nie et al.

