RELATIONAL ATTENTION: GENERALIZING TRANSFORMERS FOR GRAPH-STRUCTURED TASKS

Abstract

Transformers flexibly operate over sets of real-valued vectors representing taskspecific entities and their attributes, where each vector might encode one wordpiece token and its position in a sequence, or some piece of information that carries no position at all. But as set processors, standard transformers are at a disadvantage in reasoning over more general graph-structured data where nodes represent entities and edges represent relations between entities. To address this shortcoming, we generalize transformer attention to consider and update edge vectors in each transformer layer. We evaluate this relational transformer on a diverse array of graph-structured tasks, including the large and challenging CLRS Algorithmic Reasoning Benchmark. There, it dramatically outperforms state-of-theart graph neural networks expressly designed to reason over graph-structured data. Our analysis demonstrates that these gains are attributable to relational attention's inherent ability to leverage the greater expressivity of graphs over sets.

1. INTRODUCTION

Figure 1 : The relational transformer (RT) outperforms baseline GNNs on a set of 30 distinct graph-structured tasks from CLRS-30, averaged by algorithm class. Graph-structured problems turn up in many domains, including knowledge bases (Hu et al., 2021; Bordes et al., 2013) , communication networks (Leskovec et al., 2010) , citation networks (McCallum et al., 2000) , and molecules (Debnath et al., 1991; Zhang et al., 2020b) . One example is predicting the bioactive properties of a molecule, where the atoms of the molecule are the nodes of the graph and the bonds are the edges. Along with their ubiquity, graph-structured problems vary widely in difficulty. For example, certain graph problems can be solved with a simple multi-layer perceptron, while others are quite challenging and require explicit modeling of relational characteristics. Graph Neural Networks (GNNs) are designed to process graphstructured data, including the graph's (possibly directed) edge structure and (in some cases) features associated with the edges. In particular, they learn to represent graph features by passing messages between neighboring nodes and edges, and updating the node and (optionally) edge vectors. Importantly, GNNs typically restrict message passing to operate over the edges in the graph. Standard transformers lack the relational inductive biases (Battaglia et al., 2018) that are explicitly built into the most commonly used GNNs. Instead, the transformer fundamentally consumes unordered sets of real-valued vectors, injecting no other assumptions. This allows entities carrying domain-specific attributes (like position) to be encoded as vectors for input to the same transformer architecture applied to different domains. Transformers have produced impressive results in a wide variety of domains, starting with machine translation (Vaswani et al., 2017) , then quickly impacting language modeling (Devlin et al., 2019 ) and text generation (Brown et al., 2020) . They are revolutionizing image processing (Dosovitskiy et al., 2021) and are being applied to a growing variety of settings including reinforcement learning (RL), both online (Loynd et al., 2020; Parisotto et al., 2020) and offline (Chen et al., 2021; Janner et al., 2021) . Many of the domains transformers succeed in consist of array-structured data, such as text or images. By contrast, graph data is centrally concerned with pairwise relations between entities, represented as edges and edge attributes. Graphs are more general and expressive than sets, in the sense that a set is a special case of a graph-one without edges. So it is not immediately obvious how graph data can be processed by transformers in a way that preserves relational information. Transformers have been successfully applied to graph-structured tasks in one of two broad ways. Certain works, most recently TokenGT (Kim et al., 2022) , encode graphs as sets of real-valued vectors passed to a standard transformer. Other works change the transformer architecture itself to consider relational information, e.g. by introducing relative position vectors to transformer attention. We discuss these and many other such approaches in Section 4. Our novel contribution is relational attention, a mathematically elegant extension of transformer attention, which incorporates edge vectors as first-class model components. We call the resulting transformer architecture the Relational Transformer (RT). As a native graph-to-graph model, RT does not rely on special encoding schemes to input or output graph data. We find that RT outperforms baseline GNNs on a large and diverse set of difficult graph-structured tasks. In particular, RT establishes dramatically improved state-of-the-art performance (Figure 1 ) over baseline GNNs on the challenging CLRS-30 (Veličković et al., 2022) , which comprises 30 different algorithmic tasks in a framework for probing the reasoning abilities of graph-to-graph models. To summarize our main contributions: • We introduce the relational transformer for application to arbitrary graph-structured tasks, and make the implementation available at https://github.com/CameronDiao/ relational-transformer. • We evaluate the reasoning power of RT on a wide range of challenging graph-structured tasks, achieving new state-of-the-art results on CLRS-30. • We enhance the CLRS-30 framework to support evaluation of a broader array of models (Section 5.1.2). • We improve the performance of CLRS-30 baseline models by adding multi-layer functionality, and tuning their hyperparameters (Section 5.1.1).

2. GRAPH NEURAL NETWORKS

