RELATIONAL ATTENTION: GENERALIZING TRANSFORMERS FOR GRAPH-STRUCTURED TASKS

Abstract

Transformers flexibly operate over sets of real-valued vectors representing taskspecific entities and their attributes, where each vector might encode one wordpiece token and its position in a sequence, or some piece of information that carries no position at all. But as set processors, standard transformers are at a disadvantage in reasoning over more general graph-structured data where nodes represent entities and edges represent relations between entities. To address this shortcoming, we generalize transformer attention to consider and update edge vectors in each transformer layer. We evaluate this relational transformer on a diverse array of graph-structured tasks, including the large and challenging CLRS Algorithmic Reasoning Benchmark. There, it dramatically outperforms state-of-theart graph neural networks expressly designed to reason over graph-structured data. Our analysis demonstrates that these gains are attributable to relational attention's inherent ability to leverage the greater expressivity of graphs over sets.

1. INTRODUCTION

Figure 1 : The relational transformer (RT) outperforms baseline GNNs on a set of 30 distinct graph-structured tasks from CLRS-30, averaged by algorithm class. Graph-structured problems turn up in many domains, including knowledge bases (Hu et al., 2021; Bordes et al., 2013) , communication networks (Leskovec et al., 2010) , citation networks (McCallum et al., 2000) , and molecules (Debnath et al., 1991; Zhang et al., 2020b) . One example is predicting the bioactive properties of a molecule, where the atoms of the molecule are the nodes of the graph and the bonds are the edges. Along with their ubiquity, graph-structured problems vary widely in difficulty. For example, certain graph problems can be solved with a simple multi-layer perceptron, while others are quite challenging and require explicit modeling of relational characteristics. Graph Neural Networks (GNNs) are designed to process graphstructured data, including the graph's (possibly directed) edge structure and (in some cases) features associated with the edges. In particular, they learn to represent graph features by passing messages between neighboring nodes and edges, and updating the node and (optionally) edge vectors. Importantly, GNNs typically restrict message passing to operate over the edges in the graph. Standard transformers lack the relational inductive biases (Battaglia et al., 2018) that are explicitly built into the most commonly used GNNs. Instead, the transformer fundamentally consumes unordered sets of real-valued vectors, injecting no other assumptions. This allows entities carrying domain-specific attributes (like position) to be encoded as vectors for input to the same transformer architecture applied to different domains. Transformers have produced impressive results in a wide Many of the domains transformers succeed in consist of array-structured data, such as text or images. By contrast, graph data is centrally concerned with pairwise relations between entities, represented as edges and edge attributes. Graphs are more general and expressive than sets, in the sense that a set is a special case of a graph-one without edges. So it is not immediately obvious how graph data can be processed by transformers in a way that preserves relational information. Transformers have been successfully applied to graph-structured tasks in one of two broad ways. Certain works, most recently TokenGT (Kim et al., 2022) , encode graphs as sets of real-valued vectors passed to a standard transformer. Other works change the transformer architecture itself to consider relational information, e.g. by introducing relative position vectors to transformer attention. We discuss these and many other such approaches in Section 4. Our novel contribution is relational attention, a mathematically elegant extension of transformer attention, which incorporates edge vectors as first-class model components. We call the resulting transformer architecture the Relational Transformer (RT). As a native graph-to-graph model, RT does not rely on special encoding schemes to input or output graph data. We find that RT outperforms baseline GNNs on a large and diverse set of difficult graph-structured tasks. In particular, RT establishes dramatically improved state-of-the-art performance (Figure 1 ) over baseline GNNs on the challenging CLRS-30 (Veličković et al., 2022) , which comprises 30 different algorithmic tasks in a framework for probing the reasoning abilities of graph-to-graph models. To summarize our main contributions: • We introduce the relational transformer for application to arbitrary graph-structured tasks, and make the implementation available at https://github.com/CameronDiao/ relational-transformer. • We evaluate the reasoning power of RT on a wide range of challenging graph-structured tasks, achieving new state-of-the-art results on CLRS-30. • We enhance the CLRS-30 framework to support evaluation of a broader array of models (Section 5.1.2). • We improve the performance of CLRS-30 baseline models by adding multi-layer functionality, and tuning their hyperparameters (Section 5.1.1).

2. GRAPH NEURAL NETWORKS

We introduce the graph-to-graph model formalism used in the rest of this paper, inspired by Battaglia et al. (2018) . The input graph is a directed, attributed graph G = (N , E), where N is an unordered set of node vectors n i ∈ R dn , i denoting the i-th node. E is a set of edge vectors e ij ∈ R de , where directed edge (i, j) points from node j to node i. Each layer l in the model accepts a graph G l as



Figure 2: Categories of GNNs and Transformers, compared in terms of transformer machinery and edge vector incorporation. Model categories tested in our experiments are marked in bold.

