DATA-DRIVEN LEARNING OF GEOMETRIC SCATTERING NETWORKS

Abstract

Many popular graph neural network (GNN) architectures, which are often considered as the current state of the art, rely on encoding graph structure via smoothness or similarity between neighbors. While this approach performs well on a surprising number of standard benchmarks, the efficacy of such models does not translate consistently to more complex domains, such as graph data in the biochemistry domain. We argue that these more complex domains require priors that encourage learning of longer range features rather than oversmoothed signals of standard GNN architectures. Here, we propose an alternative GNN architecture, based on a relaxation of recently proposed geometric scattering transforms, which consists of a cascade of graph wavelet filters. Our learned geometric scattering (LEGS) architecture adaptively tunes these wavelets and their scales to encourage band-pass features to emerge in learned representations. This results in a simplified GNN with significantly fewer learned parameters compared to competing methods. We demonstrate the predictive performance of our method on several biochemistry graph classification benchmarks, as well as the descriptive quality of its learned features in biochemical graph data exploration tasks. Our results show that the proposed LEGS network matches or outperforms popular GNNs, as well as the original geometric scattering construction, while retaining certain mathematical properties of its handcrafted (nonlearned) design.

1. INTRODUCTION

Geometric deep learning has recently emerged as an increasingly prominent branch of machine learning in general, and deep learning in particular (Bronstein et al., 2017) . It is based on the observation that many of the impressive achievements of neural networks come in applications where the data has an intrinsic geometric structure which can be used to inform network design and training procedures. For example, in computer vision, convolutional neural networks use the spatial organization of pixels to define convolutional filters that hierarchically aggregate local information at multiple scales that in turn encode shape and texture information in data and task-driven representations. Similarly, in time-series analysis, recurrent neural networks leverage memory mechanisms based on the temporal organization of input data to collect multiresolution information from local subsequences, which can be interpreted geometrically via tools from dynamical systems and spectral analysis. While these examples only leverage Euclidean spatiotemporal structure in data, they exemplify the potential benefits of incorporating information about intrinsic data geometry in neural network design and processing. Indeed, recent advances have further generalized the utilization of geometric information in neural networks design to consider non-Euclidean structures, with particular interest in graphs that represent data geometry, either directly given as input or constructed as an approximation of a data manifold. At the core of geometric deep learning is the use of graph neural networks (GNNs) in general, and graph convolutional networks (GCNs) in particular, which ensure neuron activations follow the geometric organization of input data by propagating information across graph neighborhoods (Bruna et al., 2014; Defferrard et al., 2016; Kipf & Welling, 2016; Hamilton et al., 2017; Xu et al., 2019; Abu-El-Haija et al., 2019) . However, recent work has shown the difficulty in generalizing these methods to more complex structures, identifying common problems and phrasing them in terms of oversmoothing (Li et al., 2018 ), oversquashing (Alon & Yahav, 2020) or under-reaching (Barceló et al., 2020) . Using graph signal processing terminology from Kipf & Welling (2016), these issues