We introduce the graph-to-graph model formalism used in the rest of this paper, inspired by Battaglia et al. (2018) . The input graph is a directed, attributed graph G = (N , E), where N is an unordered set of node vectors n i ∈ R dn , i denoting the i-th node. E is a set of edge vectors e ij ∈ R de , where directed edge (i, j) points from node j to node i. Each layer l in the model accepts a graph G l as input, processes the graph's features, then outputs graph G l+1 with the same structure as G l , but with potentially updated node and edge vectors. Certain tasks may also include a single global vector with the input or output graph, but we omit those details from our formalism since they are not what distinguishes the approaches described below. Each layer l of a graph-to-graph model is comprised of two update functions ϕ and an aggregation function , Node vector n l+1 i = ϕ n   n l i , j∈Li ψ m e l ij , n l i , n l j   Edge vector e l+1 ij = ϕ e e l ij , e l ji , n l+1 i , n l+1 j (2) where L i denotes the set of node i's neighbors (optionally including node i), and ψ m denotes a message function. The baseline GNNs listed below let ϕ e be the identity function such that e l+1 ij = e l ij for all edges (i, j). Furthermore, they use a permutation-invariant aggregation function for . The following details are from Veličković et al. (2022) : In Deep Sets (Zaheer et al., 2017) , the only edges are self-connections so each L i is a singleton set containing node i. In Graph Attention Networks (GAT) (Veličković et al., 2018; Brody et al., 2022) , is self-attention, and the message function ψ m merely extracts the sender features: ψ m (e l ij , n l i , n l j ) = W m n l j , where W m is a weight matrix. In Message Passing Neural Networks (MPNN) (Gilmer et al., 2017) , edges lie between any pair of nodes and is the max pooling operation. In Pointer Graph Networks (PGN) (Veličković et al., 2020) , edges are constrained by the adjacency matrix and is the max pooling operation.

3. RELATIONAL TRANSFORMER

We aim to design a mathematically elegant extension of transformer attention, which incorporates edge vectors as first-class model components. This goal leads us to the following design criteria: 1. Preserve all of the transformer's original machinery (though still not fully understood), for its empirically established advantages. 2. Introduce directed edge vectors to represent relations between entities. 3. Condition transformer attention on the edge vectors. 4. Extend the transformer layer to consume edge vectors and produce updated edge vectors. 5. Preserve the transformer's O N 2 computational complexity.

3.1. RELATIONAL ATTENTION

(See Appendix A for a mathematical overview of transformers.) In addition to accepting node vectors representing entity features (as do all transformers), RT also accepts edge vectors representing relation features, which may include edge-presence flags from an adjacency matrix. But RT operates over a fully connected graph, unconstrained by any input adjacency matrix. Transformer attention projects QKV vectors from each node vector, then computes a dot-product between each pair of vectors q i and k j . This dot-product determines the degree to which node i attends to node j. Relational attention's central innovation (illustrated in Figure 3 ) is to condition the QKV vectors on the directed edge e ij between the nodes, by concatenating that edge vector with each node vector prior to the linear transformations: q ij = [n i , e ij ]W Q k ij = [n j , e ij ]W K v ij = [n j , e ij ]W V where each weight matrix W is now of size R (dn+de)×dn , and d e is the edge vector size. To implement this efficiently and exactly, we split each weight matrix W into two separate matrices for projecting node and edge vectors, project the edge vector to three embeddings, then add those to the node's usual attention vectors: q ij = n i W Q n + e ij W Q e k ij = n j W K n + e ij W K e v ij = n j W V n + e ij W V e (4) While we have described relational attention in terms of fully connected self-attention, it applies equally to restricted forms of attention such as causal attention, cross-attention, or even restricted GAT-like attention that passes messages only over the edges present in a graph's adjacency matrix. Relational attention is compatible with multi-head attention, and leaves the transformer's O N 2 complexity unchanged. Our implementation maintains the high GPU utilization that makes transformers efficient.

3.2. EDGE UPDATES

To update edge vectors in each layer of processing, we follow the general pattern used by transformers to update node vectors: first aggregate messages into one, then use the result to perform a local update. In self-attention each node attends to all nodes in the graph. But having each of the N 2 edges attend to N nodes would raise computational complexity to O N 3 . Instead, we restrict the edge's aggregation function to gather messages only from its immediate locale, which consists of its two adjoining nodes, itself, and the directed edge running in the opposite direction: e l+1 ij = ϕ e e l ij , e l ji , n l+1 i , n l+1 j (5) We compute the aggregated message m l ij by first concatenating the four neighbor vectors, then applying a single linear transformation to the concatenated vector, followed by a ReLU non-linearity: m l ij = ReLU concat e l ij , e l ji , n l+1 i , n l+1 j W 4 where W 4 ∈ R (2de+2dn)×d eh1 , and the non-linear ReLU operation takes the place of the non-linear softmax in regular attention. The remainder of the edge update function is essentially identical to the transformer node update function: u l ij = LayerNorm m l ij W 5 + e l ij e l+1 ij = LayerNorm ReLU(u l ij W 6 )W 7 + u l ij (7) where W 5 ∈ R d eh1 ×de , W 6 ∈ R de×d eh2 , W 7 ∈ R d eh2 ×de , and d eh1 and d eh2 are the hidden layer sizes of the edge feed-forward networks. Node vectors are updated before edge vectors within each RT layer, so that a node aggregates information from the entire graph before its adjoining edges use that information in their local updates. This is another instance of the aggregate-then-update pattern employed by both transformers and GNNs for node vectors.

4. PRIOR WORK

