LMC: FAST TRAINING OF GNNS VIA SUBGRAPH-WISE SAMPLING WITH PROVABLE CONVERGENCE

Abstract

The message passing-based graph neural networks (GNNs) have achieved great success in many real-world applications. However, training GNNs on large-scale graphs suffers from the well-known neighbor explosion problem, i.e., the exponentially increasing dependencies of nodes with the number of message passing layers. Subgraph-wise sampling methods-a promising class of mini-batch training techniques-discard messages outside the mini-batches in backward passes to avoid the neighbor explosion problem at the expense of gradient estimation accuracy. This poses significant challenges to their convergence analysis and convergence speeds, which seriously limits their reliable real-world applications. To address this challenge, we propose a novel subgraph-wise sampling method with a convergence guarantee, namely Local Message Compensation (LMC). To the best of our knowledge, LMC is the first subgraph-wise sampling method with provable convergence. The key idea of LMC is to retrieve the discarded messages in backward passes based on a message passing formulation of backward passes. By efficient and effective compensations for the discarded messages in both forward and backward passes, LMC computes accurate mini-batch gradients and thus accelerates convergence. We further show that LMC converges to first-order stationary points of GNNs. Experiments on large-scale benchmark tasks demonstrate that LMC significantly outperforms state-of-the-art subgraph-wise sampling methods in terms of efficiency.

1. INTRODUCTION

Graph neural networks (GNNs) are powerful frameworks that generate node embeddings for graphs via the iterative message passing (MP) scheme (Hamilton, 2020) . At each MP layer, GNNs aggregate messages from each node's neighborhood and then update node embeddings based on aggregation results. Such a scheme has achieved great success in many real-world applications involving graph-structured data, such as search engines (Brin & Page, 1998 ), recommendation systems (Fan et al., 2019 ), materials engineering (Gostick et al., 2016) , molecular property prediction (Moloi & Ali, 2005; Kearnes et al., 2016), and combinatorial optimization (Wang et al., 2023) . However, the iterative MP scheme poses challenges to training GNNs on large-scale graphs. One commonly-seen approach to scale deep models to arbitrarily large-scale data with limited GPU memory is to approximate full-batch gradients by mini-batch gradients. Nevertheless, for the graphstructured data, the computational costs for computing the loss across a mini-batch of nodes and the corresponding mini-batch gradients are expensive due to the well-known neighbor explosion problem. Specifically, the embedding of a node at the k-th MP layer recursively depends on the embeddings of its neighbors at the (k -1)-th MP layer. Thus, the complexity grows exponentially with the number of MP layers. To deal with the neighbor explosion problem, recent works propose various sampling techniques to reduce the number of nodes involved in message passing (Ma & Tang, 2021) . For example, nodewise (Hamilton et al., 2017; Chen et al., 2018a) and layer-wise (Chen et al., 2018b; Zou et al., 2019; Huang et al., 2018) sampling methods recursively sample neighbors over MP layers to estimate node embeddings and corresponding mini-batch gradients. Unlike the recursive fashion, subgraph-wise sampling methods (Chiang et al., 2019; Zeng et al., 2020; Fey et al., 2021; Zeng et al., 2021) adopt a

