OPTIMAL TRANSPORT GRAPH NEURAL NETWORKS

Abstract

Current graph neural network (GNN) architectures naively average or sum node embeddings into an aggregated graph representation-potentially losing structural or semantic information. We here introduce OT-GNN, a model that computes graph embeddings using parametric prototypes that highlight key facets of different graph aspects. Towards this goal, we are (to our knowledge) the first to successfully combine optimal transport (OT) with parametric graph models. Graph representations are obtained from Wasserstein distances between the set of GNN node embeddings and "prototype" point clouds as free parameters. We theoretically prove that, unlike traditional sum aggregation, our function class on point clouds satisfies a fundamental universal approximation theorem. Empirically, we address an inherent collapse optimization issue by proposing a noise contrastive regularizer to steer the model towards truly exploiting the optimal transport geometry. Finally, we consistently report better generalization performance on several molecular property prediction tasks, while exhibiting smoother graph representations.

1. INTRODUCTION

Recently, there has been considerable interest in developing learning algorithms for structured data such as graphs. For example, molecular property prediction has many applications in chemistry and drug discovery (Yang et al., 2019; Vamathevan et al., 2019) . Historically, graphs were decomposed into features such as molecular fingerprints, or turned into non-parametric graph kernels (Vishwanathan et al., 2010; Shervashidze et al., 2011) . More recently, learned representations via graph neural networks (GNNs) have achieved state-of-the-art on graph prediction tasks (Duvenaud et al., 2015; Defferrard et al., 2016; Kipf & Welling, 2017; Yang et al., 2019) . Despite these successes, graph neural networks are often underutilized in whole graph prediction tasks such as molecule property prediction. Specifically, while GNNs produce node embeddings for each atom in the molecule, these are typically aggregated via simple operations such as a sum or average, turning the molecule into a single vector prior to classification or regression. As a result, some of the information naturally extracted by node embeddings may be lost. Departing from this simple aggregation step, Togninalli et al. (2019) recently proposed a kernel function over graphs by directly comparing non-parametric node embeddings as point clouds through optimal transport (Wasserstein distance). Their non-parametric model yields better empirical performance over popular graph kernels, but this idea hasn't been extended to the more challenging parametric case where optimization difficulties have to be reconciled with the combinatorial aspects of optimal transport solvers. Motivated by these observations and drawing inspiration from prior work on prototype learning (appendix F), we introduce a new class of GNNs where the key representational step consists of comparing each input graph to a set of abstract prototypes (fig. 1 ). These prototypes play the role of basis functions; they are stored as point clouds as if they were encoded from actual real graphs. Each input graph is first encoded into a set of node embeddings using any existing GNN architecture. We then compare this resulting embedding point cloud to the prototype embedding sets. Formally, the distance between two point clouds is measured by their optimal transport Wasserstein distances. The prototypes as abstract basis functions can be understood as keys that highlight property values associated with different graph structural features. In contrast to previous kernel methods, the prototypes are learned together with the GNN parameters in an end-to-end manner. Figure 1 : Our OT-GNN prototype model computes graph embeddings from Wasserstein distances between (a) the set of GNN node embeddings and (b) prototype embedding sets. These distances are then used as the molecular representation (c) for supervised tasks, e.g. property prediction. We assume that a few prototypes, e.g. some functional groups, highlight key facets or structural features of graphs relevant to a particular downstream task at hand. We express graphs by relating them to these abstract prototypes represented as free point cloud parameters. Our notion of prototypes is inspired from the vast prior work on prototype learning which we highlight in appendix F. In our case, prototypes are not required to be the mean of a cluster of data, but instead they are entities living in the data embedding space that capture helpful information for the task under consideration. The closest analogy are the centers of radial basis function networks (Chen et al., 1991; Poggio & Girosi, 1990 ), but we also inspire from learning vector quantization approaches (Kohonen, 1995) and prototypical networks (Snell et al., 2017) . Our model improves upon traditional aggregation by explicitly tapping into the full set of node embeddings without collapsing them first to a single vector. We theoretically prove that, unlike standard GNN aggregation, our model defines a class of set functions that is a universal approximator. Introducing prototype points clouds as free parameters trained using combinatorial optimal transport solvers creates a challenging optimization problem. Indeed, as the models are trained end-to-end, the primary signal is initially available only in aggregate form. If trained as is, the prototypes often collapse to single points, reducing the Wasserstein distance between point clouds to Euclidean comparisons of their means. To counter this effect, we introduce a contrastive regularizer which effectively prevents the model from collapsing (Section 3.2). We demonstrate its merits empirically. Our contributions. First, we introduce an efficiently trainable class of graph neural networks enhanced with optimal transport (OT) primitives for computing graph representations based on relations with abstract prototypes. Second, we are the first to successfully train parametric graph models together with combinatorial OT distances, despite optimization difficulties. A key element is our noise contrastive regularizer that prevents the model from collapsing back to standard summation, thus fully exploiting the OT geometry. Third, we provide a theoretical justification of the increased representational power compared to the standard GNN aggregation method. Finally, our model shows consistent empirical improvements over previous state-of-the-art on molecular datasets, yielding also smoother graph embedding spaces.