We categorize, then discuss prior works based on their use of transformer machinery. We further divide each category based on edge vector incorporation, highlighting works that use full node-edge round-tripping, here defined as the process in which edge vectors directly condition node updates and node vectors directly condition edge updates. See Figure 2 for visual comparisons. Encoding Graphs to Sets For clarity, we define the transformer architecture to exclude any modules that encode inputs to the transformer or decode its outputs. Several prior works have made progress on the challenge of applying standard transformers (Vaswani et al., 2017) to graph-structured tasks by representing graphs as sets of tokens, e.g. with positional encodings. TokenGT (Kim et al., 2022) , for example, treats all nodes and edges of a graph as independent tokens, augmented with token-wise embeddings. Graphormer (Ying et al., 2021) and Graph-BERT (Zhang et al., 2020a) introduce structural encodings that are applied to each node prior to transformer processing. GraphTrans (Wu et al., 2021) and ReFormer (Yang et al., 2022) perform initial convolutions or message passing before the transformer module. Relative Position Vectors A number of prior works have modified self-attention to implement relative positional encodings (Ying et al., 2021; Cai & Lam, 2020b; Shaw et al., 2018; Hellendoorn et al., 2020; Dai et al., 2019) . To compare these formulations with RT, we expand relational attention's dot-product into four terms as follows: q ij k ⊤ ij = [n i , e ij ]W Q [n j , e ij ]W K ⊤ (8) = n i W Q n + e ij W Q e n j W K n + e ij W K e ⊤ (9) = n i W Q n + e ij W Q e n j W K n ⊤ + e ij W K e ⊤ (10) = n i W Q n n j W K n ⊤ + n i W Q n e ij W K e ⊤ + e ij W Q e n j W K n ⊤ + e ij W Q e e ij W K e ⊤ The transformer of Vaswani et al. (2017) employs only the first term. The transformer of Shaw et al. (2018) adds part of the second term, leaving out one weight matrix W K e , and GREAT (Hellendoorn et al., 2020) adds the entire third term. The Transformer-XL (Dai et al., 2019) and Graph Transformer of Cai & Lam (2020b) use parts of all four terms, but leave out two of the eight weight matrices W Q e and W K e . Each work above uses edge vectors only for relative positional information. RT employs all four terms, allows the edge vectors to represent arbitrary information depending on the task, and updates the edge vectors in each layer of computation. Transformers With Restricted Node-Edge Communication A few prior graph-to-graph transformers restrict round-trip communication between nodes and edges. EGT (Hussain et al., 2022) does a form of node-edge round-tripping but introduces single-scalar bottlenecks (per-edge, per-head). The Graph Transformer of Dwivedi & Bresson (2021) , SAT (Chen et al., 2022) , and SAN (Kreuzer et al., 2021) condition node attention coefficients on edge vectors, but they do not explicitly condition node vector updates on edge vectors. GRPE (Cai & Lam, 2020a ) also conditions node attention on edge vectors, e.g. by adding edge vectors to the node value vectors. But the edge vectors themselves are not explicitly updated using node vectors. Bergen et al. (2021) introduce the Edge Transformer which replaces standard transformer attention with a triangular attention mechanism that takes edge vectors into account and updates the edge vectors in each layer. This differs from RT in three important respects. First, triangular attention is a completely novel form of attention, unlike relational attention which is framed as a natural extension of standard transformer attention. Second, triangular attention ignores node vectors altogether, and thereby requires node input features and node output predictions to be somehow mapped onto edges. And third, triangular attention's computational complexity is O N 3 in the number of nodes, unlike RT's relational attention which maintains the O N 2 transformer complexity. Like Edge Transformer, Nodeformer (Wu et al., 2022) employs a novel form of attention, but in this case with O (N ) complexity. Nodeformer does not perform node-edge round-tripping, and introduces single-scalar bottlenecks per-edge. Another transformer, GraphGPS (Rampasek et al., 2022) , is described by the authors as an MPNN+Transformer hybrid, which does support full node-edge round-tripping. Unlike RT, GraphGPS represents a significant departure from the standard transformer architecture.

Highly Modified Transformers

Finally, attentional GNNs, such as GAT (Veličković et al., 2018) , GATv2 (Brody et al., 2022 ), Edgeformer (Jin et al., 2023) , kgTransformer (Liu et al., 2022) , Relational Graph Transformer (Feng et al., 2022) , HGT (Hu et al., 2020) , and Simple-HGN (Lv et al., 2021) aggregate features across neighborhoods based on transformer-style attention coefficients. However, unlike transformers, attentional GNNs only compute attention over input edge vectors, and (except in Edgeformer and kgTransformer) the edge vectors are not updated in each layer. In particular, kgTransformer, Relational Graph Transformer, HGT, and Simple-HGN modify transformer attention to consider hetereogeneous structures in the graph data, such that the model can differentiate between types of nodes and edges. Other GNNs MPNN (Gilmer et al., 2017) is a popular GNN that accepts entire edge vectors as input, as do some other works such as MXMNet (Zhang et al., 2020c) and G-MPNN (Yadati, 2020) . But apart from EGNN (Gong & Cheng, 2019) and Censnet (Jiang et al., 2019) , relatively few GNNs update the edge vectors themselves. None of these GNNs use the full transformer machinery, and in general many GNNs are designed for specific settings, such as quantum chemistry. Unlike any of these prior works, RT preserves all of the original transformer machinery, while adding full bidirectional conditioning of node and edge vector updates.

