INTERPRETABLE GEOMETRIC DEEP LEARNING VIA LEARNABLE RANDOMNESS INJECTION

Abstract

Point cloud data is ubiquitous in scientific fields. Recently, geometric deep learning (GDL) has been widely applied to solve prediction tasks with such data. However, GDL models are often complicated and hardly interpretable, which poses concerns to scientists who are to deploy these models in scientific analysis and experiments. This work proposes a general mechanism, learnable randomness injection (LRI), which allows building inherently interpretable models based on general GDL backbones. LRI-induced models, once trained, can detect the points in the point cloud data that carry information indicative of the prediction label. We also propose four datasets from real scientific applications that cover the domains of high-energy physics and biochemistry to evaluate the LRI mechanism. Compared with previous post-hoc interpretation methods, the points detected by LRI align much better and stabler with the ground-truth patterns that have actual scientific meanings. LRI is grounded by the information bottleneck principle, and thus LRI-induced models are also more robust to distribution shifts between training and test scenarios.

1. INTRODUCTION

The measurement of many scientific research objects can be represented as a point cloud, i.e., a set of featured points in some geometric space. For example, in high energy physics (HEP), particles generated from collision experiments leave spacial signals on the detectors they pass through (Guest et al., 2018) ; In biology, a protein is often measured and represented as a collection of amino acids with spacial locations (Wang et al., 2004; 2005) . Geometric quantities of such point cloud data often encode important properties of the research object, analyzing which researchers may expect to achieve new scientific discoveries (Tusnady & Simon, 1998; Aad et al., 2012) . Recently, machine learning techniques have been employed to accelerate the procedure of scientific discovery (Butler et al., 2018; Carleo et al., 2019) . For geometric data as above, geometric deep learning (GDL) (Bronstein et al., 2017; 2021) has shown great promise and has been applied to the fields such as HEP (Shlomi et al., 2020; Qu & Gouskos, 2020 ), biochemistry (Gainza et al., 2020; Townshend et al., 2021) and so on. However, geometric data in practice is often irregular and highdimensional. Think about a collision event in HEP that generates hundreds to thousands of particles, or a protein that consists of tens to hundreds of amino acids. Although each particle or each amino acid is located in a low-dimensional space, the whole set of points eventually is extremely irregular and high-dimensional. So, current research on GDL primarily focuses on designing neural network (NN) architectures for GDL models to deal with the above data challenge. GDL models have to preserve some symmetries of the system and incorporate the inductive biases reflected by geometric principles to guarantee their prediction quality (Cohen & Welling, 2016; Bogatskiy et al., 2020) , and therefore often involve dedicated-designed complex NN architectures. Albeit with outstanding prediction performance, the complication behind GDL models makes them hardly interpretable. However, in many scientific applications, interpretable models are in need (Roscher et al., 2020) : For example, in drug discovery, compared with just predicting the binding affinity of a protein-ligand pair, it is more useful to know which groups of amino acids determine the affinity and where can be the binding site, as the obtained knowledge may guide future research directions (Gao et al., 2018; Karimi et al., 2019; 2020) . Moreover, scientists tend to only trust inter- pretable models in many scenarios, e.g., most applications in HEP, where data from real experiments lack labels and models have to be trained on simulation data (Nachman & Shimmin, 2019). Here, model interpretation is used to verify if a model indeed captures the patterns that match scientific principles instead of some spurious correlation between the simulation environment and labels. Unfortunately, to the best of our knowledge, there have been no studies on interpretable GDL models let alone their applications in scientific problems. Some previous post-hoc methods may be extended to interpret a pre-trained GDL model while they suffer from some limitations as to be reviewed in Sec. 2. Moreover, recent works (Rudin, 2019; Laugel et al., 2019; Bordt et al., 2022; Miao et al., 2022) have shown that the data patterns detected by post-hoc methods are often inconsistent across interpretation methods and pre-trained models, and may hardly offer reliable scientific insights. To fill the gap, this work proposes to study interpretable GDL models. Inspired by the recent work (Miao et al., 2022) , we first propose a general mechanism named Learnable Randomness Injection (LRI) that allows building inherently interpretable GDL models based on a broad range of GDL backbones. We then propose four datasets from real-world scientific applications in HEP and biochemistry and provide an extensive comparison between LRI-induced GDL models and previous post-hoc interpretation approaches (after being adapted to GDL models) over these datasets. Our LRI mechanism provides model interpretation by detecting a subset of points from the point cloud that is most likely to determine the label of interest. The idea of LRI is to inject learnable randomness to each point, where, along with training the model for label prediction, injected randomness on the points that are important to prediction gets reduced. The convergent amounts of randomness on points essentially reveal the importance of the corresponding points for prediction. Specifically in GDL, as the importance of a point may be indicated by either the existence of this point in the system or its geometric location, we propose to inject two types of randomness, Bernoulli randomness, with the framework name LRI-Bernoulli to test existence importance of points and Gaussian randomness on geometric features, with the framework name LRI-Gaussian to test location importance of points. Moreover, by properly parameterized such Gaussian randomness, we may tell for a point, how in different directions perturbing its location affects the prediction result more. With such fine-grained geometric information, we may estimate the direction of the particle velocity when analyzing particle collision data in HEP. LRI is theoretically sound as it essentially uses a variational objective derived from the information bottleneck principle (Tishby et al., 2000) . LRI-induced models also show better robustness to the distribution shifts between training and test scenarios, which gives scientists more confidence in applying them in practice. We note that one obstacle to studying interpretable GDL models is the lack of valid datasets that consist of both classification labels and scientifically meaningful patterns to verify the quality of interpretation. Therefore, another significant contribution of this work is to prepare four benchmark datasets grounded on real-world scientific applications to facilitate interpretable GDL research. These datasets cover important applications in HEP and biochemistry. We illustrate the four datasets in Fig. 1 and briefly introduce them below. More detailed descriptions can be found in Appendix C. • ActsTrack is a particle tracking dataset in HEP that is used to reconstruct the properties, such as the kinematics of a charged particle given a set of position measurements from a tracking detector. Tracking is an indispensable step in analyzing HEP experimental data as well as particle tracking used in medical applications such as proton therapy (Schulte et al., 2004; Thomson, 2013; Ai et al., 2022) . Our task is formulated differently from traditional track reconstruction tasks: We predict the existence of a z → µµ decay and use the set of points from the µ's to verify model interpretation, which can be used to reconstruct µ tracks. ActsTrack also provides a controllable environment (e.g., magnetic field strength) to study fine-grained geometric patterns.



Figure 1: Illustrations of the four scientific datasets in this work to study interpretable GDL models.

availability

Our code and datasets are available at https://github.com/Graph-COM/LRI.