2.1. DIRECTED MESSAGE PASSING NEURAL NETWORKS (D-MPNN)

We briefly remind here of the simplified D-MPNN (Dai et al., 2016) architecture which was adapted for state-of-the-art molecular property prediction by Yang et al. (2019) . This model takes as input a directed graph G = (V, E), with node and edge features denoted by x v and e vw respectively, for v, w in the vertex set V and v → w in the edge set E. The parameters of D-MPNN are the matrices {W i , W m , W o }. It keeps track of messages m t vw and hidden states h t vw for each step t, defined as follows. An initial hidden state is set to h 0 vw := ReLU (W i cat(x v , e vw )) where "cat" denotes concatenation. Then, the updates are: m t+1 vw = k∈N (v)\{w} h t kv , h t+1 vw = ReLU (h 0 vw + W m m t+1 vw ), where N (v) = {k ∈ V |(k, v) ∈ E} denotes v's incoming neighbors. After T steps of message passing, node embeddings are obtained by summing edge embeddings: m v = w∈N (v) h T vw , h v = ReLU (W o cat(x v , m v )). (2) A final graph embedding is then obtained as h = v∈V h v , which is usually fed to a multilayer perceptron (MLP) for classification or regression.

2.2. OPTIMAL TRANSPORT GEOMETRY

Optimal Transport (Peyré et al., 2019) is a mathematical framework that defines distances or similarities between objects such as probability distributions, either discrete or continuous, as the cost of an optimal transport plan from one to the other. n 1 n and T T 1 n = 1 m 1 m . Intuitively, the marginal constraints mean that T preserves the mass from X to Y. We denote the set of such couplings as C XY . Given a cost function c on R d , its associated Wasserstein discrepancy is defined as W(X, Y) = min T∈C XY ij T ij c(x i , y j ). (3) We further describe the shape of optimal transports for point clouds of same sizes in Appendix B.3.

3.1. ARCHITECTURE ENHANCEMENT

Reformulating standard architectures. As mentioned at the end of Section 2.1, the final graph embedding h = v∈V h v obtained by aggregating node embeddings is usually fed to a MLP performing a matrix-multiplication whose i-th component is (Rh) i = r i , h , where r i is the i-th row of matrix R. Replacing •, • by a distance/kernel k(•, •) allows the processing of more general graph representations than just vectors in R d , such as point clouds or adjacency tensors. From a single point to a point cloud. We propose to replace the aggregated graph embedding h = v∈V h v by the point cloud (of unaggregated node embeddings) H = {h v } v∈V , and the inner-products h, r i by the below written Wasserstein discrepancy: W(H, Q i ) := min T∈C HQ i vj T vj c(h v , q j i ), where Q i = {q j i } j∈{1,...,N } , ∀i ∈ {1, . . . , M } represent M prototype point clouds each being represented as a set of N embeddings as free trainable parameters, and the cost is chosen as c = • -• 2 2 or c = -•, • . Note that both options yield identical optimal transport plans. Greater representational power. We formulate mathematically in Section 4 that this kernel has a strictly greater representational power than the kernel corresponding to standard inner-product on top of a sum aggregation, to distinguish between different point clouds. Final architecture. Finally, the vector of all Wasserstein distances in eq. ( 4) becomes the input to a final MLP with a single scalar as output. This can then be used as the prediction for various downstream tasks. This model is depicted in fig. 1 and motivated theoretically in section 4.1.

3.2. CONTRASTIVE REGULARIZATION

What would happen to W(H, Q i ) if all points q j i belonging to point cloud Q i would collapse to the same point q i ? All transport plans would yield the same cost, giving for c = -•, • : W(H, Q i ) = - vj T vj h v , q j i = -h, q i /|V | . ( ) In this scenario, our proposition would simply over-parametrize the standard Euclidean model. contrastive regularization for runs using the same random seed. Points in the prototypes tend to cluster and collapse more when no regularization is used, suggesting that the optimal transport plan no longer remains uniquely discriminative. Prototype 1 (red) is enlarged for clarity: without regularization it is clumped together (left), but with regularization it is distributed across the space (right). A first obstacle and its cause. Our first empirical trials with OT-enhanced GNNs showed that a model trained with only the Wasserstein component would sometimes perform similarly to the Euclidean baseline in both train and validation settings, in spite of its greater representational power. Further investigation revealed that the Wasserstein model would naturally displace the points in each of its prototype point clouds in such a way that the optimal transport plan T obtained by maximizing vj T vj h v , q j i was not discriminative, i.e. many other transports would yield a similar Wasserstein cost. Indeed, as shown in Eq. ( 5), if each point cloud collapses to its mean, then the Wasserstein geometry collaspses to Euclidean geometry. In this scenario, any transport plan yields the same Wasserstein cost. Further explanations are provided in Appendix A.1 and Figure 3 ..

