CAUSALLY-GUIDED REGULARIZATION OF GRAPH AT-TENTION IMPROVES GENERALIZABILITY

Abstract

Graph attention networks estimate the relational importance of node neighbors to aggregate relevant information over local neighborhoods for a prediction task. However, the inferred attentions are vulnerable to spurious correlations and connectivity in the training data, hampering the generalizability of models. We introduce CAR, a general-purpose regularization framework for graph attention networks. Embodying a causal inference approach based on invariance prediction, CAR aligns the attention mechanism with the causal effects of active interventions on graph connectivity in a scalable manner. CAR is compatible with a variety of graph attention architectures, and we show that it systematically improves generalizability on various node classification tasks. Our ablation studies indicate that CAR hones in on the aspects of graph structure most pertinent to the prediction (e.g., homophily), and does so more effectively than alternative approaches. Finally, we also show that CAR enhances interpretability of attention coefficients by accentuating node-neighbor relations that point to causal hypotheses.

1. INTRODUCTION

Graphs encode rich relational information that can be leveraged in learning tasks across a wide variety of domains. Graph neural networks (GNNs) can learn powerful node, edge or graph-level representations by aggregating a node's representations with that of its neighbors. The specifics of a GNN's neighborhood aggregation scheme are critical to its effectiveness on a prediction task. For instance, graph convolutional networks (GCNs) aggregate information via a simple averaging or max-pooling of neighbor features. GCNs are prone to suffer in many real-world scenarios where uninformative or noisy connections exist between nodes (Kipf & Welling, 2017; Hamilton et al., 2017) . Graph-based attention mechanisms combat these issues by quantifying the relevance of node-neighbor relations and softly selecting neighbors in the aggregation step accordingly (Velickovic et al., 2018; Brody et al., 2022; Shi et al., 2021) . This process of attending to select neighbors has contributed to significant performance gains for GNNs across a variety of tasks (Zhou et al., 2018; Veličković, 2022) . Similar to the use of attention in natural language processing and computer vision, attention in graph settings also enables the interpretability of model predictions via the examination of attention coefficients (Serrano & Smith, 2019) . However, graph attention mechanisms can be prone to spurious edges and correlations that mislead them in how they attend to node neighbors, which manifests as a failure to generalize to unseen data (Knyazev et al., 2019) . One approach to improve GNNs' generalizability is to regularize attention coefficients in order to make them more robust to spurious correlations/connections in the training data. Previous work has focused on L 0 regularization of attention coefficients to enforce sparsity (Ye & Ji, 2021) or has co-optimized a link prediction task using attention (Kim & Oh, 2021) . Since these regularization strategies are formulated independently of the primary prediction task, they align the attention mechanism with some intrinsic property of the input graph without regard for the training objective. We take a different approach and consider the question: "What is the importance of a specific edge to the prediction task?" Our answer comes from the perspective of regularization: we introduce CAR, a causal attention regularization framework that is broadly suitable for graph attention network architectures (Figure 1 ). Intuitively, an edge in the input graph is important to a prediction task if removing it leads to substantial degradation in the prediction performance of the GNN. The key conceptual advance of this work is to scalably leverage active interventions on node neighborhoods (i.e., deletion of specific edges) to align graph attention training with the causal impact of these interventions on task performance. Theoretically, our approach is motivated by the invariant prediction framework for causal inference (Peters et al., 2016; Wu et al., 2022) . While some efforts have previously been made to infuse notions of causality into GNNs, these causal approaches have been largely limited to using causal effects from pre-trained models as features for a separate model (Feng et al., 2021; Knyazev et al., 2019) or decoupling causal from non-causal effects (Sui et al., 2021) . We apply CAR on three graph attention architectures across eight node classification tasks, finding that it consistently improves test loss and accuracy. CAR is able to fine-tune graph attention by improving its alignment with task-specific homophily. Correspondingly, we found that as graph heterophily increases, the margin of CAR's outperformance widens. In contrast, a non-causal approach that directly regularizes with respect to label similarity generalizes less well. On the ogbn-arxiv network, we investigate the citations up/down-weighted by CAR and found them to broadly group into three intuitive themes. Our causal approach can thus enhance the interpretability of attention coefficients, and we provide a qualitative analysis of this improved interpretability. We also present preliminary results demonstrating the applicability of CAR to graph pruning tasks. Due to the size of industrially relevant graphs, it is common to use GCNs or sampling-based approaches on them. There, using attention coefficients learned by CAR on sampled subnetworks may guide graph rewiring of the full network to improve the results obtained with convolutional techniques.

2.1. GRAPH ATTENTION NETWORKS

