BEYOND GNNS: A SAMPLE-EFFICIENT ARCHITEC-TURE FOR GRAPH PROBLEMS

Abstract

Despite their popularity in learning problems over graph structured data, existing Graph Neural Networks (GNNs) have inherent limitations for fundamental graph problems such as shortest paths, k-connectivity, minimum spanning tree and minimum cuts. In all these instances, it is known that one needs GNNs of high depth, scaling at a polynomial rate with the number of nodes n, to provably encode the solution space. This in turn affects their statistical efficiency thus requiring a significant amount of training data in order to obtain networks with good performance. In this work we propose a new hybrid architecture to overcome this limitation. Our proposed architecture that we call as GNN + networks involve a combination of multiple parallel low depth GNNs along with simple pooling layers involving low depth fully connected networks. We provably demonstrate that for many graph problems, the solution space can be encoded by GNN + networks using depth that scales only poly-logarithmically in the number of nodes. This significantly improves the amount of training data needed that we establish via improved generalization bounds. Finally, we empirically demonstrate the effectiveness of our proposed architecture for a variety of graph problems.

1. INTRODUCTION

In recent years Graph Neural Networks (GNNs) have become the predominant paradigm for learning problems over graph structured data (Hamilton et al., 2017; Kipf & Welling, 2016; Veličković et al., 2017) . Computation in GNNs is performed by each node sending and receiving messages along the edges of the graph, and aggregating messages from its neighbors to update its own embedding vector. After a few rounds of message passing, the computed node embeddings are aggregated to compute the final output (Gilmer et al., 2017) . The analogy to message passing leads to a simple and elegant architecture for learning functions on graphs. On the other hand, from a theoretical and practical perspective, we also need these architectures to be sample efficient, i.e., learnable from a small number of training examples, where each training example corresponds to a graph. Recent works have shown that generalization in GNNs depends upon the depth of the architecture, i.e., the number of rounds of message passing, as well as the embedding size for each node in the graph (Garg et al., 2020) . However, this requirement is in fundamental conflict with the message passing framework. In particular, using GNNs to compute several fundamental graph problems such as shortest paths, minimum spanning tree, min cut etc., necessarily requires the product of the depth of the GNN and the embedding size to scale as √ n where n is the size of the graph (Loukas, 2020) . This in turn places a significant statistical burden when learning these fundamental problems on large scale graphs. The above raises the the following question: Can one develop sample efficient architectures for graph problems while retaining the simplicity of the message passing framework? Several recent works have tried to address the above question by proposing extensions to the basic GNN framework by augmenting various pooling operations in conjunction with message passing rounds to capture more global structure (Ying et al., 2018; Simonovsky & Komodakis, 2017; Fey et al., 2018) . While these works demonstrate an empirical advantage over GNNs, we currently do not know of a general neural architecture that is versatile enough to provably encode the solution space of a variety of graph problems such as shortest paths and minimum spanning trees, while being significantly superior to GNNs in terms of statistical efficiency. In this work we propose a theoretically principled architecture, called GNN + networks for learning graph problems. While the basic GNN framework is inspired from classical message passing style models studied in distributed computing, we borrow from two fundamental paradigms in graph algorithm design namely, sketching and parallel computation, to design GNN + networks. As a result of combining these two powerful paradigms, we get a new neural architecture that simultaneously achieve low depth and low embedding size for many fundamental graph problems. As a result our proposed GNN + architecture have a significantly smaller number of parameters that provably leads to better statistical efficiency than GNNs. Before we present our improved architecture, we briefly describe the standard GNN framework. Model for GNNs. In this work we will study GNNs that fall within the message passing framework and using notation from previous works we denote such networks as GNN mp (Loukas, 2020). A GNN mp network operates in the AGGREGATE and COMBINE model (Gilmer et al., 2017) that captures many popular variants such as GraphSAGE, Graph Convolutional Networks (GCNs) and GIN networks (Hamilton et al., 2017; Kipf & Welling, 2016; Xu et al., 2019a) . Given a graph G = (V, E), let x (k) i denote the feature representation of node i at layer k. Then we have a (k-1) i = AGGREGATE({x (k-1) j : j ∈ N (i)}) x (k) i = COMBINE(x (k-1) i , a (k-1) i ). (2) Here N (i) is the set of neighbors for node i. Typically the aggregation and combination is performed via simple one or two layer full connected networks (FNNs), also known as multi layer perceptrons (MLPs). In the rest of the paper we will use the two terms interchangeably. GNN + Networks. Our proposed GNN + networks consist of one or more layers of a GNN + block shown in Figure 1 . The GNN + block comprises of r parallel GNN mp networks follows by s parallel fully connected network modules for pooling where r and s are the hyperparameters of the architecture. More importantly we restrict the r GNN mp modules to share the same set of weights. Hence the parallel GNN mp modules only differ in the way the node embeddings are initialized. Furthermore, we restrict each GNN mp to be of low depth. In particular, for degree-d graphs of diameter D, over n nodes, we will restrict the GNN mp to be of depth O((d + D) • polylog(n)). Similarly, we require the s fully connected networks to be of depth O((d + D) • polylog(n)) and share the network weights. We connect the outputs of the GNN mp modules to the fully connected pooling networks in a sparse manner and restrict the input size of each fully connected network to be O((d + D) • polylog(n)). Stacking up L layers of GNN + blocks results in a GNN + network that is highly parameter efficient and in total has O((d + D)L • polylog(n)) parameters. For such a network we call the depth as the total number of message passing rounds and the number of MLP layers used across all the L stacks. Since we restrict our MLPs and GNN mp blocks inside a GNN + network to be of low depth, we will often abuse notation and refer to a GNN + architecture with L stacks of GNN + blocks as a depth L architecture. Our proposed design lets us alternate between local computations involving multiple parallel GNN blocks and global post-processing stages, while still being sample efficient due to the enforced parameter sharing. We will show via several applications that optimal or near-optimal solutions to many popular graph problems can indeed be computed via a GNN + architecture. Below we briefly summarize our main results. Overview of Results. To demonstrate the generality of our proposed GNN + architecture, we study several fundamental graph problems and show how to construct efficient GNN + networks to compute optimal or near optimal solutions to these problems. In particular, we will focus on degree-d graphs, i.e., graphs of maximum node degree d, with n nodes and diameter D and will construct GNN + networks of depth polylog(n) and O (D + d)polylog(n) total parameters. Shortest Paths. The first problem we consider is the fundamental graph problem of computing (approximate) all pairs shortest paths in undirected graphs. Given a graph G = (V, E), let d G (u, v) be the shortest path between nodes u and v. We say that an output { dG (u, v) : u, v ∈ V } is an



Figure 1: The basic GNN + block.