Contrastive regularization.

To address this difficulty, we consider adding a regularizer which encourages the model to displace its prototype point clouds such that the optimal transport plans would be discriminative against chosen contrastive transport plans. Namely, consider a point cloud Y of node embeddings and let T i be an optimal transport plan obtained in the computation of W(Y, Q i ). For each T i , we then build a set N eg(T i ) ⊂ C YQi of noisy/contrastive transports. If we denote by W T (X, Y) := kl T kl c(x k , y l ) the Wasserstein cost obtained for the particular transport T, then our contrastive regularization consists in maximizing the term: i log e -W T i (Y,Qi) e -W T i (Y,Qi) + T∈N eg(T i ) e -W T (Y,Qi) , which can be interpreted as the log-likelihood that the correct transport T i be (as it should) a better minimizer of W T (Y, Q i ) than its negative samples. This can be considered as an approximation of log(Pr (T i | Y, Q i )) , where the partition function is approximated by our selection of negative examples, as done e.g. by Nickel & Kiela (2017) . Its effect is shown in Figure 3 . Remarks. The selection of negative examples should reflect the trade-off: (i) to not be too large, for computational efficiency while (ii) containing sufficiently meaningful and challenging contrastive samples. Details about choice of contrastive samples are given in Appendix A.2. Note that replacing the set N eg(T i ) with a singleton {T} for a contrastive random variable T lets us rewrite Eq. ( 6) asfoot_0 i log σ(W T -W T i ), reminiscent of noise contrastive estimation (Gutmann & Hyvärinen, 2010) .

3.3. OPTIMIZATION & COMPLEXITY

Backpropagating gradients through optimal transport (OT) has been the subject of recent research investigations: Genevay et al. ( 2017) explain how to unroll and differentiate through the Sinkhorn procedure solving OT, which was extended by Schmitz et al. (2018) to Wasserstein barycenters. However, more recently, Xu (2019) proposed to simply invoke the envelop theorem (Afriat, 1971) to support the idea of keeping the optimal transport plan fixed during the back-propagation of gradients through Wasserstein distances. For the sake of simplicity and training stability, we resort to the latter procedure: keeping T fixed during back-propagation. We discuss complexity in appendix C.

4. THEORETICAL ANALYSIS

In this section we show that the standard architecture lacks a fundamental property of universal approximation of functions defined on point clouds, and that our proposed architecture recovers this property. We will denote by X n d the set of point clouds X = {x i } n i=1 of size n in R d .

4.1. UNIVERSALITY

As seen in Section 3.1, we have replaced the sum aggregationfollowed by the Euclidean innerproductby Wasserstein discrepancies. How does this affect the function class and representations? A common framework used to analyze the geometry inherited from similarities and discrepancies is that of kernel theory. A kernel k on a set X is a symmetric function k : X × X → R, which can either measure similarities or discrepancies. An important property of a given kernel on a space X is whether simple functions defined on top of this kernel can approximate any continuous function on the same space. This is called universality: a crucial property to regress unknown target functions. Universal kernels. A kernel k defined on X n d is said to be universal if the following holds: for any compact subset X ⊂ X n d , the set of functions in the formfoot_1 m j=1 α j σ(k(•, θ j ) + β j ) is dense in the set C(X ) of continuous functions from X to R, w.r.t the sup norm • ∞,X , σ denoting the sigmoid. Although the notion of universality does not indicate how easy it is in practice to learn the correct function, it at least guarantees the absence of a fundamental bottleneck of the model using this kernel. In the following we compare the aggregating kernel agg(X, Y) := i x i , j y j (used by popular GNN models) with the Wasserstein kernel, where W L2 (X, Y) := min T∈C XY ij T ij x i -y j 2 2 , W dot (X, Y) := max T∈C XY ij T ij x i , y j . ( ) Theorem 1. We have that: 1. The aggregation kernel agg is not universal. 2. The Wasserstein kernel W L2 is universal. Proof: See appendix B.1. Universality of the W L2 kernel comes from the fact that its square-root defines a metric, and from the axiom of separation of distances: if d(x, y) = 0 then x = y. Implications. Theorem 1 states that our proposed OT-GNN model is strictly more powerful than the state of the art GNN models that use summation or averaging of node embeddings. Nevertheless, this implies we can only distinguish graphs that have distinct multi-sets of node embeddings, e.g. all Weisfeiler-Lehman distinguishable graphs in the case of graph convolutional networks. In practice, the shape of the aforementioned functions that posses universal approximation capabilities gives an indication of how one should leverage the vector of Wasserstein distances to prototypes to perform graph classification -e.g. using a multilayer perceptron (MLP) on top.

4.2. DEFINITENESS