Attention mechanisms have been effectively used in many domains by enabling models to dynamically attend to the specific parts of an input that are relevant to a prediction task (Chaudhari et al., 2021) . In graph settings, attention mechanisms compute the relevance of edges in the graph for a prediction task. A neighbor aggregation operator then uses this information to weight the contribution of each edge (Lee et al., 2019a; Li et al., 2016; Lee et al., 2019b) . The approach for computing attention is similar in many graph attention mechanisms. A graph attention layer takes as input a set of node features h = { ⃗ h 1 , ..., ⃗ h N }, ⃗ h i ∈ R F , where N is the number of nodes. The graph attention layer uses these node features to compute attention coefficients for each edge: α ij = a(W ⃗ h i , W ⃗ h j ), where a : R F ′ × R F ′ → (0, 1) is the attention mechanism function, and the attention coefficient α ij for an edge indicates the importance of node i's input features to node j. For a node j, these attention coefficients are then used to compute a linear combination of its neighbors' features: ⃗ h ′ j = i∈N (j) α ij W ⃗ h i , s.t. i∈N (j) α ij = 1. For multi-headed attention, each of the K heads first independently calculates its own attention coefficients α (k) i,j with its head-specific attention mechanism a (k) (•, •), after which the head-specific outputs are averaged. In this paper, we focus on three widely used graph attention architectures: the original graph attention network (GAT) (Velickovic et al., 2018) , a modified version of this original network (GATv2) (Brody et al., 2022) , and the Graph Transformer network (Shi et al., 2021) . The three architectures and their equations for computing attention are presented in Appendix A.1.

2.2. CAUSAL ATTENTION REGULARIZATION: AN INVARIANCE PREDICTION FORMULATION

CAR is motivated by the invariant prediction (IP) formulation of causal inference (Peters et al., 2016; Wu et al., 2022) . The central insight of this formulation is that, given sub-models that each contain a different set of predictor variables, the underlying causal model of a system is comprised of the set of all sub-models for which the predicted class distributions are equivalent, up to a noise term. This approach is capable of providing statistically rigorous estimates for both the causal effect strength of predictor variables as well as confidence intervals. With CAR, our core insight is that the graph structure itself, in addition to the set of node features, comprise the set of predictor variables. This is equivalent to the intuition that relevant edges for a particular task should not only be assigned high attention coefficients but also be important to the predictive accuracy of the model (Figure 1 ). Figure 1 : Schematic of CAR: Graph attention networks learn the relative importance of each node-neighbor for a given prediction task. However, their inferred attention coefficients can be miscalibrated due to noise, spurious correlations, or confounders (e.g., node size here). Our causal approach directly intervenes on a sampled subset of edges and supervises an auxiliary task that aligns an edge's causal importance to the task with its attention coefficient. The removal of these relevant edges from the graph should cause the predictions that rely on them to substantially worsen. We leverage the residual formulation of IP to formalize this intuition. This formulation assumes that we can generate sub-models for different sets of predictor variables, each corresponding to a separate experiment e ∈ E. For each sub-model S e , we compute the predictions Y e = g(G e S , X e S , ϵ e ) where G is the graph structure, X is the set of features associated with G, ϵ e is the noise distribution, and S is the set of predictor variables corresponding to S e . We next compute the residuals R = Y -Y e . IP requires that we perform a hypothesis test on the means of the residuals, with the generic approach being to perform an F-test for each sub-model against the null-hypothesis. The relevant assumptions (ϵ e ∼ F ϵ , and ϵ e ⊥ ⊥ S e for all e ∈ E) are satisfied if and only if the conditionals Y e |S e and Y f |S f are identical for all experiments e, f ∈ E. We use an edge intervention-based strategy that corresponds precisely to this IP-based formulation. However, we differ from the standard IP formulation in how we estimate the final causal model. While IP provides a method to explicitly construct an estimator of the true causal model (by taking the intersection of all models for which the null hypothesis was rejected), we rely on interventionguided regularization of graph attention coefficients as a way to aggregate sub-models while balancing model complexity and runtime considerations. In our setting, each sub-model corresponds to a set of edge interventions and, thus, slightly different graph structures. The same GNN architecture is trained on each of these sub-models. Given a set of experiments E = {e} with sub-models S e , outputs Y e and errors ϵ e , we regularize the attention coefficients to align with sub-model errors, thus learning a GNN architecture primarily from causal sub-models. Incorporating this regularization as an auxiliary task, we seek to minimize the following loss: L = L p + λL c (1) The full loss function L consists of the loss associated with the prediction L p , the loss associated with causal attention task L c , and a causal regularization strength hyperparameter λ that mediates the contribution of the regularization loss to the objective. For the prediction loss, we have L p = 1 N N n=1 ℓ p (ŷ (n) , y (n) ) , where N is the size of the training set, ℓ p (•, •) corresponds to the loss function for the given prediction task, ŷ(n) is the prediction for entity n, and y (n) is the ground truth value for entity n. We seek to align the attention coefficient for an edge with the causal effect of removing that edge through the use of the following loss function: L c = 1 R R r=1 1 S (r) (n,i,j)∈S (r) ℓ c α (n) ij , c (n) ij (2) Here, n represents a single entity for which we aim to make a prediction. For a node prediction task, the entity n corresponds to a node, and in a graph prediction task, n corresponds to a graph. In this paper, we assume that all edges are directed and, if necessary, decompose an undirected edge into two directed edges. In each mini-batch, we generate R separate sub-models, each of which consists of a set of edge interventions S (r) , r = 1, . . . , R. Each edge intervention in S (r) is represented by a set of tuples (n, i, j) which denote a selected edge (i, j) for an entity n. Note that in a node classification task, n is the same as j (i.e., the node with the incoming edge). More details for the edge intervention procedure and causal effect calculations can be found in the sections below. The causal effect c (n) ij scores the impact of deleting edge (i, j) through a likelihood ratio test. This causal effect is compared to the edge's attention coefficient α (n) ij via the loss function ℓ c (•, •). A detailed algorithm for CAR is provided in Appendix A.2. Edge intervention procedure We sample a set of edges in each round r such that the prediction for each entity will strictly be affected by at most one edge intervention in that round to ensure effect independence. For example, in a node classification task for a model with one GNN layer, a round of edge interventions entails removing only a single incoming edge for each node being classified. In the graph property prediction case, only one edge will be removed from each graph per round. Because a model with one GNN layer only aggregates information over a 1-hop neighborhood, the removal of each edge will only affect the model's prediction for that edge's target node. To select a set of edges in the L-layer GNN case, edges are sampled from the 1-hop neighborhood of each node being classified, and sampled edges that lie in more than one target nodes' L-hop neighborhood are removed from consideration as intervention candidates. This edge intervention selection procedure is crucial, as it enables the causal effect of each selected intervention c (n) ij on an entity n to be calculated independently of other interventions on a graph. Moreover, by selecting only one intervention per entity within each round, we can parallelize the computation of these causal effects across all entities per round instead of iteratively evaluating just one intervention for the entire graph per entity, significantly aiding scalability. Calculating task-specific causal effects We quantify the causal effect of an intervention at edge (i, j) on entity n through the use of an approximate likelihood-ratio test c (n) ij = σ ρ (n) ij d(n) -1 where ρ (n) ij = ℓ p ŷ(n) \ij , y (n) ℓ p ŷ(n) , y (n) (3) Here, ŷ(n) \ij is the prediction for entity n upon removal of edge (i, j). d(n) represents the node degree for node classification, while it represents the number of edges in a graph for graph classification. Interventions will likely have smaller effects on higher degree nodes due to the increased likelihood of there being multiple edges that are relevant and beneficial for predictions for such nodes. In graph classification, graphs with many edges are likely to be less affected by interventions as well. Exponentiating the relative intervention effect ρ (n) ij by d(n) is intended to adjust for these biases. We have experimented both with and without raising ρ to the factor of d, but have found empirically better results with. We do not have a rigorous explanation as to why, although we suspect the correlated nature of edges implies some amount of shrinkage is necessary. The link function σ : R → (0, 1) maps its input to the support of the distribution of attention coefficients. The predictions ŷ(n) and ŷ(n) \ij are generated from the graph attention network being trained with CAR. We emphasize that ŷ(n) and ŷ(n) \ij are both computed during the same training run, rather than in two separate runs. is the depth of the GNN, and b is the mean in-degree. We note that, in practice, GNNs usually have L ≤ 2 layers since greater depth increases the computational cost and the risk of oversmoothing (Li et al., 2018; Chen et al., 2020; Topping et al., 2022) . The contribution of this intervention overlap search step would, therefore, be minimal for most GNNs. We are also able to mitigate much of the computational cost by parallelizing these searches. Further speedups can achieved by identifying non-overlapping interventions as a preprocessing step. In summary, we have found CAR to have a modest effect on the scalability of graph attention methods, with training runtimes that are only increased by 1.5-2 fold (Appendix A.3).