5. EXPERIMENTS

We evaluate RT against common GNNs on the diverse set of graph-structured tasks provided by CLRS-30 (Veličković et al., 2022) , which was designed to measure the reasoning abilities of neural networks. This is a common motivation for tasking neural networks to execute algorithms (Zaremba & Sutskever, 2014; Kaiser & Sutskever, 2015; Trask et al., 2018; Kool et al., 2019) . RT outperforms baseline GNNs by wide margins, especially on tasks that require processing of node relations (Section 5.1.5). We further evaluate RT against GNNs on the end-to-end shortest paths task provided by Tang et al. (2020) , where again RT outperforms the baselines (Section 5.2). Our final experiment (Appendix B) evaluates RT against a standard transformer on a reinforcement learning task where no graph structure is provided. We find that RT decreases error rates of the RL agent significantly.

5.1. STEP-BY-STEP REASONING

In CLRS-30, each step in a task is framed as a graph-to-graph problem, even for algorithms that may seem unrelated to graphs. To give an example, for list sorting algorithms, each input list element is treated as a separate node and predecessor links are added to order the elements. Task data is organized into task inputs, task outputs, and 'hints', which are intermediate inputs and outputs for the intervening steps of an algorithm. Data is comprised of combinations of node, edge, and/or global features, which can be of five possible types: scalars, categoricals, masks, binary masks, or pointers. CLRS-30 employs an encode-process-decode framework for model evaluation. Input features are encoded using linear layers, then passed to the model (called a processor) being tested. The model performs one step of computation on the inputs, then it outputs a set of node vectors, which are passed back to the model as input on the next step. On each step, the model's output node vectors are decoded by the framework (using linear layers) and compared with the targets (either hints or final outputs) to compute training losses or evaluation scores. Certain CLRS-30 tasks provide a global vector with each input graph. As per CLRS-30 specifications, the baseline GNNs handle global vectors by including them as messages in each update step. They do not propagate global vectors through the steps of the algorithm. RT can use two different methods for handling global vectors and we evaluate both in Section 5.1.4.

5.1.1. BASELINE GNNS

We began by reproducing the published results of key baseline models on the eight representative tasks (one per algorithm class) listed in Figure 3 of Veličković et al. (2022) and in our Table 13 . For several of the following experiments, we refer to these as the 8 core tasks. Our results on these tasks agree closely with the published CLRS-30 results (Table 11 ). See Appendix D for details of our train/test protocol. We chose not to include Memnet in our experiments given our focus on standard GNNs, and given Memnet's poor performance in the original CLRS-30 experiments. Missing detailsfoot_0 made it impossible to reproduce the published GAT results on CLRS-30. The published CLRS-30 results show sharp drops in out-of-distribution (OOD) performance for all models. For instance, MPNN's average evaluation score drops from 96.63% on the validation set to 51.02% on the test set. We note that small training datasets can induce overfitting even in models that are otherwise capable of generalizing to OOD test sets. To mitigate this spurious form of overfitting, We found in early experiments that RT obtained far better results than those of the CLRS-30 baseline GNNs. So to further enhance the performance of the baseline GNNs, we extended them to support multiple layers (rounds of message passing) per algorithmic step, and thoroughly tuned their hyperparameters (see Appendix C). This significantly improved the baseline results (see Table 10 ). See Table 12 for results comparing CLRS-30 baseline model performances with and without our proposed changes. On certain tasks, baseline score variance increased along with mean scores. For example, on the Jarvis' March task, tuning raised the score of MPNN from 22.99 ± 3.87% to 59.31 ± 29.3%. See Appendix G for detailed analysis of score variance.

5.1.2. ENHANCEMENTS OF CLRS-30'S FRAMEWORK

Most GNNs consume and produce node vectors, and many also consume edge vectors or edge types. However, relatively few GNNs (and none of the CLRS-30 baseline models) are designed to output edge vectors. Because of this, the CLRS-30 framework does not support edge vector outputs from a processor network. To test models such as RT that have these abilities, we extended the CLRS-30 framework to accept edge and global vectors from the processor at each step, and pass these vectors back to the processor as input on the next step. The framework handles node vectors as usual. In framing algorithms as graph-to-graph tasks, CLRS-30 relies heavily on what it terms a node pointer, which is conceptually equivalent to a directed edge pointing from one node to another. Since the CLRS-30 baseline models do not output edge vectors, a decoder in the CLRS-30 framework uses the model's output node vectors to create node pointers. But for models like RT that output edge vectors, it is more natural to decode node pointers from those edge vectors alone. To better support such models, we added a flag to enable this modified behavior in the CLRS-30 framework.

5.1.3. MAIN RESULTS

After tuning hyperparameters for all models (Appendix C), we evaluated RT against the six baseline GNNs on all CLRS-30 tasks, using 20 seeds. The full results are presented in Table 19 . RT outperforms the top-scoring baseline model (MPNN) by 11% overall. As bolded in the table, RT scores the highest on 11 out of 30 tasks. RT is also the best-performing model on 6 of 8 algorithmic classes (Table 1 ), and scores the highest when results are averaged over those classes (Figure 1 ). See Table 8 for the algorithm-class mappings. For convenience, Figure 1 includes the prior results (labeled as MPNN-pr and DeepSets-pr) for single-layer GNNs from Veličković et al. (2022) . In summary, RT significantly outperforms all baseline GNN models over the CLRS-30 tasks.