For the sake of simplified mathematical analysis, similarity kernels are often required to be positive definite (p.d.), which corresponds to discrepancy kernels being conditionally negative definite (c.n.d.). Although such a property has the benefit of yielding the mathematical framework of Reproducing Kernel Hilbert Spaces, it essentially implies linearity, i.e. the possibility to embed the geometry defined by that kernel in a linear vector space. We now discuss that, interestingly, the Wasserstein kernel we used does not satisfy this property, and hence constitutes an interesting instance of a universal, non p.d. kernel. Let us remind these notions. Kernel definiteness. A kernel k is positive definite (p.d.) on X if for n ∈ N * , x 1 , ..., x n ∈ X and c 1 , ..., c n ∈ R, we have ij c i c j k(x i , x j ) ≥ 0. It is conditionally negative definite (c.n.d.) on X if for n ∈ N * , x 1 , ..., x n ∈ X and c 1 , ..., c n ∈ R such that i c i = 0, we have ij c i c j k(x i , x j ) ≤ 0. These two notions relate to each other via the below result Boughorbel et al. (2005) : Proposition 1. Let k be a symmetric kernel on X , let x 0 ∈ X and define the kernel: k(x, y) := - 1 2 [k(x, y) -k(x, x 0 ) -k(y, x 0 ) + k(x 0 , x 0 )]. ( ) Then k is p.d. if and only if k is c.n.d. Example: k = • -• 2 2 and x 0 = 0 yield k = •, • . One can easily show that agg also defines a p.d. kernel, and that agg(•, •) ≤ n 2 W(•, •). However, the Wasserstein kernel is not p.d., as stated in different variants before (e.g. Vert (2008) ) and as reminded by the below theorem. We here give a novel proof in Appendix B.2. Theorem 2. We have that: 1. The (similarity) Wasserstein kernel W dot is not positive definite; 2. The (discrepancy) Wasserstein kernel W L2 is not conditionally negative definite.

5.1. EXPERIMENTAL SETUP

We experiment on 4 benchmark molecular property prediction datasets (Yang et al., 2019) including both regression (ESOL, Lipophilicity) and classification (BACE, BBBP) tasks. These datasets cover different complex chemical properties (e.g. ESOL -water solubility, LIPO -octanol/water distribution coefficient, BACE -inhibition of human β-secretase, BBBP -blood-brain barrier penetration). We provide results for our implementation of 4 different widely used graph-based models. Fingerprint + MLP applies a MLP over the input features which are hashed graph structures (called a molecular fingerprint). GIN is the Graph Isomorphism Network from (Xu et al., 2019) , which is a variant of a GNN. The original GIN does not account for edge features, so we adapt their algorithm to our setting. Next, GAT is the Graph Attention Network from (Veličković et al., 2017) , which uses multi-head attention layers to propagate information. The original GAT model does not account for edge features, so we adapt their algorithm to our setting. More details about our implementation of the GIN and GAT models can be found in the appendix D.2. Finally, Chemprop -D-MPNN (Yang et al., 2019 ) is a graph network that exhibits state-of-the-art performance for molecular representation learning across multiple classification and regression datasets. Empirically we find that this baseline is indeed the best performing, and so is used as to obtain node embeddings in all our prototype models. Its architecture is described in section 2.1. Different variants of our OT-GNN prototype model are described below: ProtoW-L2/Dot is the model that treats point clouds as point sets, and computes the Wasserstein distances to each point cloud (using either L2 distance or (minus) dot product cost functions) as the molecular embedding. ProtoS-L2 is a special case of ProtoW-L2, in which the point clouds have a single point and instead of using Wasserstein distances, we just compute simple Euclidean distances between the aggregated graph embedding and point clouds. Here, we omit using dot product distances, as that model is mathematically equivalent to the GNN model. We use the the POT library (Flamary & Courty, 2017) to compute Wasserstein distances using the network simplex algorithm (Earth Movers distance), which we find empirically to be faster than the Sinkhorn algorithm for our datasets. We define the cost matrix by taking the pairwise L2 or negative dot product distances. As mentioned in Section 3.3, we fix the transport plan, and only backpropagate through the cost matrix for computational efficiency. Additionally, to account for the variable size of each input graph, we multiply the OT distance between two point clouds by their respective sizes. To avoid the problem of point clouds collapsing, we employ the contrastive regularizer defined in Section 3. Lower RMSE is better, while higher AUC is better. Wasserstein models are by default trained with contrastive regularization as described in section 3.2 and outperform those without.

5.2.1. REGRESSION AND CLASSIFICATION

Results are shown in Table 1 . Our prototype models outperform state-of-the art GNN/D-MPNN baselines on all 4 property prediction tasks. Moreover, the prototype models using Wasserstein distance (ProtoW-L2/Dot) achieve better performance on 3 out of 4 of the datasets compared to the prototype model using only Euclidean distances (ProtoS-L2). This indicates that Wasserstein distance confers greater discriminative power compared to traditional aggregation methods. 