3. RELATED WORK

The performance gains associated with graph attention networks have led to a number of efforts to enhance and better understand graph attention mechanisms (Lee et al., 2019a) . One category of methods aims to improve the expressive power of graph attention mechanisms by supplementing a given prediction objective with a supervised attention mechanism (Feng et al., 2021) , or a selfsupervised connectivity prediction approach (Kim & Oh, 2021) . A related set of methods leverage signals from interventions on graphs from a causal perspective to aid in training GNNs. One class of techniques performs interventions on nodes (Knyazev et al., 2019; Feng et al., 2021) . CAL (Sui et al., 2021) , a method designed for graph property prediction tasks, performs abstract interventions on representations of entire graphs instead of specific nodes or edges to identify the causally attended subgraph for a given task. Specifically, it uses these interventions to achieve robustness rather than directly leveraging information about the effect of these interventions on model predictions. Our intervention-oriented framework can also be understood as a graph structure perturbation method. Perturbation methods can be broadly split into three different categories: graph data augmentation (Zhao et al., 2022) , structural graph rewiring (Rong et al., 2020) , and geometric graph rewiring (Topping et al., 2022) . Inspired by the success of data augmentation approaches in computer vision, graph data augmentation methods seek to generate new training samples through different augmentation techniques. One of the earliest methods is DropEdge (Rong et al., 2020) reduces overfitting by randomly selecting edges from a uniform distribution to delete. Other methods build on DropEdge and select edges according to additional constraints including geometrical invariants (Gao et al., 2021) , target probability distributions (Park et al., 2021) , and information criteria (Suresh et al., 2021) . EERM, a powerful invariance-based approach by Wu et al. (2022) takes a graph-editing approach to learn GNNs that are robust to distribution shifts in the data. Structural graph rewiring instead seeks to enforce structural priors such as sparsity or homophily during the graph alteration phase. Examples of these rewiring priors include fairness (Kose & Shen, 2022; Spinelli et al., 2022) , temporal structure (Wang et al., 2021) , predicted homophily (Chen et al., 2020) , sparsity (Jin et al., 2020; Zheng et al., 2020) , or information transfer efficiency (Klicpera et al., 2019) . Geometric approaches, instead, choose to view the graph as a discrete geometry and alter the connectivity according to the balanced Forman curvature (Topping et al., 2022) , the stochastic discrete Ricci flows (Bober et al., 2022) , commute times (Arnaiz-Rodríguez et al., 2022) , or algebraic connectivity (Arnaiz-Rodríguez et al., 2022) . All of these approaches are designed either for use either in self-supervised learning or in a task-agnostic fashion, and consider the input graph independently of the task at hand. To summarize, CAR introduces a combination of advances not previously reported: applicability to diverse attention architectures; task-based supervised regularization (rather than task-agnostic or self-supervised regularization) that leads to improved generalization; and a causal approach that scalably and directly relates an edge's importance to its attention coefficient, enhancing interpretability.

