OPTIMAL TRANSPORT GRAPH NEURAL NETWORKS

Abstract

Current graph neural network (GNN) architectures naively average or sum node embeddings into an aggregated graph representation-potentially losing structural or semantic information. We here introduce OT-GNN, a model that computes graph embeddings using parametric prototypes that highlight key facets of different graph aspects. Towards this goal, we are (to our knowledge) the first to successfully combine optimal transport (OT) with parametric graph models. Graph representations are obtained from Wasserstein distances between the set of GNN node embeddings and "prototype" point clouds as free parameters. We theoretically prove that, unlike traditional sum aggregation, our function class on point clouds satisfies a fundamental universal approximation theorem. Empirically, we address an inherent collapse optimization issue by proposing a noise contrastive regularizer to steer the model towards truly exploiting the optimal transport geometry. Finally, we consistently report better generalization performance on several molecular property prediction tasks, while exhibiting smoother graph representations.

1. INTRODUCTION

Recently, there has been considerable interest in developing learning algorithms for structured data such as graphs. For example, molecular property prediction has many applications in chemistry and drug discovery (Yang et al., 2019; Vamathevan et al., 2019) . Historically, graphs were decomposed into features such as molecular fingerprints, or turned into non-parametric graph kernels (Vishwanathan et al., 2010; Shervashidze et al., 2011) . More recently, learned representations via graph neural networks (GNNs) have achieved state-of-the-art on graph prediction tasks (Duvenaud et al., 2015; Defferrard et al., 2016; Kipf & Welling, 2017; Yang et al., 2019) . Despite these successes, graph neural networks are often underutilized in whole graph prediction tasks such as molecule property prediction. Specifically, while GNNs produce node embeddings for each atom in the molecule, these are typically aggregated via simple operations such as a sum or average, turning the molecule into a single vector prior to classification or regression. As a result, some of the information naturally extracted by node embeddings may be lost. Departing from this simple aggregation step, Togninalli et al. ( 2019) recently proposed a kernel function over graphs by directly comparing non-parametric node embeddings as point clouds through optimal transport (Wasserstein distance). Their non-parametric model yields better empirical performance over popular graph kernels, but this idea hasn't been extended to the more challenging parametric case where optimization difficulties have to be reconciled with the combinatorial aspects of optimal transport solvers. Motivated by these observations and drawing inspiration from prior work on prototype learning (appendix F), we introduce a new class of GNNs where the key representational step consists of comparing each input graph to a set of abstract prototypes (fig. 1 ). These prototypes play the role of basis functions; they are stored as point clouds as if they were encoded from actual real graphs. Each input graph is first encoded into a set of node embeddings using any existing GNN architecture. We then compare this resulting embedding point cloud to the prototype embedding sets. Formally, the distance between two point clouds is measured by their optimal transport Wasserstein distances. The prototypes as abstract basis functions can be understood as keys that highlight property values associated with different graph structural features. In contrast to previous kernel methods, the prototypes are learned together with the GNN parameters in an end-to-end manner. Figure 1 : Our OT-GNN prototype model computes graph embeddings from Wasserstein distances between (a) the set of GNN node embeddings and (b) prototype embedding sets. These distances are then used as the molecular representation (c) for supervised tasks, e.g. property prediction. We assume that a few prototypes, e.g. some functional groups, highlight key facets or structural features of graphs relevant to a particular downstream task at hand. We express graphs by relating them to these abstract prototypes represented as free point cloud parameters. Our notion of prototypes is inspired from the vast prior work on prototype learning which we highlight in appendix F. In our case, prototypes are not required to be the mean of a cluster of data, but instead they are entities living in the data embedding space that capture helpful information for the task under consideration. The closest analogy are the centers of radial basis function networks (Chen et al., 1991; Poggio & Girosi, 1990 ), but we also inspire from learning vector quantization approaches (Kohonen, 1995) and prototypical networks (Snell et al., 2017) . Our model improves upon traditional aggregation by explicitly tapping into the full set of node embeddings without collapsing them first to a single vector. We theoretically prove that, unlike standard GNN aggregation, our model defines a class of set functions that is a universal approximator. Introducing prototype points clouds as free parameters trained using combinatorial optimal transport solvers creates a challenging optimization problem. Indeed, as the models are trained end-to-end, the primary signal is initially available only in aggregate form. If trained as is, the prototypes often collapse to single points, reducing the Wasserstein distance between point clouds to Euclidean comparisons of their means. To counter this effect, we introduce a contrastive regularizer which effectively prevents the model from collapsing (Section 3.2). We demonstrate its merits empirically. Our contributions. First, we introduce an efficiently trainable class of graph neural networks enhanced with optimal transport (OT) primitives for computing graph representations based on relations with abstract prototypes. Second, we are the first to successfully train parametric graph models together with combinatorial OT distances, despite optimization difficulties. A key element is our noise contrastive regularizer that prevents the model from collapsing back to standard summation, thus fully exploiting the OT geometry. Third, we provide a theoretical justification of the increased representational power compared to the standard GNN aggregation method. Finally, our model shows consistent empirical improvements over previous state-of-the-art on molecular datasets, yielding also smoother graph embedding spaces.

2. PRELIMINARIES 2.1 DIRECTED MESSAGE PASSING NEURAL NETWORKS (D-MPNN)

We briefly remind here of the simplified D-MPNN (Dai et al., 2016) architecture which was adapted for state-of-the-art molecular property prediction by Yang et al. (2019) . This model takes as input a directed graph G = (V, E), with node and edge features denoted by x v and e vw respectively, for v, w in the vertex set V and v → w in the edge set E. The parameters of D-MPNN are the matrices {W i , W m , W o }. It keeps track of messages m t vw and hidden states h t vw for each step t, defined as follows. An initial hidden state is set to h 0 vw := ReLU (W i cat(x v , e vw )) where "cat" denotes concatenation. Then, the updates are: m t+1 vw = k∈N (v)\{w} h t kv , h t+1 vw = ReLU (h 0 vw + W m m t+1 vw ),

