PROTOGNN: PROTOTYPE-ASSISTED MESSAGE PASS-ING FRAMEWORK FOR NON-HOMOPHILOUS GRAPHS

Abstract

Many well-known Graph Neural Network (GNN) models assume the underlying graphs are homophilous, where nodes share similar features and labels with their neighbours. They rely on message passing that iteratively aggregates neighbour's features and often suffer performance degradation on non-homophilous graphs where useful information is hardly available in the local neighbourhood. In addition, earlier studies show that in some cases, GNNs are even outperformed by Multi-Layer Perceptron, indicating insufficient exploitation of node feature information. Motivated by these two limitations, we propose ProtoGNN, a novel message passing framework that augments existing GNNs by effectively combining node features with structural information. ProtoGNN learns multiple class prototypes for each class from raw node features with the slot-attention mechanism. These prototype representations are then transferred onto the structural node features with explicit message passing to all non-training nodes irrespective of distance. This form of message passing, from training nodes to class prototypes to non-training nodes, also serves as a shortcut that bypasses local graph neighbourhoods and captures global information. ProtoGNN is a generic framework which can be applied onto any of the existing GNN backbones to improve node representations when node features are strong and local graph information is scarce. We demonstrate through extensive experiments that ProtoGNN brings performance improvement to various GNN backbones and achieves state-of-the-art performance on several non-homophilous datasets.

1. INTRODUCTION

Graph Neural Networks (GNNs) have emerged as prominent models for learning representations on graph-structured data. GNNs iteratively update the node embeddings based on the node features of its own and its local neighbors on the graph (Kipf & Welling, 2016; Defferrard et al., 2016; Gilmer et al., 2017) . This form of iterative message passing provides strong structural inductive bias that assumes the presence of useful information in local neighborhoods. This assumption holds true for homophilous graphs whose connected nodes tend to share similar features and labels. However, many real-world graphs are structured in a heterophilous way wherein nodes tend to be dissimilar with their local neighbors. Such non-homophilous graphs can be found in domains of fraud detection (Shi et al., 2022; Pandit et al., 2007 ), molecular biology (Ye et al., 2022) and certain social networks (Lim et al., 2021) . Recent studies show that many GNNs (e.g. GCN (Kipf & Welling, 2016) ) fail to learn well on nonhomophilous graphs (Zhu et al., 2020b; 2021) , and are even outperformed by simple architectures such as Multi-Layer Perceptron (MLP) that ignore graph structures and only leverage node features. This indicates two limitations of these GNNs. First, information of node features is not sufficiently exploited and might be diluted by iterative message passing. Second, the prevailing structural inductive bias that utilizes local neighbourhood information may not be helpful on non-homophilous graphs, since useful information is often scarce within the local neighbourhood. Therefore, it is essential to go beyond the local neighborhood for useful information. Existing works either explicitly capture multi-hop information (Zhu et al., 2020b) or try to leverage globally available correlations with distant nodes (Suresh et al., 2021; Li et al., 2022) . However, such methods often go beyond the linear complexity of GNN message passing (Suresh et al., 2021) . More importantly, since relevant information is often scattered within the vast amount of distant nodes, identifying and extracting such information is a challenging task. To address the two limitations, we draw inspiration from prototypical networks (Snell et al., 2017) , an efficient and effective paradigm that constructs a single representation for each class called a prototype to assist learning in data-scarce settings. Learning node representations on graph structures using class prototypes can facilitate information flow from distant nodes and capture global correlations efficiently. We observe that in many datasets, nodes from the same class form multiple clusters (see Fig. 1(a) ). If we use a single prototype for each class, the prototype representation would be inaccurate or insufficient. For example, the single prototype may learn to represent the average of different clusters. Alternatively, the single prototype may learn to represent one of the dominant clusters in the feature space, therefore failing to capture information of the remaining clusters. Hence using a single prototype for each class may lead to suboptimal performance. In this work, we propose ProtoGNN, a novel message-passing framework that augments existing GNNs and overcomes two limitations of GNNs on heterophilous graphs by facilitating efficient information transfer from distant nodes. In ProtoGNN, we first disentangle the two sources of information from the graph in the form of node feature view and structure view. With the node features of the training nodes, we construct multiple prototypes from each class independent of graph structure by adapting the mechanism of slot-attention (Locatello et al., 2020) . These class prototypes are learnt from the node feature space via multiple rounds of attention. It provides an effective way of soft-clustering the node features as well as the needed model capacity for learning different clusters within the same class. Alternatively, this process can also be viewed as message passing from training nodes to prototype nodes as illustrated in Fig. 1(b ). The learned prototype representations are then transferred onto all node representations from the structure view to make them more discriminative. To enhance the effectiveness of ProtoGNN, we design two regularizers. To exploit the presence of the feature and structure views for regularization, we enforce cross-view compatibility between the prototypes (feature view) and node embeddings (structure view) through a hierarchical compatibility function which handles the presence of multiple prototypes within each class. To further exploit the full power of prototypes, we regularize them so that they are distinct from each other by encouraging the prototypes to be orthogonal to each other through an orthogonality loss. The learning of prototypes introduces artificial edges from training nodes of each class to all nontraining nodes via prototype nodes. This serves as a shortcut in message passing which bypasses local graph neighbourhoods and captures global information, while maintaining the linear complexity with respect to the backbone GNNs. Additionally, since the prototypes are learnt from node feature space, it preserves strong feature information that might be diluted by traditional message passing. Overall, ProtoGNN is a generic framework which presents the following advantages: (1) It is orthogonal to existing GNN backbones and can be applied onto any of them to improve node representations in heterophilous graphs. (2) It is efficient and can capture global node correlations with only O(n) additional edges. (3) It can be useful in graphs where distant information is helpful and node feature information is strong, even in homophilous settings when label rate is low. (4) It preserves the node features from dilution due to message passing. We conduct extensive experiments



Figure 1: (a) Visualization of features for Penn94 and Twitch Gamers datasets. The plots show clustering patterns of the nodes in the feature space. (b) Illustration of information transfer from training nodes to prototypes. Yellow and blue represent different classes. The prototype nodes learn multiple representations per class from the training nodes and transfer the information to non-training nodes through attention.