4.1. EXPERIMENTAL SETUP

We assessed the effectiveness of CAR by comparing the performance of a diverse range of models trained with and without CAR on 8 node classification datasets. Specifically, we aimed to assess the consistency of CAR's outperformance over matching baseline models across various graph attention mechanism and hyperparameter choices. Accordingly, we evaluated numerous combinations of such configurations (48 settings for each dataset and graph attention mechanism), rather than testing only a limited set of optimized hyperparameter configurations. The configurable model design and hyperparameter choices that we evaluated include the graph attention mechanism (GAT, GATv2, or Graph Transformer), the number of graph attention layers L = {1, 2}, the number of attention heads K = {1, 3, 5}, the number of hidden dimensions F ′ = {10, 25, 100, 200}, and the regularization strength λ = {0.1, 0.5, 1, 5}. See Appendix A.4 for details on the network architecture, hyperparameters, and training configurations.

4.2. NODE CLASSIFICATION

Datasets and evaluation: We used a total of 8 real-world node classification datasets of varying sizes and degrees of homophily: Cora, CiteSeer, PubMed, ogbn-arxiv, Chameleon, Squirrel, Cornell and Wisconsin. Each model was evaluated according to its accuracy on a held-out test set. We also evaluated the test cross-entropy loss, as it accounts for the full distribution of per-class predictions rather than just the highest-valued prediction considered in accuracy calculations. See Appendix A.5 for details on dataset statistics, splits, and references. Generalization performance: We compared both the test accuracy and the test loss of each model when trained with and without CAR. Across all model architecture and hyperparameter choices for a dataset, we applied the one-tailed paired Wilcoxon signed-rank test to quantify the overall outperformance of models trained with CAR against models trained without it. CAR resulted in higher test accuracy in 7 of the 8 node classification datasets and a lower test loss in all 8 datasets (Figure 2 , p < 0.05). We report the relative performance when averaging over all hyperparameter choices in Table 1 (test loss), Table 7 (test accuracy, in Appendix A.6), and Appendix A.7. We also observed that even small values of R were effective (Appendix A.8). We believe that this is due to there being an adequate number of causal-effect examples for regularization, as even in the minimum R = 1 case, we have roughly one causal-effect example per training example. We also compared CAR with Sui et al. (2021) 's CAL, a method for graph property prediction that relies on an alternative formulation of causal attention with interventions on implicit representations. We adapted CAL for node classification by removing its final pooling layer. CAR-trained models substantially outperformed CAL (Appendix A.9), suggesting that CAR's direct edge-intervention approach results in better generalization. Taken together, these results highlight the consistency of performance gains achieved with CAR and its broad applicability across across graph attention architectures and hyperparameters. 

4.3. MODEL INVESTIGATION AND ABLATION STUDIES