5.2.2. NOISE CONTRASTIVE REGULARIZER

Without any constraints, the Wasserstein prototype model will often collapse the set of points in a point cloud into a single point. As mentioned in Section 3.2, we use a contrastive regularizer to force the model to meaningfully distribute point clouds in the embedding space. We show 2D embeddings in Fig. 3 , illustrating that without contrastive regularization, prototype point clouds are often displaced close to their mean, while regularization forces them to nicely scatter. Quantitative results in Table 1 also highlight the benefit of this regularization.

5.2.3. LEARNED EMBEDDING SPACE: QUALITATIVE AND QUANTITATIVE RESULTS

We further examine the learned embedding space of the best baseline (i.e. D-MPNN) and our best Wasserstein model. We claim that our models learn smoother latent representations. We compute the pairwise difference in embedding vectors and the labels for each test data point on the ESOL dataset. Then, we compute two measures of rank correlation, Spearman correlation coefficient (ρ) and Pearson correlation coefficient (r). This is reminiscent of evaluation tasks for the correlation of word embedding similarity with human labels (Luong et al., 2013) . Our ProtoW-L2 achieves better ρ and r scores compared to the D-MPNN model (Table 2 ), that indicating our Wasserstein model constructs more meaningful embeddings with respect to the label distribution. Indeed, Figure 4 plots the pairwise scores for the D-MPNN model (left) and the ProtoW-L2 model (right). Our ProtoW-L2 model, trained to optimize distances in the embedding space, produces more meaningful representations with respect to the label of interest. Moreover, as qualitatively shown in Figure 5 , our model provides more robust molecular embeddings compared to the baseline, in the following sense: we observe that a small perturbation of a molecular embedding corresponds to a small change in predicted property value -a desirable phenomenon that holds rarely for the baseline D-MPNN model 

5.2.4. ADDITIONAL EXPERIMENTAL DETAILS

Model Sizes Using MPNN hidden dimension as 200, and the final output MLP hidden dimension as 100, the number of parameters for the models are given by table 4. The fingerprint used dimension was 2048, explaining why the MLP has a large number of parameters. The D-MPNN model is much smaller than GIN and GAT models because it shares parameters between layers, unlike the others. Our prototype models are even smaller than the D-MPNN model because we do not require the large MLP at the end, instead we compute distances to a few small prototypes (small number of overall parameters used for these point clouds). The dimensions of the prototype embeddings are also smaller compared to the node embedding dimensions of the D-MPNN and other baselines. We did not see significant improvements in quality by increasing any of the hyperparameter values. Remarkably, our model outperforms all the baselines using between 10 to 1.5 times less parameters. Runtimes. We report the average total training time (number of epochs might vary depending on the early stopping criteria), as well as average training epoch time for the D-MPNN and our prototype models in table 3 . We note that our method is between 1 to 7.1 times slower than the D-MPNN baseline which mostly happens due to the frequent calls to the Earth Mover Distance OT solver. (Hammond et al., 2011) . Different such architectures were later unified into the message passing neural networks framework by Gilmer et al. (2017) , and applied to molecular property prediction. A directed variant of message passing from Dai et al. (2016) was later used to improve state-of-the-art in molecular property prediction on a wide variety of datasets by (Yang et al., 2019) . Other applications include recommender systems (Ying et al., 2018a) . Inspired by DeepSets Zaheer et al. (2017) , Xu et al. (2019) suggest a simplified architecture called GIN, which theoretically can discriminate between any different local neighborhoods. Other recent approaches modify the sum-aggregation of node embeddings in the GCN architecture with the aim to preserve more information Kondor et al. (2018) ; Pei et al. (2020) . In this category there is also the recently growing class of hierarchical graph pooling methods which typically either use deterministic and non-differentiable node clustering (Defferrard et al., 2016; Jin et al., 2018) , or differentiable pooling (Ying et al., 2018b; Noutahi et al., 2019; Gao & Ji, 2019) . However, these methods are still strugling with small labelled graphs such as molecules where global and local node interconnections cannot be easily cast as a hierarchical interaction.

6. RELATED WORK

Other recent geometry-inspired GNNs include adaptations to non-Euclidean spaces (Liu et al., 2019; Chami et al., 2019; Bachmann et al., 2019) . We additionally discuss related work on prototype learning in appendix F.

7. CONCLUSION

We propose OT-GNN: one of the first parametric graph models that leverages optimal transport to learn graph representations. It learns abstract prototypes as free parametric point clouds that highlight different aspects of the graph. Empirically, we outperform popular baselines in different molecular property prediction tasks, while the learned representations also exhibit stronger correlation with the target labels. Finally, universal approximation theoretical results enhance the merits of our model.

A FURTHER DETAILS ON CONTRASTIVE REGULARIZATION

