LMC: FAST TRAINING OF GNNS VIA SUBGRAPH-WISE SAMPLING WITH PROVABLE CONVERGENCE

Abstract

The message passing-based graph neural networks (GNNs) have achieved great success in many real-world applications. However, training GNNs on large-scale graphs suffers from the well-known neighbor explosion problem, i.e., the exponentially increasing dependencies of nodes with the number of message passing layers. Subgraph-wise sampling methods-a promising class of mini-batch training techniques-discard messages outside the mini-batches in backward passes to avoid the neighbor explosion problem at the expense of gradient estimation accuracy. This poses significant challenges to their convergence analysis and convergence speeds, which seriously limits their reliable real-world applications. To address this challenge, we propose a novel subgraph-wise sampling method with a convergence guarantee, namely Local Message Compensation (LMC). To the best of our knowledge, LMC is the first subgraph-wise sampling method with provable convergence. The key idea of LMC is to retrieve the discarded messages in backward passes based on a message passing formulation of backward passes. By efficient and effective compensations for the discarded messages in both forward and backward passes, LMC computes accurate mini-batch gradients and thus accelerates convergence. We further show that LMC converges to first-order stationary points of GNNs. Experiments on large-scale benchmark tasks demonstrate that LMC significantly outperforms state-of-the-art subgraph-wise sampling methods in terms of efficiency.

1. INTRODUCTION

Graph neural networks (GNNs) are powerful frameworks that generate node embeddings for graphs via the iterative message passing (MP) scheme (Hamilton, 2020) . At each MP layer, GNNs aggregate messages from each node's neighborhood and then update node embeddings based on aggregation results. Such a scheme has achieved great success in many real-world applications involving graph-structured data, such as search engines (Brin & Page, 1998) , recommendation systems (Fan et al., 2019) , materials engineering (Gostick et al., 2016) , molecular property prediction (Moloi & Ali, 2005; Kearnes et al., 2016), and combinatorial optimization (Wang et al., 2023) . However, the iterative MP scheme poses challenges to training GNNs on large-scale graphs. One commonly-seen approach to scale deep models to arbitrarily large-scale data with limited GPU memory is to approximate full-batch gradients by mini-batch gradients. Nevertheless, for the graphstructured data, the computational costs for computing the loss across a mini-batch of nodes and the corresponding mini-batch gradients are expensive due to the well-known neighbor explosion problem. Specifically, the embedding of a node at the k-th MP layer recursively depends on the embeddings of its neighbors at the (k -1)-th MP layer. Thus, the complexity grows exponentially with the number of MP layers. To deal with the neighbor explosion problem, recent works propose various sampling techniques to reduce the number of nodes involved in message passing (Ma & Tang, 2021) . For example, nodewise (Hamilton et al., 2017; Chen et al., 2018a) and layer-wise (Chen et al., 2018b; Zou et al., 2019; Huang et al., 2018) sampling methods recursively sample neighbors over MP layers to estimate node embeddings and corresponding mini-batch gradients. Unlike the recursive fashion, subgraph-wise sampling methods (Chiang et al., 2019; Zeng et al., 2020; Fey et al., 2021; Zeng et al., 2021) adopt a cheap and simple one-shot sampling fashion, i.e., sampling the same subgraph constructed based on a mini-batch for different MP layers. By discarding messages outside the mini-batches, subgraphwise sampling methods restrict message passing to the mini-batches such that the complexity grows linearly with the number of MP layers. Moreover, subgraph-wise sampling methods are applicable to a wide range of GNN architectures by directly running GNNs on the subgraphs constructed by the sampled mini-batches (Fey et al., 2021) . Because of these advantages, subgraph-wise sampling methods have recently drawn increasing attention. Despite the empirical success of subgraph-wise sampling methods, discarding messages outside the mini-batch sacrifices the gradient estimation accuracy, which poses significant challenges to their convergence behaviors. First, recent works (Chen et al., 2018a; Cong et al., 2020) demonstrate that the inaccurate mini-batch gradients seriously hurt the convergence speeds of GNNs. Second, in Section 7.3, we demonstrate that many subgraph-wise sampling methods are difficult to resemble full-batch performance under small batch sizes, which we usually use to avoid running out of GPU memory in practice. These issues seriously limit the real-world applications of GNNs. In this paper, we propose a novel subgraph-wise sampling method with a convergence guarantee, namely Local Message Compensation (LMC), which uses efficient and effective compensations to correct the biases of mini-batch gradients and thus accelerates convergence. To the best of our knowledge, LMC is the first subgraph-wise sampling method with provable convergence. Specifically, we first propose unbiased mini-batch gradients for the one-shot sampling fashion, which helps decompose the gradient computation errors into two components: the bias from the discarded messages and the variance of the unbiased mini-batch gradients. Second, based on a message passing formulation of backward passes, we retrieve the messages discarded by existing subgraph-wise sampling methods during the approximation to the unbiased mini-batch gradients. Finally, we propose efficient and effective compensations for the discarded messages with a combination of incomplete up-to-date messages and messages generated from historical information in previous iterations, avoiding the exponentially growing time and memory consumption. An appealing feature of the resulting mechanism is that it can effectively correct the biases of mini-batch gradients, leading to accurate gradient estimation and the speed-up of convergence. We further show that LMC converges to first-order stationary points of GNNs. Notably, the convergence of LMC is based on the interactions between mini-batch nodes and their 1-hop neighbors, without the recursive expansion of neighborhoods to aggregate information far away from the mini-batches. Experiments on largescale benchmark tasks demonstrate that LMC significantly outperforms state-of-the-art subgraphwise sampling methods in terms of efficiency. Moreover, under small batch sizes, LMC outperforms the baselines and resembles the prediction performance of full-batch methods.

2. RELATED WORK

In this section, we discuss some works related to our proposed method. Subgraph-wise Sampling Methods. Subgraph-wise sampling methods sample a mini-batch and then construct the same subgraph based on it for different MP layers (Ma & Tang, 2021) . For example, Cluster-GCN (Chiang et al., 2019) and GraphSAINT (Zeng et al., 2020) construct the subgraph induced by a sampled mini-batch. They encourage connections between the sampled nodes by graph clustering methods (e.g., METIS (Karypis & Kumar, 1998) and Graclus (Dhillon et al., 2007) ), edge, node, or random-walk-based samplers. GNNAutoScale (GAS) (Fey et al., 2021) and MVS-GNN (Cong et al., 2020) use historical embeddings to generate messages outside a sampled subgraph, maintaining the expressiveness of the original GNNs. Recursive Graph Sampling Methods. Both node-wise and layer-wise sampling methods recursively sample neighbors over MP layers and then construct different computation graphs for each MP layer. Node-wise sampling methods (Hamilton et al., 2017; Chen et al., 2018a) aggregate messages from a small subset of sampled neighborhoods at each MP layer to decrease the bases in the exponentially increasing dependencies. To avoid the exponentially growing computation, layer-wise sampling methods (Chen et al., 2018b; Zou et al., 2019; Huang et al., 2018) independently sample nodes for each MP layer and then use importance sampling to reduce variance, resulting in a constant sample size in each MP layer. Pre-Processing Methods. Another line for scalable graph neural networks is to develop preprocessing Methods. They aggregate the raw input features and then take the pre-processing features as input into subsequent models. As the aggregation has no parameters, they can use stochastic gradient descent to train the subsequent models without the neighbor explosion problem. While