5.1.4. ABLATIONS

Using only the 8 core tasks (except where noted), we perform several ablations to analyze the factors behind RT's solid performance on CLRS-30. Transformer -We compare RT to a standard, set-based transformer (Vaswani et al., 2017) by disabling edge vectors and features in RT. Table 13 shows that performance collapses by almost 40% without edge vectors and relational attention, even after re-tuning its hyperparameters. Layers -The tuned RT uses three layers of computation per algorithmic step. When restricted to a single layer, performance drops drastically (Table 14 ), even after re-tuning the other hyperparameters. However, single-layer RT still outperforms the top-scoring MPNN by 10.69% on the 8 core tasks, suggesting that relational attention improves expressivity even when restricted to a single layer of computation. Global vector -Many CLRS-30 tasks provide a global feature vector as input to the processor model. We designed RT to handle this global vector by either concatenating it to each input node vector, or by passing it to a dedicated core node (Loynd et al., 2020; Guo et al., 2019) . Hyperparameter tuning chose concatenation instead of the core node option, so concatenation was used in all experiments. But in this ablation, the core-node method obtained slightly higher test scores on 7 tasks that use global vectors as inputs to the processor (Table 16 ). The score difference of 0.08% was marginal, providing no empirical basis to prefer one method of handling the global vector over the other. Node pointer decoding -We assess impact of the flag we added to the CLRS-30 framework, which can be used to decode node pointers from edge vectors only. Compared to using the original decoding procedure, using the flag improved performance by a small amount (0.30%) (Table 15 ). Disabling edge updates -We disable edge updates in RT such that RT relies solely on relational attention to process input features. Table 17 shows the resulting drop in performance, from 81.30% to 53.99%. This indicates that edge updates are crucial for RT's learning of relational characteristics in the graph data. As a final note, RT without edge updates still outperforms the transformer by 11.65%, demonstrating the effectiveness of relational attention even without updated edges.

5.1.5. ALGORITHMIC ANALYSIS

We investigate the reasoning power of RT based on its test performance on specific algorithmic tasks. We only provide possible explanations here, in line with previous work (Veličković et al., 2022; Veličković et al., 2020; Xu et al., 2020) . Underperformance -The greedy class is one of two where RT is outperformed (by just one other model). The two greedy tasks, activity selector and task scheduling, require selecting node entities that minimize some metric at each step. For example, in task scheduling, the optimal solution involves repeatedly selecting the task with the smallest processing time. The selection step is aligned with max pooling in GNNs: Veličković et al. ( 2020) demonstrate how max pooling aligns with making discrete decisions over neighborhoods. Here, each neighborhood represents a set of candidate entities to be selected from. MPNN, PGN-u, and PGN-m all perform max pooling at each step of message passing. On the other hand, RT performs soft attention pooling, which does not align with the discrete decision-making required to execute greedy algorithms. This may explain RT's underperformance on activity selector and task scheduling, as well as Prim's and Kruskal's. Overperformance -RT overwhelmingly beats baseline GNNs on dynamic programming (DP) tasks. This is surprising, considering that GNNs have been proven to align well with dynamic programming routines (Dudzik & Veličković, 2022) . To explain RT's overperformance, we consider 1) edge updates and 2) relational attention. For 1), Ibarz et al. (2022) observe that several algorithms in CLRS-30, especially those categorized as DP, require edge-based reasoning-where edges store values, and update those values based on other edges' values. These algorithms do not use node representations in their update functions, yet the baseline GNNs can only learn these update functions using message passing between node representations. On the other hand, RT directly supports edge-based reasoning by representing and updating edges. The hypothesis that RT actually uses this ability is supported by the fact that RT beats baseline GNNs on most of the 6 edge-centric tasks (Find Maximum Subarray, Insertion Sort, Matrix Chain Order, and Optimal BST), though not on the other 2 (Dijkstra and Floyd-Warshall). For 2), recall from Section 3.1 that relational attention is an extension of standard transformer attention. Standard attention itself is a specific instance of the message-passing function described in Dudzik & Veličković (2022) , which is part of the author's framework for aligning GNNs and DP routines. To see how this is the case, the reader can compare our equation equation 12 with equation 1 in Dudzik & Veličković (2022) . From this comparison, and from 1), we expect RT to perform well on both of the edge-centric DP tasks Matrix Chain Order and Optimal BST. We find that RT obtains best scores on both of them, by 5.24% and 3.66% respectively. 4, 34) and evaluating models on graphs of size 100. Furthermore, all models use 30 layers. Results are reported using the relative loss metric introduced by the authors, defined as |y-ŷ|/|y| given a label y and a prediction ŷ. We compare RT to their two baselines, Graph Convolution Network or GCN (Kipf & Welling, 2017) and GAT (Veličković et al., 2018) . Results are averaged across 20 random seeds. RT outperforms both baselines with an average relative loss of 0.22, compared to GCN's 0.45 and GAT's 0.28 (Figure 4 ). These results were obtained without using the iterative module proposed by Tang et al. (2020) that introduces a stopping criterion to message passing computations.

