DATA-DRIVEN LEARNING OF GEOMETRIC SCATTERING NETWORKS

Abstract

Many popular graph neural network (GNN) architectures, which are often considered as the current state of the art, rely on encoding graph structure via smoothness or similarity between neighbors. While this approach performs well on a surprising number of standard benchmarks, the efficacy of such models does not translate consistently to more complex domains, such as graph data in the biochemistry domain. We argue that these more complex domains require priors that encourage learning of longer range features rather than oversmoothed signals of standard GNN architectures. Here, we propose an alternative GNN architecture, based on a relaxation of recently proposed geometric scattering transforms, which consists of a cascade of graph wavelet filters. Our learned geometric scattering (LEGS) architecture adaptively tunes these wavelets and their scales to encourage band-pass features to emerge in learned representations. This results in a simplified GNN with significantly fewer learned parameters compared to competing methods. We demonstrate the predictive performance of our method on several biochemistry graph classification benchmarks, as well as the descriptive quality of its learned features in biochemical graph data exploration tasks. Our results show that the proposed LEGS network matches or outperforms popular GNNs, as well as the original geometric scattering construction, while retaining certain mathematical properties of its handcrafted (nonlearned) design.

1. INTRODUCTION

Geometric deep learning has recently emerged as an increasingly prominent branch of machine learning in general, and deep learning in particular (Bronstein et al., 2017) . It is based on the observation that many of the impressive achievements of neural networks come in applications where the data has an intrinsic geometric structure which can be used to inform network design and training procedures. For example, in computer vision, convolutional neural networks use the spatial organization of pixels to define convolutional filters that hierarchically aggregate local information at multiple scales that in turn encode shape and texture information in data and task-driven representations. Similarly, in time-series analysis, recurrent neural networks leverage memory mechanisms based on the temporal organization of input data to collect multiresolution information from local subsequences, which can be interpreted geometrically via tools from dynamical systems and spectral analysis. While these examples only leverage Euclidean spatiotemporal structure in data, they exemplify the potential benefits of incorporating information about intrinsic data geometry in neural network design and processing. Indeed, recent advances have further generalized the utilization of geometric information in neural networks design to consider non-Euclidean structures, with particular interest in graphs that represent data geometry, either directly given as input or constructed as an approximation of a data manifold. At the core of geometric deep learning is the use of graph neural networks (GNNs) in general, and graph convolutional networks (GCNs) in particular, which ensure neuron activations follow the geometric organization of input data by propagating information across graph neighborhoods (Bruna et al., 2014; Defferrard et al., 2016; Kipf & Welling, 2016; Hamilton et al., 2017; Xu et al., 2019; Abu-El-Haija et al., 2019) . However, recent work has shown the difficulty in generalizing these methods to more complex structures, identifying common problems and phrasing them in terms of oversmoothing (Li et al., 2018 ), oversquashing (Alon & Yahav, 2020) or under-reaching (Barceló et al., 2020) . Using graph signal processing terminology from Kipf & Welling (2016), these issues can be partly attributed to the limited construction of convolutional filters in many commonly used GCN architectures. Inspired by the filters learned in convolutional neural networks, GCNs consider node features as graph signals and aim to aggregate information from neighboring nodes. For example, Kipf & Welling (2016) presented a typical implementation of a GCN with a cascade of averaging (essentially low pass) filters. We note that more general variations of GCN architectures exist (Defferrard et al., 2016; Hamilton et al., 2017; Xu et al., 2019) , which are capable of representing other filters, but as investigated in Alon & Yahav (2020), they too often have difficulty in learning long range connections. Recently, an alternative approach was presented to provide deep geometric representation learning by generalizing Mallat's scattering transform (Mallat, 2012) , originally proposed to provide a mathematical framework for understanding convolutional neural networks, to graphs (Gao et al., 2019; Gama et al., 2019a; Zou & Lerman, 2019) and manifolds (Perlmutter et al., 2018) . Similar to traditional scattering, which can be seen as a convolutional network with nonlearned wavelet filters, geometric scattering is defined as a GNN with handcrafted graph filters, typically constructed as diffusion wavelets over the input graph (Coifman & Maggioni, 2006) , which are then cascaded with pointwise absolute-value nonlinearities. This wavelet cascade results in permutation equivariant node features that are typically aggregated via statistical moments over the graph nodes, as explained in detail in Sec. 2, to provide a permutation invariant graph-level representation. The efficacy of geometric scattering features in graph processing tasks was demonstrated in Gao et al. ( 2019), with both supervised learning and data exploration applications. Moreover, their handcrafted design enables rigorous study of their properties, such as stability to deformations and perturbations, and provides a clear understanding of the information extracted by them, which by design (e.g., the cascaded band-pass filters) goes beyond low frequencies to consider richer notions of regularity (Gama et al., 2019b; Perlmutter et al., 2019) . However, while graph scattering transforms provide effective universal feature extractors, their rigid handcrafted design does not allow for the automatic task-driven representation learning that naturally arises in traditional GNNs. To address this deficiency, recent work has proposed a hybrid scattering- GCN (Min et al., 2020) model for obtaining node-level representations, which ensembles a GCN model with a fixed scattering feature extractor. In Min et al. (2020) , integrating channels from both architectures alleviates the well-known oversmoothing problem and outperforms popular GNNs on node classification tasks. Here, we focus on improving the geometric scattering transform by learning, in particular its scales. We focus on whole-graph representations with an emphasis on biochemical molecular graphs, where relatively large diameters and non-planar structures usually limit the effectiveness of traditional GNNs. Instead of the ensemble approach of Min et al. ( 2020), we propose a native neural network architecture for learned geometric scattering (LEGS), which directly modifies the scattering architecture from Gao et al. (2019); Perlmutter et al. (2019) , via relaxations described in Sec. 3, to allow a task-driven adaptation of its wavelet configuration via backpropagation implemented in Sec. 4. We note that other recent graph spectrum-based methods approach the learning of long range connections by approximating the spectrum of the graph with the Lancoz algorithm Liao et al. (2019) , or learning in block Krylov subspaces Luan et al. (2019) . Such methods are complementary to the work presented here, in that their spectral approximation can also be applied in the computation of geometric scattering when considering very long range scales (e.g., via spectral formulation of graph wavelet filters). However, we find that such approximations are not necessary in the datasets considered here and in other recent work focusing on whole-graph tasks, where direct computation of polynomials of the Laplacian is sufficient. The resulting learnable geometric scattering network balances the mathematical properties inherited from the scattering transform (as shown in Sec. 3) with the flexibility enabled by adaptive representation learning. The benefits of our construction over standard GNNs, as well as pure geometric scattering, are discussed and demonstrated on graph classification and regression tasks in Sec. 5. In particular, we find that our network maintains the robustness to small training sets present in graph scattering while improving classification on biological graph classification and regression tasks, and we show that in tasks where the graphs have a large diameter relative to their size, learnable scattering features improve performance over competing methods.

