CROSS-NODE FEDERATED GRAPH NEURAL NETWORK FOR SPATIO-TEMPORAL DATA MODELING Anonymous authors Paper under double-blind review

Abstract

Vast amount of data generated from networks of sensors, wearables, and the Internet of Things (IoT) devices underscores the need for advanced modeling techniques that leverage the spatio-temporal structure of decentralized data due to the need for edge computation and licensing (data access) issues. While federated learning (FL) has emerged as a framework for model training without requiring direct data sharing and exchange, effectively modeling the complex spatiotemporal dependencies to improve forecasting capabilities still remains an open problem. On the other hand, state-of-the-art spatio-temporal forecasting models assume unfettered access to the data, neglecting constraints on data sharing. To bridge this gap, we propose a federated spatio-temporal model -Cross-Node Federated Graph Neural Network (CNFGNN) -which explicitly encodes the underlying graph structure using graph neural network (GNN)-based architecture under the constraint of cross-node federated learning, which requires that data in a network of nodes is generated locally on each node and remains decentralized. CNFGNN operates by disentangling the temporal dynamics modeling on devices and spatial dynamics on the server, utilizing alternating optimization to reduce the communication cost, facilitating computations on the edge devices. Experiments on the traffic flow forecasting task show that CNFGNN achieves the best forecasting performance in both transductive and inductive learning settings with no extra computation cost on edge devices, while incurring modest communication cost.

1. INTRODUCTION

Modeling the dynamics of spatio-temporal data generated from networks of edge devices or nodes (e.g. sensors, wearable devices and the Internet of Things (IoT) devices) is critical for various applications including traffic flow prediction (Li et al., 2018; Yu et al., 2018) , forecasting (Seo et al., 2019; Azencot et al., 2020) , and user activity detection (Yan et al., 2018; Liu et al., 2020) . While existing works on spatio-temporal dynamics modeling (Battaglia et al., 2016; Kipf et al., 2018; Battaglia et al., 2018) assume that the model is trained with centralized data gathered from all devices, the volume of data generated at these edge devices precludes the use of such centralized data processing, and calls for decentralized processing where computations on the edge can lead to significant gains in improving the latency. In addition, in case of spatio-temporal forecasting, the edge devices need to leverage the complex inter-dependencies to improve the prediction performance. Moreover, with increasing concerns about data privacy and its access restrictions due to existing licensing agreements, it is critical for spatio-temporal modeling to utilize decentralized data, yet leveraging the underlying relationships for improved performance. Although recent works in federated learning (FL) (Kairouz et al., 2019) provides a solution for training a model with decentralized data on multiple devices, these works either do not consider the inherent spatio-temporal dependencies (McMahan et al., 2017; Li et al., 2020b; Karimireddy et al., 2020) or only model it implicitly by imposing the graph structure in the regularization on model weights (Smith et al., 2017) , the latter of which suffers from the limitation of regularization based methods due to the assumption that graphs only encode similarity of nodes (Kipf & Welling, 2017) , and cannot operate in settings where only a fraction of devices are observed during training (inductive learning setting). As a result, there is a need for an architecture for spatio-temporal data modeling which enables reliable computation on the edge, while maintaining the data decentralized. To this end, leveraging recent works on federated learning (Kairouz et al., 2019) , we introduce the cross-node federated learning requirement to ensure that data generated locally at a node remains decentralized. Specifically, our architecture -Cross-Node Federated Graph Neural Network (CN-FGNN), aims to effectively model the complex spatio-temporal dependencies under the cross-node federated learning constraint. For this, CNFGNN decomposes the modeling of temporal and spatial dependencies using an encoder-decoder model on each device to extract the temporal features with local data, and a Graph Neural Network (GNN) based model on the server to capture spatial dependencies among devices. As compared to existing federated learning techniques that rely on regularization to incorporate spatial relationships, CNFGNN leverages an explicit graph structure using a graph neural networkbased (GNNs) architecture, which leads to performance gains. However, the federated learning (data sharing) constraint means that the GNN cannot be trained in a centralized manner, since each node can only access the data stored on itself. To address this, CNFGNN employs Split Learning (Singh et al., 2019) to train the spatial and temporal modules. Further, to alleviate the associated high communication cost incurred by Split Learning, we propose an alternating optimization-based training procedure of these modules, which incurs only half the communication overhead as compared to a comparable Split Learning architecture. Here, we also use Federated Averaging (FedAvg) (McMahan et al., 2017) to train a shared temporal feature extractor for all nodes, which leads to improved empirical performance. Our main contributions are as follows : 1. We propose Cross-Node Federated Graph Neural Network (CNFGNN), a GNN-based federated learning architecture that captures complex spatio-temporal relationships among multiple nodes while ensuring that the data generated locally remains decentralized at no extra computation cost at the edge devices. 2. Our modeling and training procedure enables GNN-based architectures to be used in federated learning settings. We achieve this by disentangling the modeling of local temporal dynamics on edge devices and spatial dynamics on the central server, and leverage an alternating optimization-based procedure for updating the spatial and temporal modules using Split Learning and Federated Averaging to enable effective GNN-based federated learning. 3. We demonstrate that CNFGNN achieves the best prediction performance (both in transductive and inductive settings) at no extra computation cost on edge devices with modest communication cost, as compared to the related techniques on a traffic flow prediction task.

2. RELATED WORK

Our method derives elements from graph neural networks, federated learning and privacy-preserving graph learning, we now discuss related works in these areas in relation to our work. Graph Neural Networks (GNNs). GNNs have shown their superior performance on various learning tasks with graph-structured data, including graph embedding (Hamilton et al., 2017) , node classification (Kipf & Welling, 2017), spatio-temporal data modeling (Yan et al., 2018; Li et al., 2018; Yu et al., 2018) and multi-agent trajectory prediction (Battaglia et al., 2016; Kipf et al., 2018; Li et al., 2020a) . Recent GNN models (Hamilton et al., 2017; Ying et al., 2018; You et al., 2019; Huang et al., 2018) also have sampling strategies and are able to scale on large graphs. While GNNs enjoy the benefit from strong inductive bias (Battaglia et al., 2018; Xu et al., 2019) , most works require centralized data during the training and the inference processes. Federated Learning (FL). Federated learning is a machine learning setting where multiple clients train a model in collaboration with decentralized training data (Kairouz et al., 2019) . It requires that the raw data of each client is stored locally without any exchange or transfer. However, the decentralized training data comes at the cost of less utilization due to the heterogeneous distributions of data on clients and the lack of information exchange among clients. Various optimization algorithms have been developed for federated learning on non-IID and unbalanced data (McMahan et al., 2017; Li et al., 2020b; Karimireddy et al., 2020) . Smith et al. (2017) propose a multi-task learning framework that captures relationships amongst data. While the above works mitigate the caveat of missing

