OPTIMAL TRANSPORT GRAPH NEURAL NETWORKS

Abstract

Current graph neural network (GNN) architectures naively average or sum node embeddings into an aggregated graph representation-potentially losing structural or semantic information. We here introduce OT-GNN, a model that computes graph embeddings using parametric prototypes that highlight key facets of different graph aspects. Towards this goal, we are (to our knowledge) the first to successfully combine optimal transport (OT) with parametric graph models. Graph representations are obtained from Wasserstein distances between the set of GNN node embeddings and "prototype" point clouds as free parameters. We theoretically prove that, unlike traditional sum aggregation, our function class on point clouds satisfies a fundamental universal approximation theorem. Empirically, we address an inherent collapse optimization issue by proposing a noise contrastive regularizer to steer the model towards truly exploiting the optimal transport geometry. Finally, we consistently report better generalization performance on several molecular property prediction tasks, while exhibiting smoother graph representations.

1. INTRODUCTION

Recently, there has been considerable interest in developing learning algorithms for structured data such as graphs. For example, molecular property prediction has many applications in chemistry and drug discovery (Yang et al., 2019; Vamathevan et al., 2019) . Historically, graphs were decomposed into features such as molecular fingerprints, or turned into non-parametric graph kernels (Vishwanathan et al., 2010; Shervashidze et al., 2011) . More recently, learned representations via graph neural networks (GNNs) have achieved state-of-the-art on graph prediction tasks (Duvenaud et al., 2015; Defferrard et al., 2016; Kipf & Welling, 2017; Yang et al., 2019) . Despite these successes, graph neural networks are often underutilized in whole graph prediction tasks such as molecule property prediction. Specifically, while GNNs produce node embeddings for each atom in the molecule, these are typically aggregated via simple operations such as a sum or average, turning the molecule into a single vector prior to classification or regression. As a result, some of the information naturally extracted by node embeddings may be lost. Departing from this simple aggregation step, Togninalli et al. ( 2019) recently proposed a kernel function over graphs by directly comparing non-parametric node embeddings as point clouds through optimal transport (Wasserstein distance). Their non-parametric model yields better empirical performance over popular graph kernels, but this idea hasn't been extended to the more challenging parametric case where optimization difficulties have to be reconciled with the combinatorial aspects of optimal transport solvers. Motivated by these observations and drawing inspiration from prior work on prototype learning (appendix F), we introduce a new class of GNNs where the key representational step consists of comparing each input graph to a set of abstract prototypes (fig. 1 ). These prototypes play the role of basis functions; they are stored as point clouds as if they were encoded from actual real graphs. Each input graph is first encoded into a set of node embeddings using any existing GNN architecture. We then compare this resulting embedding point cloud to the prototype embedding sets. Formally, the distance between two point clouds is measured by their optimal transport Wasserstein distances. The prototypes as abstract basis functions can be understood as keys that highlight property values associated with different graph structural features. In contrast to previous kernel methods, the prototypes are learned together with the GNN parameters in an end-to-end manner.