A.1 MOTIVATION One may speculate that it was locally easier for the model to extract valuable information if it would behave like the Euclidean component, preventing it from exploring other roads of the optimization landscape. To better understand this situation, consider the scenario in which a subset of points in a prototype point cloud "collapses", i.e. become close to each other (see Figure 3 ), thus sharing similar distances to all the node embeddings of real input graphs. The submatrix of the optimal transport matrix corresponding to these collapsed points can be equally replaced by any other submatrix with the same marginals (i.e. same two vectors obtained by summing rows or columns), meaning that the optimal transport matrix is not discriminative. In general, we want to avoid any two rows or columns in the Wasserstein cost matrix being proportional. An additional problem of point collapsing is that it is a non-escaping situation when using gradient-based learning methods. The reason is that gradients of these collapsed points would become and remain identical, thus nothing will encourage them to "separate" in the future.

A.2 ON THE CHOICE OF CONTRASTIVE SAMPLES

Our experiments were conducted with ten negative samples for each correct transport plan. Five of them were obtained by initializing a matrix with uniform i.i.d entries from [0, 10) and performing around five Sinkhorn iterations (Cuturi, 2013) in order to make the matrix satisfy the marginal constraints. The other five were obtained by randomly permuting the columns of the correct transport plan. The latter choice has the desirable effect of penalizing the points of a prototype point cloud Q i to collapse onto the same point. Indeed, the rows of T i ∈ C HQi index points in H, while its columns index points in Q i .

B THEORETICAL RESULTS

B.1 PROOF OF THEOREM 1 1. Let us first justify why agg is not universal. Consider a function f ∈ C(X ) such that there exists X, Y ∈ X satisfying both f (X) = f (Y) and k x k = l y l . Clearly, any function of the form i α i σ(agg(W i , •) + θ i ) would take equal values on X and Y and hence would not approximate f arbitrarily well.

2.

To justify that W is universal, we take inspiration from the proof of universality of neural networks Cybenko (1989) . Notation. Denote by M (X ) the space of finite, signed regular Borel measures on X . Definition. We say that σ is discriminatory w .r.t a kernel k if for a measure µ ∈ M (X ), X σ(k(Y, X) + θ)dµ(X) = 0 for all Y ∈ X n d and θ ∈ R implies that µ ≡ 0. We start by reminding a lemma coming from the original paper on the universality of neural networks by Cybenko Cybenko (1989) . Lemma. If σ is discriminatory w.r.t. k then k is universal. Proof: Let S be the subset of functions of the form m i=1 α i σ(k(•, Q i ) + θ i ) for any θ i ∈ R, Q i ∈ X n d and m ∈ N * and denote by S the closure 3 of S in C(X ). Assume by contradiction that 2. This comes from proposition 1. Choosing k = W L2 and x 0 = 0 to be the trivial point cloud made of n times the zero vector yields k = W dot . Since k is not positive definite from the previous point of the theorem, k is not conditionally negative definite from proposition 1.

B.3 SHAPE OF THE OPTIMAL TRANSPORT PLAN FOR POINT CLOUDS OF SAME SIZE

The below result describes the shape of optimal transport plans for point clouds of same size. For the sake of curiosity, we also illustrate in Figure 2 the optimal transport for point clouds of different sizes. We note that non-square transports seem to remain relatively sparse as well. This is in line with our empirical observations. Proposition 2. For X, Y ∈ X n,d there exists a rescaled permutation matrix 1 n (δ iσ(j) ) 1≤i,j≤n which is an optimal transport plan, i.e. W L2 (X, Y) = 1 n n j=1 x σ(j) -y j 2 2 , W dot (X, Y) = 1 n n j=1 x σ(j) , y j . (10) Proof. It is well known from Birkhoff's theorem that every squared doubly-stochastic matrix is a convex combination of permutation matrices. Since the Wasserstein cost for a given transport T is a linear function, it is also a convex/concave function, and hence it is maximized/minimized over the convex compact set of couplings at one of its extremal points, namely one of the rescaled permutations, yielding the desired result.

C COMPLEXITY C.1 WASSERSTEIN

Computing the Wasserstein optimal transport plan between two point clouds consists in the minimization of a linear function under linear constraints. It can either be performed exactly by using network simplex methods or interior point methods as done by (Pele & Werman, 2009) in time Õ(n 3 ), or approximately up to ε via the Sinkhorn algorithm (Cuturi, 2013) in time Õ(n 2 /ε 3 ). More recently, (Dvurechensky et al., 2018) proposed an algorithm solving OT up to ε with time complexity Õ(min{n 9/4 /ε, n 2 /ε 2 }) via a primal-dual method inspired from accelerated gradient descent. In our experiments, we used the Python Optimal Transport (POT) library (Flamary & Courty, 2017) . We noticed empirically that the Earth Mover Distance (EMD) solver yielded faster and more accurate solutions than Sinkhorn for our datasets, because the graphs and point clouds were small enough (< 30 elements). However, Sinkhorn may take the lead for larger graphs. Significant speed up could potentially be obtained by rewritting the POT library for it to solve OT in batches over GPUs. In our experiments, we ran all jobs on CPUs.