6. CONCLUSION AND FUTURE WORK

We propose the relational transformer (RT), an elegant extension of the standard transformer to operate on graphs instead of sets. It incorporates edge information through relational attention in a principled and computationally efficient way. Our experimental results demonstrate that RT performs consistently well across a diverse range of graph-structured tasks. Specifically, RT outperforms baseline GNNs on CLRS-30 by wide margins, and also outperforms baseline models on end-to-end algorithmic reasoning. RT even boosts transformer performance on the Sokoban task, where graph structure is entirely hidden and must be discovered by the RL agent. Beyond establishing RT's state-of-the-art results on CLRS-30, we enhance performance of the CLRS-30 baseline models, and contribute extensions to the CLRS-30 framework, broadening the scope of models that can be evaluated on the benchmark tasks. All of these improvements make CLRS-30 tasks and baselines more appealing for evaluating current and future models. In general, comparing GNNs with transformer-based models like RT on common benchmarks is an important challenge for the community. We have made progress on that challenge by rigorously evaluating RT against standard baseline GNNs on the large and challenging CLRS-30 benchmark, but we leave experiments with other transformer-based approaches for future work. One difficulty is the fact that the CLRS-30 framework is written in Jax, and few if any of these transformers have Jax implementations available. But we have improved the CLRS benchmark itself to make such comparisons more practical in the future. In future work, we aim to fully leverage the richness of CLRS-30 to more thoroughly investigate RT's capabilities. For example, recent extensions to the CLRS framework (Ibarz et al., 2022) allow us to task RT with executing several algorithms simultaneously, which requires knowledge transfer between algorithms. We also plan to evaluate RT on a wider range of real-world graph settings, such as the molecular domain using the large-scale QM9 dataset (Wu et al., 2018) . Finally, we aim to relax the locality bottleneck of RT's edge updates by allowing edges to attend to other edges directly in a computationally efficient manner.

A TRANSFORMERS

We describe the transformer architecture introduced by Vaswani et al. (2017) . This description also applies to most transformer variants proposed over the years. Layer superscripts are employed to distinguish input vectors from output vectors, and are often omitted for vectors inside the same layer. Although the transformer is a set-to-set model, it can be described using our graph-to-graph formalism as limited to computation over nodes only. Each transformer layer is a function passing updated node vectors to the next layer. A single transformer layer can therefore be expressed as a modified version of equation 1: n l+1 i = ϕ n   n l i , j∈Li a n l i , n l j ψ m n l j   ( ) where a n l i , n l j computes the attentional coefficient α l ij applied by node i to the value vector v l j , which is computed by ψ m n l j , a linear transformation: α l ij = a n l i , n l j v l j = ψ m n l j = n l j W V The attentional coefficients applied by node i to the set of all nodes j is a probability distribution a i computed by the softmax function over a set of vector dot products: a i = softmax j q i k T j √ d n The QKV vectors introduced above are linear transformations of node vectors: q i = n i W Q k j = n j W K v j = n j W V where W Q ∈ R dn×dn , W K ∈ R dn×dn , W V ∈ R dn×dn , and d n is the node vector size. These W matrices (like all other trainable parameters in ϕ n ) are not shared between transformer layers. The aggregation function sums the incoming messages from all nodes L i in the completely connected graph: m l i = j∈Li a n l i , n l j ψ m n l j = j α l ij v l j This aggregated message m l i is then passed to the local update function ϕ n (shared by all nodes), which is the following stack of linear layers, skip connections, layer normalization and a ReLU activation function: u l i = LayerNorm m l i W 1 + n l i ( ) n l+1 i = LayerNorm ReLU(u l i W 2 )W 3 + u l i ( ) where W 1 ∈ R dn×dn , W 2 ∈ R dn×d nh , W 3 ∈ R d nh ×dn , and d nh is the hidden layer size of the feed-forward network. This overview of the transformer architecture has focused on the fully connected case of self-attention. For brevity we have omitted the details of multi-head attention, bias vectors, and the stacking of vectors into matrices for maximal GPU utilization.

B SOKOBAN EXPERIMENTS

In the experiments described in the main text, the model received edge feature vectors as inputs. The question we pose here is whether RT's latent edge vectors can improve reasoning ability even on tasks with graph structure that is hidden, rather than passed to the model in the form of edge vectors. We use the Sokoban (Guez et al., 2019) reinforcement learning task to investigate. In Sokoban, the agent must push four yellow boxes onto the red targets within 120 time steps. Humans solving these -5, 6.3e-5, 1e-4, 1.6e-4, 2.5e-4, 4e-4, 6.3e-4, 1e-3, 1.6e-3, 2.5e-3, 4e-3, 6.3e-3, 1e -2 dn = de = dg 45, 64, 90, 128, 180, 256, 360, 512 nb heads 3, 4, 6, 8, 10, 12, 16 head size 8, 12, 16, 24, 32, 45, 64 d nh 4, 6, 8, 12, 16, 24, 32, 45, 64, 90 d eh1 12, 16, 24, 32, 45, 64 d eh2 4, 6, 8, 12, 16, 24, 32 , 45 ptr from edges false, true graph vec core, cat 

E TEST RESULTS ON CLRS-30

