GATED RELATIONAL GRAPH ATTENTION NETWORKS

Abstract

Relational Graph Neural Networks (GNN), like all GNNs, suffer from a drop in performance when training deeper networks, which may be caused by vanishing gradients, over-parameterization, and over-smoothing. Previous works have investigated methods that improve the training of deeper GNNs, which include normalization techniques and various types of skip connection within a node. However, learning long-range patterns in multi-relational graphs using GNNs remains an under-explored topic. In this work, we propose a novel relation-aware GNN architecture based on the Graph Attention Network that uses gated skip connections to improve long-range modeling between nodes and uses a more scalable vector-based approach for parameterizing relations. We perform extensive experimental analysis on synthetic and real data, focusing explicitly on learning long-range patterns. The proposed method significantly outperforms several commonly used GNN variants when used in deeper configurations and stays competitive to existing architectures in a shallow setup.

1. INTRODUCTION

In this work, we focus on learning long-range patterns in multi-relational graphs using graph neural networks (GNN), which heavily relies on the ability to train deep networksfoot_0 . However, GNNs suffer from decreasing performance when the number of layers is increased. Zhao & Akoglu (2020) point out that this may be due to (1) over-fitting, (2) vanishing gradients, and (3) over-smoothing (the phenomenon where node representations become less distinguishable from each other when more layers are used). Recently, several works have investigated over-fitting (Vashishth et al., 2020 ), over-smoothing (Li et al., 2018; Chen et al., 2019; Zhao & Akoglu, 2020; Rong et al., 2019; Yang et al., 2020 ), over-squashing (Alon & Yahav, 2020) , and possible vanishing gradient (Hochreiter & Schmidhuber, 1997; Pascanu et al., 2013; He et al., 2016) problems in GNNs (Li et al., 2019a; 2020; Rahimi et al., 2018) . One simple but effective technique to improve the training of deeper GNNs is using skip connections, for example, implemented by the Gated Recurrent Unit (GRU) (Cho et al.) in the Gated GNN (GGNN) (Li et al., 2015) . Such connections can improve learning deep GNNs as it avoids vanishing gradients towards lower-layer representations of the same node, and reduces over-smoothing (Hamilton, 2020) . However, such vertical skip connections are not sufficient to enable learning long-range patterns. In addition, in relational GNNs, such as the Relational Graph Convolutional Network (RGCN) (Schlichtkrull et al., 2018) , training difficulties may arise from the methods used for integrating relation information, which may suffer from over-parameterization and impact backpropagation. In this work, we develop a novel GNN architecture for multi-relational graphs that reduces the vanishing gradient and over-parameterization problems that occur with existing methods and improves generalization when learning long-range patterns using deeper networks. Several changes are proposed to the Graph Attention Network (GAT) (Veličković et al., 2018) , including a modified attention mechanism, an alternative GRU-based update function, and a gated relation-aware message function. An extensive experimental study is conducted that (1) shows that different existing relationaware GNNs fail to learn simple patterns in a simple synthetic sequence-based graph classification task, (2) presents a comparison and ablation study on a synthetic node classification task, (3) shows that our architecture is competitive with existing ones on an entity classification task using real-world data from previous work. Many popular GNNs can be formulated within a message passing (MP) framework (Gilmer et al., 2017) . Consider a multi-relational graph G = (V, E), where E consists of edges e which are specified by a triple (u, v, r) that define a directed edge from u to v, labeled by edge type (relation) r. A GNN maps each node v ∈ V onto representation vectors h 1 v . . . h K v by repeatedly aggregating the representations of the immediate neighbours of every node and updating node representations in every step of the encoding process, each associated with one of K layers of the GNN. Relational GNNs must also take into account the edge types between nodes in the graph G. In the following sections, we use the message passing framework where a single GNN layer/step is decomposed into a three-step process: h (k) v = φ(h (k-1) v , γ({µ(h (k-1) u , r)} (u,v,r)∈E(•,v) )) , where µ(•) computes a "message" along a graph edge using the neighbour representation h (k-1) v and the edge type r of an edge from u to v, γ(•) aggregates the incoming messages into a single vector, and φ(•) computes a new representation for node v. E(•, v) denotes the set of all edges in G that end in v. After subsequently applying Eq. 1 K times to each node of the graph, the final node representations h (K) v can be used for different tasks, such as graph or node classification. One commonly used message function is a relation-specific linear transformation, which is used by RGCNs and GGNNs: µ MM (h u , r) = W r h u , where W r is a R D×D parameter matrix associated with edge type r. Moreover, GGNNs implement the update function φ(•) using a GRU: h (k) v = GRU(h (k-1) v , h (k-1) v ), where the GRU's input argument is h (k-1) v = γ({µ(h (k-1) u , r)} (u,v,r)∈E(•,v) ), the vector for the aggregated neighbourhood of v.

3. MOTIVATION: BREADTH-WISE BACKPROPAGATION

Many methods have very recently been proposed to mitigate vanishing gradients in GNNs, using techniques such as residual connections (Li et al., 2019a; 2020) , "jumping knowledge" connections (Xu et al., 2018), and DenseNets (Huang et al., 2017) , which have effects similar to the Gated GNN from Li et al. (2015) . However, these techniques address only backpropagation in the depth of the network (i.e., vertically towards lower-level features of the same node). As we will now discuss, this is not sufficient for learning long-range patterns since breadth-wise backpropagation (i.e., horizontally towards the neighbours) in such GNNs may still result in learning problems and vanishing gradients. Consider a simple graph convolutional network (GCN) with the general update of the hidden units h (k) v = σ(W (k) (h (k-1) v + (u,v)∈E h (k) u )) , where σ is a non-linearity such as a ReLU. Following the classical arguments, functions of this form will suffer from vanishing or exploding gradients because of the stacking of linear transformations (W) and non-linearities (σ). The backpropagation path from the top-level features h (K) v of a node to its initial features h (0) v is σ(W (K) (σ(W (K-1) . . . (σ(W (1) h (0) v ) . . . )) ). With many layers (i.e., when K is large), the gradient magnitude diminishes at every step because of the multiplication with the derivative of the activation function, which is always < 1 for some choices. Additionally, depending on the values of the weights W, repeated multiplication with the weights may cause the gradients to either vanish or explode. This could become especially problematic when sharing weights between layers, which could be useful to combat over-parameterization in deeper networks, which are necessary to encode multiple hops. These problems can be mitigated by the choice of activation function (such as using a ReLU), using good initializations for the weights, and using normalization layers. LSTMs (Hochreiter & Schmidhuber, 1997 ), GRUs (Cho et al.), ResNets (He et al., 2016 ), Highway Networks (Srivastava et al., 2015 ), and DenseNets (Huang et al., 2017) propose an alternative solution by introducing some form of skip connection around one or more layers, such that every layer contributes an update additively. This ensures that the gradient to deeper layers does not vanish, and also that the gradient w.r.t. features (or states) in different layers (or timesteps) is similar (GRUs, Highway Network) or the same (ResNets). In fact, He et al. (2016) point out that the improvement they obtain is not due to gradient magnitudes and attribute the effect to the fact that modeling the residual error is an easier task for the network.



We need at least K GNN layers to capture information that is K hops away.