D FURTHER EXPERIMENTAL DETAILS D.1 SETUP OF EXPERIMENTS

Each dataset is split randomly 5 times into 80%:10%:10% train, validation and test sets. For each of the 5 splits, we run each model 5 times to reduce the variance in particular data splits (resulting in each model being run 25 times). We search hyperparameters for each split of the data, and then take the average performance over all the splits. The hyperparameters are separately searched for each data split, so that the model performance is based on a completely unseen test set, and that there is no data leakage across data splits. The models are trained for 150 epochs with early stopping if the validation error has not improved in 50 epochs and a batch size of 16. We train the models using the Adam optimizer with a learning rate of 5e-4. For the prototype models, we use different learning rates for the GNN and the point clouds (5e-4 and 5e-3 respectively), because empirically we find that the gradients are much smaller for the point clouds. 5 : The parameters for our models (the prototype models all use the same GNN base model), and the values that we used for hyperparameter search. When there is only a single value in the search list, it means we did not search over this value, and used the specified value for all models.

D.2 BASELINE MODELS

Both the GIN (Xu et al., 2019) and GAT (Veličković et al., 2017 ) models were originally used for graphs without edge features. Therefore, we adapt both these algorithms to our use-case, in which edge features are a critical aspect of the prediction task. Here, we expand on the exact architectures that we use for both models. First we introduce common notation that we will use for both models. Each example is defined by a set of vertices and edges (V, E). Let v i ∈ V denote the ith node in the graph, and let e ij ∈ E denote the edge between nodes (i, j). Let h (k) vi be the feature representation of node v i at layer k. Let h eij be the input features for the edge between nodes (i, j), and is static because we do updates only on nodes. Let N (v i ) denote the set of neighbors for node i, not including node i; let N (v i ) denote the set of neighbors for node i as well as the node itself.

GIN

The update rule for GIN is then defined as: h k vi = MLP (k) (1 + (k) ) + vj ∈N (vi) [h (k-1) u + W (k) b h eij ] As with the original model, the final embedding h G is defined as the concatenation of the summed node embeddings for each layer. h G = CONCAT vi h (k) vi |k = 0, 1, 2...K

GAT

For our implementation of the GAT model, we compute the attention scores for each pairwise node α (k) ij as follows. α (k) ij = exp LeakyReLU(a (k) W (k) 1 h (k-1) vi ||W (k) 2 [h (k-1) vj + W (k) b h eij ] ) vj ∈ N (vi) exp LeakyReLU(a (k) W (k) 1 h (k-1) vi ||W (k) 2 [h (k-1) vj + W (k) b h eij ] ) {W (k) 1 , W (k) 2 , W b } are layer specific feature transforms, while a (k) is a vector that computes the final attention score for each pair of nodes. Note that here we do attention across all of a node's neighbors as well as the node itself. The updated node embeddings are as follows: h k vi = vj ∈ N (vi) α (k) i,j h (k-1) vj The final graph embedding is just a simple sum aggregation of the node representations on the last layer (h G = vi h K vi ). As with (Veličković et al., 2017) , we also extend this formulation to utilize multiple attention heads.

E WHAT TYPES OF MOLECULES DO PROTOTYPES CAPTURE ?

To better understand if the learned prototypes offer interpretability, we examined the ProtoW-Dot model trained with NC regularization (weight 0.1). For each of the 10 learned prototypes, we computed the set of molecules in the test set that are closer in terms of the corresponding Wasserstein distance to this prototype than to any other prototype. While we noticed that one prototype is closest to the majority of molecules, there are other prototypes that are more interpretable as shown in fig. 6 .

F RELATED WORK ON PROTOTYPE LEARNING