Test performance of all tuned models evaluated on CLRS-30 may be found in Tables 19 (mean test micro-F1 score) and 18 (standard deviation). On certain tasks in which baseline GNN performances histograms. We see that RT's twenty runs on this task obtain a wide range of scores, including a spike in the range from 80% to 100%, while none of the baseline model runs surpass a score of 20%. This explains RT's high variance on this task compared to the other models. In summary, in a case like this higher variance is a consequence of higher performance. The same behavior can be inferred without the histograms by examining Table 18 and Table 19 , which show that RT's high variance on Naïve String Matcher is associated with far higher mean performance on the task than the baseline models. The same pattern is apparent for RT's next two highest-variance tasks, Topological Sort and Quickselect. For a broader view of variance, Figure 6 displays the model score distributions over all CLRS-30 tasks. We see that RT's distribution is weighted more heavily than the other models in the >80% range, and underweighted in the <20% range. Meanwhile, all model distributions are spread quite widely over the five range bins. This spread is quantified numerically by the last row of Table 18 , where RT's overall standard deviation is shown to be one of the smallest. Table 18 : Standard deviations over 20 seeds for all tuned models on all algorithms. Each value in the row "Over All Runs" is not an average of variances for each algorithm, but rather the variance across all runs. 



See the final comment in https://github.com/deepmind/clrs/issues/92



Figure 2: Categories of GNNs and Transformers, compared in terms of transformer machinery and edge vector incorporation. Model categories tested in our experiments are marked in bold.

Figure 3: (Left) Standard transformer attention conditions the QKV computation on node vectors. (Right) Relational attention conditions this computation on the intervening edge vector as well.

END-TO-END ALGORITHMIC REASONING CLRS-30 evaluates reasoning ability by examining how closely a model can emulate each step of an algorithm. But we may also evaluate reasoning ability by training a model to execute an algorithm end-to-end(Xu et al., 2020). We use the task provided byTang et al. (2020) to evaluate RT in this way. Specifically, we task RT with finding a shortest path distance between two nodes in an undirected lobster graph. The node features are one-hot encoded for source, destination, and remaining. The edge features are binary presence values.

Figure 4: Average Relative Loss on Shortest PathsMain ResultsWe use the same experiment settings asTang et al. (2020) for the shortest path task. Importantly, we use their method for testing graph size generalizability by training models on graphs of size [4, 34) and evaluating models on graphs of size 100. Furthermore, all models use 30 layers. Results are reported using the relative loss metric introduced by the authors, defined as |y-ŷ|/|y| given a label y and a prediction ŷ. We compare RT to their two baselines, Graph Convolution Network or GCN(Kipf & Welling, 2017) and GAT(Veličković et al., 2018). Results are averaged across 20 random seeds. RT outperforms both baselines with an average relative loss of 0.22, compared to GCN's 0.45 and GAT's 0.28 (Figure4). These results were obtained without using the iterative module proposed byTang et al. (2020) that introduces a stopping criterion to message passing computations.

Figure 5: Histograms describing the variance in model results on Naïve String Matcher, across runs

Mean test scores of all tuned models on the eight CLRS-30 algorithm classes.we expanded the training datasets by 10x to 10,000 examples generated from the same canonical random seed of 1, and evaluated the effects on the 8 core tasks. As shown in Table9, expanding the training set significantly boosts the performance of all baseline models. For all of the other experiments in this work, we use the larger training sets of 10,000 examples.FollowingVeličković et al. (2022), we compute results for two separate PGN models, masked (PGNm) and unmasked (PGN-u), then select the best result on each task to compute the average shown for the combination PGN-c model (which is called PGN in the CLRS-30 results). Note therefore that PGN-c does not represent a single model. But it does represent the performance that would be achievable by a PGN model that adaptively learned when to use masking.

Hyperparameter values considered for CLRS-30 experiments.

Number of trainable parameters in each model tested on CLRS-30, including the framework's encoding and decoding layers, on the reference algorithm Bellman Ford.

Training speed in examples per second on a T4 GPU, on the reference algorithm Bellman Ford.



lists the tuned hyperparameter values for RT on the Sokoban task, along with the sets of values considered in that search.D TRAIN/TEST PROTOCOLExcept where specifically noted, our experiments follow the exact train/test protocol defined by CLRS-30. CLRS-30 provides canonical datasets (training, validation, and test) which can also be generated from specific random seeds: 1, 2, 3. The graphs in the training and validation datasets contain 16 nodes, while the test graphs are of size 64 to evaluate the out-of-distribution (OOD) generalization of models. Training is performed on a random sequence of 320,000 examples drawn with replacement from the train set, where each example trajectory contains a variable number of reasoning steps. During training, the model is evaluated on the validation set after every 320 examples. At the end of training, the model with the highest validation score is evaluated on the test set. The average test micro-F1 score is reported for all results. The published CLRS-30 results are reported as averages over 3 random run seeds, but we use 20 seeds in all of our experiments.

Algorithms and their respective classes and task types. Taking hint targets into consideration, there are 3 task types in CLRS-30: node (N), edge (E), and graph (G).

Average test scores of untuned baseline models trained on either the original small datasets, versus trained on the expanded training datasets, for the 8 core tasks.

Test score improvements from hyperparameter tuning on the 8 core tasks. The drop in score for PGN-m was likely the result of tuning hyperparameters (for all models) on shorter training runs than was used for evaluation.

