AUTOMATED DATA AUGMENTATIONS FOR GRAPH CLASSIFICATION

Abstract

Data augmentations are effective in improving the invariance of learning machines. We argue that the core challenge of data augmentations lies in designing data transformations that preserve labels. This is relatively straightforward for images, but much more challenging for graphs. In this work, we propose GraphAug, a novel automated data augmentation method aiming at computing label-invariant augmentations for graph classification. Instead of using uniform transformations as in existing studies, GraphAug uses an automated augmentation model to avoid compromising critical label-related information of the graph, thereby producing label-invariant augmentations at most times. To ensure label-invariance, we develop a training method based on reinforcement learning to maximize an estimated label-invariance probability. Experiments show that GraphAug outperforms previous graph augmentation methods on various graph classification tasks.

1. INTRODUCTION

Many real-world objects, such as molecules and social networks, can be naturally represented as graphs. Developing effective classification models for these graph-structured data has been highly desirable but challenging. Recently, advances in deep learning have significantly accelerated the progress in this direction. Graph neural networks (GNNs) (Gilmer et al., 2017) , a class of deep neural network models specifically designed for graphs, have been widely applied to many graph representation learning and classification tasks, such as molecular property prediction (Wang et al., 2022b; Liu et al., 2022; Wang et al., 2022a; 2023; Yan et al., 2022) . However, just like deep models on images, GNN models can easily overfit and fail to achieve satisfactory performance on small datasets. To address this issue, data augmentations can be used to generate more data samples. An important property of desirable data augmentations is label-invariance, which requires that label-related information should not be compromised during the augmentation process. This is relatively easy and straightforward to achieve for images (Taylor & Nitschke, 2018), since commonly used image augmentations, such as flipping and rotation, can preserve almost all information of original images. However, ensuring label-invariance is much harder for graphs because even minor modification of a graph may change its semantics and thus labels. Currently, most commonly used graph augmentations (You et al., 2020) are based on random modification of nodes and edges in the graph, but they do not explicitly consider the importance of label-invariance. In this work, we propose GraphAug, a novel graph augmentation method that can produce labelinvariant augmentations with an automated learning model. GraphAug uses a learnable model to automate augmentation category selection and graph transformations. It optimizes the model to maximize an estimated label-invariance probability through reinforcement learning. Experimental results show that GraphAug outperforms prior graph augmentation methods on multiple graph classification tasks. The codes of GraphAug are available in DIG (Liu et al., 2021) library.

2. BACKGROUND AND RELATED WORK 2.1 GRAPH CLASSIFICATION WITH NEURAL NETWORKS

In this work, we study the problem of graph classification. Let G = (V, E, X) be an undirected graph, where V is the set of nodes and E is the set of edges. The node feature matrix of the graph G is X ∈ R |V |×d where the i-th row of X denotes the d-dimensional feature vector for the i-th node in G. For a graph classification task with k categories, the objective is to learn a classification model f : G → y ∈ {1, ..., k} that can predict the categorical label of G.Recently, GNNs (Kipf & Welling, 2017; Veličković et al., 2018; Xu et al., 2019; Gilmer et al., 2017; Gao & Ji, 2019) have shown great success in various graph classification problems. Most GNNs use the message passing mechanism to learn graph node embeddings. Formally, the message passing for any node v ∈ V at the ℓ-th layer of a GNN model can be described aswhere N (v) denotes the set of all nodes connected to the node v in the graph G, h ℓ v is the embedding outputted from the ℓ-th layer for v, m ℓ jv is the message propagated from the node j to the node v at the ℓ-th layer and is usually a function of h ℓ-1 v and h ℓ-1 j . The aggregation function AGG(•) maps the messages from all neighboring nodes to a single vector, and the function UPDATE(•) updates h ℓ-1 v to h ℓ v using this aggregated message vector. Assuming that the GNN model has L layers, the graph representation h G is computed by a global pooling function READOUT over all node embeddings as(2) Afterwards, h G is fed into a multi-layer perceptron (MLP) model to compute the probability that G belongs to each of the categories {1, ..., k}.Despite the success of GNNs, a major challenge in many graph classification problems is data scarcity. For example, GNNs have been extensively used to predict molecular properties from graph structures of molecules. However, the manual labeling of molecules usually requires expensive wet lab experiments, so the amount of labeled molecule data is usually not large enough for expressive GNNs to achieve satisfactory prediction accuracy. In this work, we address this data scarcity challenge with data augmentations. We focus on designing advanced graph augmentation strategies to generate more data samples by performing transformations on data samples in the dataset.

2.2. DATA AUGMENTATIONS

Data augmentations have been demonstrated to be effective in improving the performance for image and text classification. For images, various image transformation or distortion techniques have been proposed to generate artificial image samples, such as flipping, cropping, color shifting (Krizhevsky et al., 2012), scaling, rotation, and elastic distortion (Sato et al., 2015; Simard et al., 2003) . And for texts, useful augmentation techniques include synonym replacement, positional swaps (Ratner et al., 2017a) , and back translation (Sennrich et al., 2016) . These data augmentation techniques have been widely used to reduce overfitting and improve robustness in training deep neural network models.In addition to hand-crafted augmentations, automating the selection of augmentations with learnable neural network model has been a recent emerging research area. Ratner et al. (2017b) selects and composes multiple image data augmentations using an LSTM (Hochreiter & Schmidhuber, 1997) model, and proposes to make the model avoid producing out-of-distribution samples through adversarial training. Cubuk et al. (2019) proposes AutoAugment, which adopts reinforcement learning based method to search optimal augmentations maximizing the classification accuracy. To speed up training and reduce computational cost, a lot of methods have been proposed to improve AutoAugment through either faster searching mechanism (Ho et al., 2019; Lim et al., 2019) , or advanced optimization methods (Hataya et al., 2020; Li et al., 2020; Zhang et al., 2020) .

2.3. DATA AUGMENTATIONS FOR GRAPHS

While image augmentations have been extensively studied, doing augmentations for graphs is much more challenging. Images are Euclidean data formed by pixel values organized in matrices. Thus,