Learning prototypes to solve machine learning tasks has been extensively studied. Generalized learning vector quantization (GLVQ) methods (Kohonen, 1995; Sato & Yamada, 1995) perform classification by assigning to each data point the class of the closest neighbor prototype measured using some distance function, typically Euclidean. Each class has a set of prototypes that are optimized together such that the closest wrong prototype is moved away, while the correct prototype is brought closer. Several extensions of GLVQ (Hammer & Villmann, 2002; Schneider et al., 2009; Bunte et al., 2012) introduce feature weights and parameterized input transformations to leverage more flexible and adaptive metric spaces. Nevertheless, such models are limited to classification tasks and might suffer from extreme gradient sparsity. On the other hand, more similar to our work are the radial basis function (RBF) networks (Chen et al., 1991) that perform classification or regression based on RBF kernel similarities to all prototypes. One such similarity vector is computed for each data point and used together with a shared linear output layer to obtain the final predictions. Prototypes are typically set in an unsupervised fashion, e.g. via k-means clustering, or using the Orthogonal Least Square Learning algorithm, unlike being learned using backpropagation as in our case. Combining the non-parametric mathematical power of kernel methods with the flexibility of deep learning models have resulted in even more expressive and scalable similarity functions that have Figure 6 : The closest molecules to some particular prototypes in terms of the corresponding Wasserstein distance. One can observe that some prototypes are closer to insoluble molecules containing rings (Prototype 2), while others prefer more soluble molecules (Prototype 1). been conveniently trained with backpropagation by maximizing the likelihood of a Gaussian process (Wilson et al., 2016) . Recently, learning parametric data embeddings and prototypes was also investigated for few-shot and zero-shot classification scenarios (Snell et al., 2017) . Finally, using distances to prototypes as opposed to p.d. kernels, while not as common, was analyzed in the past (Duin & Pękalska, 2012; Snell et al., 2017) . In contrast with the above line of work, our research focuses on learning parametric prototypes for graphs trained jointly with graph embedding functions for both graph classification and regression problems. Prototypes are modeled as sets (point clouds) of embeddings, while graphs represented by sets of unaggregated node embeddings obtained using graph neural network models. Disimilarities between prototypes and graph embeddings are then quantified via set distances computed using optimal transport. Additional challenges arise due to the combinatorial nature of the Wasserstein distances between sets, hence our discussion on introducing the noise contrastive regularizer.



where σ(•) is the sigmoid function. For m ∈ N * , αjβj ∈ R and θj ∈ X n d . W.r.t the topology defined by the sup norm f ∞,X := sup X∈X |f (X)|.



Figure 2: We illustrate, for a given 2D point cloud, the optimal transport plan obtained from minimizing the Wasserstein costs; c(•, •) denotes the Euclidean distance. A higher dotted-line thickness illustrates a greater mass transport.

Figure 3: 2D embeddings of prototypes and of real molecule samples with (right) and without (left)contrastive regularization for runs using the same random seed. Points in the prototypes tend to cluster and collapse more when no regularization is used, suggesting that the optimal transport plan no longer remains uniquely discriminative. Prototype 1 (red) is enlarged for clarity: without regularization it is clumped together (left), but with regularization it is distributed across the space (right).

Figure 4: Comparison of the correlation between graph embedding distances (X axis) and label distances (Y axis) on the ESOL dataset.

Figure 5: 2D heatmaps of T-SNE Maaten & Hinton (2008) projections of molecular embeddings (before the last linear layer) w.r.t. their associated predicted labels. Heat colors are interpolations based only on the test molecules from each dataset. Comparing (a) vs (b) and (c) vs (d), we can observe a smoother space of our model compared to the D-MPNN baseline (see also main text).

2. More details about experimental setup are presented in Appendix D.1. Results on the property prediction datasets. Best model is in bold, second best is underlined.



. Our Proto-W-L2 model yields smoother heatmaps. Training times for each model and dataset.

Number of parameters per model. Corresponding hyperparameters are in appendix D.1.

The molecular datasets used for experiments here are small in size (varying from 1-4k data points), so this is a fair method of comparison, and is indeed what is done in other works on molecular property predictionYang et al. (2019).

S = C(X )

. By the Hahn-Banach theorem, there exists a bounded linear functional L on C(X ) such that for all h ∈ S, L(h) = 0 and such that there exists h ∈ C(X ) s.t. L(h ) = 0. By the Riesz representation theorem, this bounded linear functional is of the form:Since σ is discriminatory w.r.t. k, this implies that µ = 0 and hence L ≡ 0, which is a contradiction with L(h ) = 0. Hence S = C(X ), i.e. S is dense in C(X ) and k is universal. Now let us look at the part of the proof that is new.for θ ≥ 0 and ∅ for θ < 0. By the Lebesgue Bounded Convergence Theorem we have: From the Hahn decomposition theorem, there exist disjoint Borel sets P, N such that X = P ∪ N and µ = µ + -µ -where µ + (A) := µ(A ∩ P ), µ -(A) := µ(A ∩ N ) for any Borel set A with µ + , µ - being positive measures. Since µ + and µ -coincide on all balls on a finite dimensional metric space, they coincide everywhere Hoffmann-Jørgensen (1976) and hence µ ≡ 0.Combining the previous lemmas with k = W L2 concludes the proof.

B.2 PROOF OF THEOREM 2

1. We build a counter example. We consider 4 point clouds of size n = 2 and dimension d = 2. First, define u i = ( i/2 , i%2) for i ∈ {0, ..., 3}. Then take X 1 = {u 0 , u 1 }, X 2 = {u 0 , u 2 }, X 3 = {u 0 , u 3 } and X 4 = {u 1 , u 2 }. On the one hand, if W(X i , X j ) = 0, then all vectors in the two point clouds are orthogonal, which can only happen for {i, j} = {1, 2}. On the other hand, if W(X i , X j ) = 1, then either i = j = 3 or i = j = 4. This yields the following Gram matrixwhose determinant is -1/16, which implies that this matrix has a negative eigenvalue.