Reproduction of baseline model results on the 8 core algorithms.

Test score improvements from augmenting the CLRS-30 baseline models.

Standard transformer ablation (with re-tuned hyperparameters) evaluated on the 8 core algorithms.

Single-layer RT ablation (with re-tuned hyperparameters) evaluated on the 8 core algorithms.

Ablation of the node pointer decoding procedure evaluated on the 8 core algorithms.

Ablation of RT core node (vs. concatenation of the global vector) evaluated on 7 representative algorithms that use global input features or hints.

Ablation of RT's edge update procedure evaluated on the 8 core algorithms

Average test scores of all tuned models on all algorithms.

ACKNOWLEDGMENTS

We wish to thank our many collaborators for their valuable feedback, including Roland Fernandez, who also provided the XT ML tool that made research at this scale possible.

annex

puzzles tend to plan out which boxes will go onto which targets. Assuming that a successful agent will learn a similar strategy, representing each box-to-target pair as a directed relation, we hypothesize that RT is more capable than a standard transformer at reasoning over such pairwise relations. For evaluation, we use RT in place of the standard transformer originally used by the Working Memory Graph (WMG) RL agent (Loynd et al., 2020) , then train both modified and unmodified agents on the Sokoban task for 10 million time steps.

Main Results

We find that using RT reduces the agent's final error rate by a relative 15%, from 34% to 29% of the puzzles. This improvement is much larger than the confidence intervals (0.3% and 0.7% standard error, respectively). Our results support the hypothesis that RT can learn even hidden graph structure. One alternative explanation would be that the extra trainable parameters added by RT simply improved the expressivity of WMG. But this seems unlikely since hyperparameter tuning of the unmodified WMG agent (Loynd et al., 2020) converged to an intermediate model size, rather than a larger model for more expressivity.

C HYPERPARAMETERS

To tune the hyperparameters of RT and the CLRS-30 baseline GNNs, we used Distributed Grid Descent (DGD) (Loynd et al., 2020) , a self-guided form of random search. Each search was terminated after model performance converged to a stable value. Then 20 additional runs were executed, using the winning hyperparameter configuration, to obtain results free from selection bias. All tuning runs used the CLRS-30 protocol described in D, except for the following details:1. To prevent tuning on the canonical datasets or any fixed datasets at all, the dataset generation seeds were randomized at the start of each run.2. To reduce variance, the minimum evaluation dataset size was raised from 32 to 100.3. To mitigate the computational costs, all models were tuned on only the 8 core algorithms, and each training run was shortened to 32,000 examples.Very similar procedures were used to tune RT hyperparameters for the other (non-CLRS-30) experiments. For all experiments, all untuned hyperparameter values were chosen to match the settings of the corresponding baseline models.C.1 CLRS-30 EXPERIMENTS 3 reports the sets of values considered in those searches. The runtime sizes of the corresponding models are found in Table 4 , along with their training speeds in Table 5 . , 12, 16, 24, 32, 45, 64, 90, 128, 180, 256, 360, 512, 720, 1024 d eh1 64 6, 8, 12, 16, 24, 32, 45, 64, 90, 128, 180, 256, 360, 512 d eh2 12 6, 8, 12, 16, 24, 32, 45, 64, 90, 128, 180, 256, 360, 512 improved, we also observed higher variability in results. Following Veličković et al. (2022) , we report the best performance between PGN-u and PGN-m on every task in the column titled PGN-c (for PGN-combination). Note therefore that PGN-c does not represent a single model.We bold the results of the best-performing single model for each specific task. RT was the bestperforming on 11 out of 30 of the tasks. Overall, MPNN was the top-scoring baseline model, winning on 8 tasks. Deep Sets won on 4 tasks. PGN-u won on 1 task, and PGN-m won on 5 tasks. Finally, GAT-v2 won on 1 task.Table 1 shows test performances of all tuned models across 8 algorithm classes. As mentioned in Section 5.1.3, RT performed best in 6 out of 8 of the classes, and second-best on greedy algorithms and sorting.Table 8 shows which algorithms belong to which classes, and which task types. Each algorithm may correspond to multiple task types if the hint and final outputs of that algorithm differ.Table 9 shows test performance of the baseline GNN models trained on datasets of varying size, without hyperparameter tuning. Specifically, we compare test scores between the models trained on datasets of size 1000, and datasets of size 10000. We note that the results in the first row agree closely with those in the CLRS-30 paper (Table 11 ).

F CLRS-30 BENCHMARK ABLATIONS

We perform our ablation studies on the 8 core CLRS-30 tasks. Table 13 compares the test performance of the standard transformer to the performance of RT. Table 14 shows the drop in test performance that resulted from restricting RT to one layer. The original, tuned RT had 3 layers, and is labeled RT-3. Table 15 shows the marginal improvement that resulted from decoding node vectors using only edge vectors. Table 16 shows the effects that handling global input vectors through a core node (vs. concatenation with the input node vectors) has on RT test performance. RT with core node only won on 3 out of 7 tasks, but had the higher average test performance by 0.08%. Finally, Table 17 shows the drop in test performance that resulted from removing edge updates from RT.G CLRS-30 MODEL VARIANCE 

