JOINT EDGE-MODEL SPARSE LEARNING IS PROVABLY EFFICIENT FOR GRAPH NEURAL NETWORKS

Abstract

Due to the significant computational challenge of training large-scale graph neural networks (GNNs), various sparse learning techniques have been exploited to reduce memory and storage costs. Examples include graph sparsification that samples a subgraph to reduce the amount of data aggregation and model sparsification that prunes the neural network to reduce the number of trainable weights. Despite the empirical successes in reducing the training cost while maintaining the test accuracy, the theoretical generalization analysis of sparse learning for GNNs remains elusive. To the best of our knowledge, this paper provides the first theoretical characterization of joint edge-model sparse learning from the perspective of sample complexity and convergence rate in achieving zero generalization error. It proves analytically that both sampling important nodes and pruning neurons with lowest-magnitude can reduce the sample complexity and improve convergence without compromising the test accuracy. Although the analysis is centered on two-layer GNNs with structural constraints on data, the insights are applicable to more general setups and justified by both synthetic and practical citation datasets.

1. INTRODUCTION

Graph neural networks (GNNs) can represent graph structured data effectively and find applications in objective detection (Shi & Rajkumar, 2020; Yan et al., 2018) , recommendation system (Ying et al., 2018; Zheng et al., 2021) , rational learning (Schlichtkrull et al., 2018) , and machine translation (Wu et al., 2020; 2016) . However, training GNNs directly on large-scale graphs such as scientific citation networks (Hull & King, 1987; Hamilton et al., 2017; Xu et al., 2018) , social networks (Kipf & Welling, 2017; Sandryhaila & Moura, 2014; Jackson, 2010) , and symbolic networks (Riegel et al., 2020) becomes computationally challenging or even infeasible, resulting from both the exponential aggregation of neighboring features and the excessive model complexity, e.g., training a two-layer GNN on Reddit data (Tailor et al., 2020) containing 232,965 nodes with an average degree of 492 can be twice as costly as ResNet-50 on ImageNet (Canziani et al., 2016) in computation resources. The approaches to accelerate GNN training can be categorized into two paradigms: (i) sparsifying the graph topology (Hamilton et al., 2017; Chen et al., 2018; Perozzi et al., 2014; Zou et al., 2019) , and (ii) sparsifying the network model (Chen et al., 2021b; You et al., 2022) . Sparsifying the graph topology means selecting a subgraph instead of the original graph to reduce the computation of neighborhood aggregation. One could either use a fixed subgraph (e.g., the graph typology (Hübler et al., 2008) , graph shift operator (Adhikari et al., 2017; Chakeri et al., 2016) , or the degree distribution (Leskovec & Faloutsos, 2006; Voudigari et al., 2016; Eden et al., 2018) is preserved) or apply sampling algorithms, such as edge sparsification (Hamilton et al., 2017) , or node sparsification (Chen et al., 2018; Zou et al., 2019) to select a different subgraph in each iteration. Sparsifying the network model means reducing the complexity of the neural network model, including removing the non-linear activation (Wu et al., 2019; He et al., 2020) , quantizing neuron weights (Tailor et al., 2020; Bahri et al., 2021) and output of the intermediate layer (Liu et al., 2021) , pruning network (Frankle & Carbin, 2019) , or knowledge distillation (Yang et al., 2020; Hinton et al., 2015; Yao et al., 2020; Jaiswal et al., 2021) . Both sparsification frameworks can be combined, such as joint edge sampling and network model pruning in (Chen et al., 2021b; You et al., 2022) . Despite many empirical successes in accelerating GNN training without sacrificing test accuracy, the theoretical evaluation of training GNNs with sparsification techniques remains largely unexplored. Most theoretical analyses are centered on the expressive power of sampled graphs (Hamilton et al., 2017; Cong et al., 2021; Chen et al., 2018; Zou et al., 2019; Rong et al., 2019) or pruned networks (Malach et al., 2020; Zhang et al., 2021; da Cunha et al., 2022) . However, there is limited generalization analysis, i.e., whether the learned model performs well on testing data. Most existing generalization analyses are limited to two-layer cases, even for the simplest form of feed-forward neural networks (NNs), see, e.g., (Zhang et al., 2020a; Oymak & Soltanolkotabi, 2020; Huang et al., 2021; Shi et al., 2022) ). The linearized model cannot justify the advantages of using multi-layer (G)NNs and network pruning. As far as we know, there is no finite-sample generalization analysis for the joint sparsification, even for two-layer GNNs. Contributions. This paper provides the first theoretical generalization analysis of joint topologymodel sparsification in training GNNs, including (1) explicit bounds of the required number of known labels, referred to as the sample complexity, and the convergence rate of stochastic gradient descent (SGD) to return a model that predicts the unknown labels accurately; (2) quantitative proof for that joint topology and model sparsification is a win-win strategy in improving the learning performance from the sample complexity and convergence rate perspectives. We consider the following problem setup to establish our theoretical analysis: node classification on a one-hidden-layer GNN, assuming that some node features are class-relevant (Shi et al., 2022) , which determines the labels, while some node features are class-irrelevant, which contains only irrelevant information for labeling, and the labels of nodes are affected by the class-relevant features of their neighbors. The data model with this structural constraint characterizes the phenomenon that some nodes are more influential than other nodes, such as in social networks (Chen et al., 2018; Veličković et al., 2018) , or the case where the graph contains redundancy information (Zheng et al., 2020) . Specifically, the sample complexity is quadratic in (1 -β)/α, where α in (0, 1] is the probability of sampling nodes of class-relevant features, and a larger α means class-relevant features are sampled more frequently. β in [0, 1) is the fraction of pruned neurons in the network model using the magnitude-based pruning method such as (Frankle & Carbin, 2019) . The number of SGD iterations to reach a desirable model is linear in (1 -β)/α. Therefore, our results formally prove that graph sampling reduces both the sample complexity and number of iterations more significantly provided that nodes with class-relevant features are sampled more frequently. The intuition is that importance sampling helps the algorithm learns the class-relevant features more efficiently and thus reduces the sample requirement and convergence time. The same learning improvement is also observed when the pruning rate increases as long as β does not exceed a threshold close to 1. Given an undirected graph G(V, E), where V is the set of nodes, E is the set of edges. Let R denote the maximum node degree. For any node v ∈ V, let x v ∈ R d and y v ∈ {+1, -1} denote its input feature and corresponding labelfoot_0 , respectively. Given all node features {x v } v∈V and partially known labels {y v } v∈D for nodes in D ⊂ V, the semi-supervised node classification problem aims to predict all unknown labels y v for v ∈ V/D.



The analysis can be extended to multi-class classification, see Appendix I.



as examples. To the best of our knowledge, only Li et al. (2022); Allen-Zhu et al. (2019a) go beyond two layers by considering three-layer GNNs and NNs, respectively. However, Li et al. (2022) requires a strong assumption, which cannot be justified empirically or theoretically, that the sampled graph indeed presents the mapping from data to labels. Moreover, Li et al. (2022); Allen-Zhu et al. (2019a) focus on a linearized model around the initialization, and the learned weights only stay near the initialization (Allen-Zhu & Li, 2022

Figure 1: Illustration of node classification in the GNN