Impact of regularization strength We explored 4 settings of λ: {0.1, 0.5, 1, 5}. For 6 of the 8 node classification datasets, CAR models trained with the higher causal regularization strengths (λ = {1, 5}) demonstrated significantly larger reductions in test loss (p < 0.05, one-tailed Welch's t-test) compared to those trained with weaker regularization (λ = {0.1, 0.5}). Notably, all four of the datasets with lower homophily (Chamelon, Squirrel, Cornell and Wisconsin) displayed significantly larger reductions in test loss with the higher regularization strengths, suggesting that stronger regularization may contribute to improved generalization in such settings (Appendix A.10). Connection to spurious correlations and homophily Graph attention networks that are prone to spurious correlations mistakenly attend to parts of the graph that are irrelevant to their prediction task. To evaluate if CAR reduces the impact of such spurious correlations, we assessed if models trained with CAR more effectively prioritized relevant edges. For node classification tasks, the relevant neighbors to a given node are expected to be those that share the same label as that node. We, therefore, used the label agreement between nodes connected by an edge as a proxy for the edge's ground-truth relevance. We assigned a reference attention coefficient e ij to an edge based on label agreement in the target node's neighborhood: e ij = êij / k∈Nj êkj (Kim & Oh, 2021) . Here, êij = 1 if nodes i and j have the same label and êij = 0 otherwise. N j denotes the in-neighbors of node j. We then calculated the KL divergence of an edge's attention coefficient α i,j from its reference attention coefficient e ij and summarize a model's ability to identify relevant edges as the mean of these KL divergence values across the edges in the held-out test set. We compared these mean KL divergences between baseline models trained without CAR and models trained with CAR across the same broad range of model architecture and hyperparameter choices described above. We found that CAR-trained models consistently yielded lower mean KL divergence values than models trained without CAR for 6 of 8 node classification datasets (Figure 3 ,p < 0.05, one-tailed paired Wilcoxon signed-rank test). Notably, this enhanced prioritization of relevant edges was achieved without explicitly optimizing label agreement during training and is an inherent manifestation of aligning attention with the node classification tasks' causal effects. Low homophily graphs are associated with greater proportions of task-irrelevant edges and thus may introduce more spurious correlations (Zheng et al., 2022) . We reasoned that CAR's relative effectiveness should be greater in such settings. We assessed this by evaluating CAR on 33 synthetic Cora datasets with varying levels of homophily (Zhu et al., 2020) . We observed that CAR-trained models outperformed baseline models most substantially in low homophily settings, with performance gains generally increasing with decreasing homophily (Appendix A.11). Altogether, these results demonstrate that CAR not only more accurately prioritizes the edges that are most relevant to the desired task but also highlights its utility in low homophily settings most prone to spurious correlations.

Comparison to homophily-based regularization

We next assessed if the performance gains by our causal approach could be replicated by a non-causal approach that systematically aligns attention coefficients with a generic measure of homophily. To do so, we performed an ablation study in which we replace the causal effects c (n) ij computed from the network being trained (Equation 3) with an alternative score derived from a homophily-based classification scheme. Briefly, this scheme entails assigning a prediction for each node based on the counts of its neighbors' labels (see Appendix A.12 for details). By regularizing attention with respect to this homophily-based classification scheme, the attention mechanism for a network will be guided no longer by causal effects associated with the network but rather this separate measure of homophily. We compared CAR-trained models with those trained with homophily-based regularization by evaluating the consistency of their test loss improvements relative to the baseline models without regularization. We used the one-tailed paired Wilcoxon signed-rank test to evaluate the significance of the regularized models' test loss improvements across the set of model architecture and hyperparameter choices, focusing on models trained with the higher λ ∈ {1, 5} regularization strengths. Models trained when regularizing attention coefficients with the homophily-based scheme underperformed those trained with CAR in 7 of the 8 datasets (Table 2 ). Interestingly, the homophily-based -based regularization showed an overall gain in performance relative to the baseline models (i.e., those trained without any regularization), suggesting that even non-specific regularization can be somewhat useful for training attention. Overall, these results demonstrate the effectiveness of using CAR to improve generalization performance for many node classification tasks. Additional applications To explore the utility of CAR as a general-purpose framework, we employed CAR for graph pruning. CAR directly uses local pruning (i.e. edge interventions) to guide the training of graph attention in a manner that down-weights task-irrelevant edges. As such, we reasoned that attention coefficients produced by CAR-trained models could be used to prune taskirrelevant edges (see Appendix A.13 for more details). In this approach, we used CAR's edge pruning procedure as a pre-processing step for training and inference with GCNs, which are more scalable than graph attention networks (Rossi et al., 2020) but indiscriminate in how they aggregate information over node neighborhoods. We found that using CAR-guided pruning improved the test accuracy of GCNs, outperforming vanilla GCNs trained and evaluated over the full graph as well as GCNs trained and evaluated on graphs pruned with baseline graph attention mechanisms. These preliminary results open the door for further exploration of CAR's utility on these tasks.

4.4. INTERPRETING ATTENTION COEFFICIENTS: QUALITATIVE ANALYSIS

In addition to providing a robust way to increase both model quality and generalization, we explored the interpretability of CAR attention coefficients in a post hoc analysis. Here, we evaluated the edgewise difference of the attention coefficients between our method and a baseline GAT applied to the ogbn-arxiv dataset. In this dataset, nodes represent arXiv papers, edges are citation links, and the prediction task is to classify papers into their subject areas. We manually reviewed citations that were up/down-weighted by CAR-trained models and observed that these citations broadly fit into one of three categories: (i) down-weighting self-citations, (ii) down-weighting popular "anchor" papers, or (iii) upweighting topically-narrow papers with few citations. In the first case, we found that CAR down-weights edges associated with self-citations (Table 3 ). The causal story here is clear-machine learning is a fast-moving field with authors moving into the field and changing specialties as those specialties are born. Because of this, the narrative arc that a set of authors constructs to present this idea can include citations to their own previous work from different sub-fields. While these citations help situate the work within the broader literature and can provide background that readers might find valuable, they are not relevant to the subject area prediction task. In the second case, we found that CAR down-weights edges directed towards popular or otherwise seminal "anchor" papers (Appendix A.14, Table 11 ). These papers tends to be included in introductions to provide references for common concepts or methods, such as Adam, ResNet, and ImageNet. They are also widely cited across subject areas and hence have little bearing on the subject area prediction task. Notably, CAR does not simply learn to ignore edges from high-degree nodes. For the Word2Vec paper, we observed notable increases in attention coefficients for edges connecting it to multiple highly related papers, including a 2.5 × 10 6 % increase for Efficient Graph Computation for Node2Vec and a 2.0 × 10 6 % increase for Multi-Dimensional Explanation of Reviews . In the final case, we observed that CAR up-weighted edges directed towards deeply related but largely unnoticed papers (Appendix A.14,Table 12 ). In our manual exploration of the data, we observed that these papers are those that are close to the proposed method. These papers are the type that are often found only after a thorough literature review. Such edges should play a key role in predicting a paper's topic and should be up-weighted.

5. CONCLUSION

We introduced CAR, an invariance principle-based causal regularization scheme that can be applied to graph attention architectures. Unlike other invariance-based approaches (Wu et al., 2022) , our focus is on scalably improving overall generalization rather than handling distribution shifts. Towards that, we introduce an efficient scheme to directly intervene on multiple edges in parallel. Applying it to both homophilic and heterophilic node-classification tasks, we found accuracy improvements and loss reductions in almost all circumstances. We performed ablation studies for a deeper understanding, and found that CAR aligns attention with task-specific homophily and does so better than a homophily-based regularizer. A qualitative review also suggested that the attention-weight changes produced by CAR are intuitive and interpretable. Understanding how, and improving what, GNNs learn remains a major open problem and is an active area of research. For instance, Zheng et al. (2022) have discussed the challenges that GNNs face when handling low-homophily graphs or when different tasks could be specified on the same underlying graph (e.g., predicting citation year vs. topic in obgn-arxiv). Towards this, our method provides a principled and scalable approach to align attention coefficients with the relevant task. Our work bridges two families of techniques: attention regularization and causal interventions. The synthesis of these techniques is not only a promising direction for enhancing the performance and interpretability of graph attention but also opens the door for leveraging similar techniques for general GNNs without attention as well. Lastly, while our graph pruning results are preliminary, they also suggest a promising direction for future work on scaling CAR-based insights to web-scale graphs.

6. REPRODUCIBILITY STATEMENT

To ensure the reproducibility of the results in this paper, we have included the source code for our method as supplementary materials. The datasets used in this paper are all publicly available, and we also use the publicly available train/validation/test splits for these datasets. We provide details on these datasets in the Appendix and have provided references to them in both the main text and the Appendix. In addition, we have provided detailed descriptions of the experimental setup, model training schemes, model architecture design choices, and hyperparameter choices in the "Experimental Setup" section as well as in Appendix A.4.

A APPENDIX

A.1 GRAPH ATTENTION ARCHITECTURE VARIANTS Table 4 : Attention coefficient calculation across graph attention architecture variants GAT (Velickovic et al., 2018) e (Brody et al., 2022) e ij = ⃗ a T LeakyReLU W ⃗ h i ||W ⃗ h j Graph Transformer (Shi et al., 2021 ) ij = LeakyReLU ⃗ a T [W ⃗ h i ||W ⃗ h j ] GATv2 e ij = (W ⃗ h i ) T (W ⃗ h j ) √ F ′ A.2 CAR ALGORITHM Algorithm 1 CAR Framework Input: Training set D train , validation set D val , model M, regularization strength λ repeat for each mini-batch {B k = {n (k) j } b k j=1 } do Prediction loss: L p ← 1 |B k | n∈B k ℓ p (ŷ (n) , y (n) ) procedure EDGE INTERVENTION Causal attention loss: L c ← 0 for round r ← 1 to R do Set of edge interventions S(r) ← {} for each entity {n (k) j } b k j=1 do Sample edge (i, j) ∼ E n (k) j ▷ E n (k) j set of edges related to entity n (k) j if (i, j) independent of S(r) then ▷ See "Edge intervention procedure" S(r) ← S(r) ∪ (n (k) j , i, j) ▷ Add edge to set of edge interventions Compute causal effect c (n) ij ← σ ρ (n) ij d(n) -1 ▷ Equation 5 end if end for L c ← L c + 1 R 1 |S (r) | (n,i,j)∈S (r) ℓ c α (n) ij , c (n) ij ▷ Equation 3 end for end procedure Total loss: L = L p + λL c Update model parameters to minimize L end for until Convergence criterion ▷ We use convergence of the validation prediction loss. The GNN model used for node classification tasks takes as input the original node features ⃗ x i ∈ R d and applies a non-linear projection to these features to yield a set of hidden features ⃗ h i = LeakyReLU W 1 ⃗ x i +b 1 , where W 1 ∈ R F ×d and b 1 ∈ R F . These hidden features are then passed through L graph attentional layers of the same chosen architecture, yielding new hidden features per node of the same dimensionality ⃗ h ′ i ∈ R F . The pre-and post-graph attention layer hidden features are then concatenated

A.3 RUNTIME STATISTICS

[ ⃗ h i || ⃗ h ′ i ], after which a final linear layer and softmax transformation σ softmax (•) are applied to produce the prediction output ŷi = σ softmax W 2 LeakyReLU [ ⃗ h i || ⃗ h ′ i ] + b 2 . Here, ŷi ∈ R C , W 2 ∈ R 2F ×C , and b 2 ∈ R F , where C is the number of classes in the classification task. Models were implemented in PyTorch and PyTorch Geometric (Fey & Lenssen, 2019) . Self-loops were not included in the graph attention layers; otherwise, default PyTorch Geometric parameter settings were used for the graph attention layers.

A.4.2 TRAINING DETAILS

We used cross-entropy loss for the prediction loss ℓ p (•, •) and binary cross-entropy loss for the causal regularization loss ℓ c (•, •). The link function σ(•) was chosen to be the sigmoid function with temperature T = 0.1. Unless otherwise specified, we performed R = 5 rounds of edge interventions per mini-batch when training with CAR. All models were trained using the Adam optimizer with a learning rate of 0.01 and mini-batch size of 10,000. Each dataset was partitioned into training, validation, and test splits in line with previous work (Appendix A.5), and early stopping was applied during training with respect to the validation loss. Training was performed on a single NVIDIA Tesla T4 GPU.

A.5 DATASETS

We provide overviews of the various node classification datasets along with accompanying statistics in Table 6 . For all datasets, we use the publically available train/validation/test splits that accompany these datasets. Planetoid: The Cora, CiteSeer, and PubMed datasets are citation networks from Yang et al. (2016) . Nodes represent documents and directed edges represent citation links. Nodes are featurized as bag-of-word representations of their respective documents. The prediction task for this dataset is to classify a given paper into its respective subject area. ogbn-arxiv: The ogbn-arxiv dataset is a citation network between computer science arXiv paper indexed by MAG (Hu et al., 2020) . Nodes represent papers and a node's features are the mean embeddings of words in its corresponding paper's title and abstract. Edges are directed and represent a citation by one paper of another. The prediction task for this dataset is to predict the subject area of a given arXiv paper. Wikipedia: The Chameleon and Squirrel datasets are Wikipedia networks from Rozemberczki et al. (2021) , in which nodes represent web pages and edges represent hyperlinks between them. Nodes are featurized as bag-of-word representations of important nouns in their respective Wikipedia pages. Average monthly traffic of web pages are converted into categories, and the prediction task is to assign a given page to its corresponding category. WebKB: The Cornell and Wisconsin datasets are networks of web pages from various computer science departments, in which nodes represent web pages and edges are hyperlinks between them. Node features are bag-of-word representations of their respective web pages, and the prediction task is to assign a given web page to the category that describes its content. A.9 COMPARISON OF CAR-TRAINED MODELS WITH CAL MODELS CAL is an approach for identifying causally attended subgraphs for graph prediction tasks that leverages causal interventions on graph representations to achieve robustness of model predictions (Zhao et al., 2021) . While CAL and CAR have related goals of enhancing graph attention using concepts from causal theory, CAL uses abstract perturbations on graph representation to perform causal interventions while we propose an edge intervention strategy that enables causal effects to be computed scalably. In addition, CAL is designed to identify causally attended subgraphs for graph property prediction tasks, while our work primarily focuses on node classification tasks. Furthermore, CAL uses interventions to achieve robustness and does not directly leverage the effects of interventions on model predictions during training. Despite these differences, we sought to determine whether the causal principles underlying CAL could be effectively applied to the various node classification tasks evaluated in our paper. We modified the CAL architecture to make it suitable for node prediction tasks by simply removing the final pooling layer that aggregates node representations within each graph directly upstream of a classifier, thus enabling node-level prediction. We evaluated the CausalGAT model from CAL using all combinations of the following hyperparameter choices: F ′ = {128, 256}, K = {1, 2, 4}, L = {1, 2}, λ 1 = {0.2, 0.4, 0.6, 0.8, 1}, and λ 2 = {0.2, 0.4, 0.6, 0.8, 1}, where F ′ refers to the number of hidden dimensions, K is the number of attention heads, L is the number of GNN layers, and λ 1 and λ 2 are CAL-specific hyperparameters. For each dataset, we report the maximum test accuracy observed for the CAL CausalGAT across all combinations of these hyperparameter choices. We compare these test accuracies from the CAL CausalGAT models with the test accuracies from CARtrained models averaged over all hyperparameter choices, which also appear above in Appendix A.6. To assess the relationship between the effectiveness of CAR and the homophily of a dataset, we obtained a set of synthetic Cora datasets from (Zhu et al., 2020) . These synthetic datasets are modified versions of the original Cora dataset that feature varying levels of edge homophily h, which is defined as the fraction of edges in a graph which connect nodes that have the same class label. Here, E is the set of edges, y u is the class label for node u, and y v is the class label for node v. h = |{(u, v) : (u, v) ∈ E ∧ y u = y v }| |E| We evaluated 33 synthetic Cora datasets that spanned 11 different settings for h, each of which were represented by 3 replicate datasets. For each of these datsets, we performed a similar analysis as above, in which we aimed to evaluate the consistency of improvements in test loss using CAR across a number of graph attention and hyperparameter choices. We evaluated the GAT, GATv2, and Transformer graph attention layers along with all combinations of the following sets of hyperparameter choices: F ′ = {100}, λ = {1, 5}, K = {1, 3, 5}, L = {1, 2}. We then performed a one-tailed paired Wilcoxon rank-sum test to quantify the consistency of CAR-trained models' improvement in test loss over baseline models trained without CAR. ij values from the original implementation of CAR. Otherwise, the training procedure for a network trained with this neighbor voting scheme is exactly the same as for training with CAR. We note that, for a given node, calculating the intervention-affected prediction ŷ(n) \ij simply entails updating the normalized counts of the class labels from the node's remaining neighbors after the intervention. A.13 CAR-GUIDED GRAPH REWIRING While graph attention networks have demonstrated notable performance gains, its inclusion of graph attention layers currently limits its use in large-scale applications compared to GCNs, for which a number of advances in scalability have been made (Rossi et al., 2020) . To leverage the advantages of CAR in graph attention alongside the scalability of GCNs, we explored a graph rewiring approach based on CAR-guided edge pruning. For a given dataset, we first use a trained graph attention network to assign an attention weight for each edge in the training and validation sets, after which edges with attention weights below a threshold α T are pruned. A GCN is then trained on the rewired training set with early stopping imposed with respect to the validation loss on the rewired validation set. The trained GCN is then evaluated on a similarly rewired test set. We use a similar network architecture for the GCN as the various graph attention networks described in Appendix A.4, with the graph attention layers replaced with graph convolutional layers. We set the number of hidden dimensions in the GCN models to be F ′ = 100. For the Chameleon dataset, we identified the hyperparameter settings that contributed to the highest validation accuracy for each of the one-layer GAT, GATv2, and Transformer CAR-trained models. We then trained GCN models on graphs that are pruned based on each of these models' attention mechanisms. We also pruned graphs using the counterparts of these models that were trained without CAR and trained another set of GCN models on these pruned graphs. We evaluated the test accuracy of the GCN models when performing this procedure across various attention thresholds (Figure 7 ). We observed that training and evaluating GCN models on pruned graphs contributed to enhanced test accuracy compared to the baseline GCN models that were trained and evaluated on the original graph. Furthermore, we compared GCN models trained and evaluated on CAR-guided pruned graphs against similar GCN models trained and evaluated on graphs pruned without CAR by computing the area under the curve (AUC) associated with the test accuracies at various attention thresholds. Each AUC was calculated as the area below its models' test accuracies line and above the baseline GCN models' test accuracy. CAR-guided graph pruning was associated with higher AUC values across the three graph attention mechanisms, demonstrating the potential for CAR's utility in graph pruning tasks.



CAR increases the computational cost of training in two ways: (i) additional evaluations of the loss function due to the causal effect calculations, and (ii) searches to ensure independent interventions. CAR performs O(RN ) interventions, where R is the number of interventions per entity and N is the number of entities in the training set. Because our edge intervention procedure ensures that the sampled interventions per round are independent, the causal effects of these interventions can be computed in parallel. In addition, if edge interventions were sampled uniformly across the graph, ensuring the independence each intervention would require a L-layer deep BFS that has time complexity O(b L ), resulting in a worst case time complexity of O(RN b L ), where L

Figure 2: Test accuracy and negative loss on 8 node classification datasets. Each point corresponds to a comparison between a baseline model trained without CAR and an identical model trained with CAR. The point size represents the magnitude of the λ value chosen for the CAR-trained model. p-values are computed from one-tailed paired Wilcoxon signed-rank tests evaluating the improvement of CAR-trained models over the baseline models.

Figure 3: Coherence between attention coefficients and label agreement for CAR-trained models, compared to baseline. Lower KL divergence implies greater coherence. Point colors and sizes have the same meaning as in Figure 2. p-values are computed from one-tailed paired Wilcoxon signed-rank tests evaluating the improvement of CAR-trained models over the baseline models.

Figure 4: Percent change in test accuracy for models trained with CAR. Each boxplot represents all combinations of the three graph attention layers (GAT, GATv2, Transformer) and the following sets of hyperparameter choices: λ ∈ {1, 5}, L = {1, 2}, K = {3}, F ′ = {100, 200}.

Figure 5: Test loss reduction for CAR-trained models across regularization strengths. p-values are computed from one-tailed t-tests evaluating the significance of the test loss reductions for the λ ∈ {1, 5} CAR-trained models being greater than those of the λ ∈ {0.1, 0.5} CAR-trained models.CoraCiteSeer PubMed ogbn-arxiv

Figure 6: Generalization performance of CAR-trained models compared to baseline across various levels of edge homophily (-log 10 (p), one-tailed paired Wilcoxon rank-sum test).

Test loss on 8 node classification datasets Transformer 3.43 ± 1.61 9.78 ± 8.27 1.79 ± 0.88 1.50 ± 0.03 1.38 ± 0.06 1.50 ± 0.02 1.22 ± 0.25 1.03 ± 0.28 Transf. + CAR 1.71 ± 0.50 3.92 ± 2.25 1.60 ± 0.59 1.50 ± 0.03 1.33 ± 0.04 1.48 ± 0.03 1.07 ± 0.12 0.86 ± 0.13

Ablation

Down-weighted self-citations.

Training times for node classification datasets

Node Classification Dataset Statistics

Test accuracy on 8 node classification datasets

Average percent change in test accuracy

Average percent change in test loss

Test accuracy on 8 node classification datasets compared to CAL ± 0.12 0.49 ± 0.06 0.71 ± 0.01 0.56 ± 0.01 0.50 ± 0.01 0.37 ± 0.01 0.67 ± 0.06 0.78 ± 0.03 GATv2 + CAR 0.61 ± 0.12 0.50 ± 0.06 0.71 ± 0.01 0.56 ± 0.01 0.51 ± 0.01 0.37 ± 0.01 0.70 ± 0.03 0.79 ± 0.02Transformer + CAR 0.65 ± 0.04 0.47 ± 0.09 0.72 ± 0.01 0.56 ± 0.01 0.50 ± 0.01 0.37 ± 0.01 0.71 ± 0.02 0.79 ± 0.01

