MBRAIN: A MULTI-CHANNEL SELF-SUPERVISED LEARNING FRAMEWORK FOR BRAIN SIGNALS

Abstract

Brain signals are important quantitative data for understanding physiological activities and diseases of human brain. Meanwhile, rapidly developing deep learning methods offer a wide range of opportunities for better modeling brain signals, which has attracted considerable research efforts recently. Most existing studies pay attention to supervised learning methods, which, however, require high-cost clinical labels. In addition, the huge difference in the clinical patterns of brain signals measured by invasive (e.g., SEEG) and non-invasive (e.g., EEG) methods leads to the lack of a unified method. To handle the above issues, in this paper, we propose to study the self-supervised learning (SSL) framework for brain signals that can be applied to pre-train either SEEG or EEG data. Intuitively, brain signals, generated by the firing of neurons, are transmitted among different connecting structures in human brain. Inspired by this, we propose to learn implicit spatial and temporal correlations between different channels (i.e., contacts of the electrode, corresponding to different brain areas) as the cornerstone for uniformly modeling different types of brain signals. Specifically, we capture the temporal correlation by designing the delayed-time-shift prediction task; we represent the spatial correlation by a graph structure, which is built with proposed multi-channel CPC whose goal is to maximize the mutual information of each channel and its correlated ones. We further theoretically prove that our design can lead to a better predictive representation and propose the instantaneou-time-shift prediction task based on it. Finally, replace-discriminative-learning task is designed to preserve the characteristics of each channel. Extensive experiments of seizure detection on both EEG and SEEG large-scale real-world datasets demonstrate our model outperforms several state-of-the-art time series SSL and unsupervised models.

1. INTRODUCTION

Brain signals are foundational quantitative data for the study of human brain in the field of neuroscience. The patterns of brain signals can greatly help us to understand the normal physiological function of the brain and the mechanism of related diseases. There are many applications of brain signals, such as cognitive research (Ismail & Karwowski, 2020; Kuanar et al., 2018) , emotion recognition (Song et al., 2020; Chen et al., 2019) , neurological disorders (Alturki et al., 2020; Yuan et al., 2019) and so on. Brain signals can be measured by noninvasive or invasive methods (Paluszek et al., 2015) . The noninvasive methods, like electroencephalography (EEG), cannot simultaneously consider temporal and spatial resolution along with the deep brain information, but they are easier to implement without any surgery. As for invasive methods like stereoelectroencephalography (SEEG), they require extra surgeries to insert the recording devices, but have access to more precise and higher signal-to-noise data. For both EEG and SEEG data, there are multiple electrodes with contacts (also called channels) that are sampled at a fixed frequency to record brain signals. Recently, discoveries in the field of neuroscience have inspired advances of deep learning techniques, which in turn promotes neuroscience research. According to the literature, most deep learning-based studies of brain signals focus on supervised learning (Shoeibi et al., 2021; Rasheed et al., 2020; Zhang et al., 2021; Craik et al., 2019) , which relies on a large number of clinical labels. However, obtaining accurate and reliable clinical labels requires a high cost. In the meantime, the emergence of self-supervised learning (SSL) and its great success (Chen & He, 2021; Grill et al., 2020; He et al., 2020; Brown et al., 2020; Devlin et al., 2018; Raffel et al., 2020; Van den Oord et al., 2018) makes it a predominant learning paradigm in the absence of labels. Therefore, some recent studies have introduced the means of SSL to extract the representations of brain signal data. For example, Banville et al. ( 2021) directly applies general SSL tasks to pre-train EEG data, including relative position prediction (Doersch et al., 2015) , temporal shuffling (Misra et al., 2016) and contrastive predictive coding (Van den Oord et al., 2018) . Mohsenvand et al. (2020) designs data augmentation methods, and extends the self-supervised model SimCLR (Chen et al., 2020) in computer vision to EEG data. In contrast to the numerous works investigating EEG, few studies focus on SEEG data. Martini et al. (2021) proposes a self-supervised learning model for real-time epilepsy monitoring in multimodal scenarios with SEEG data and video recordings. Despite the advances on representation learning of brain signals, two main issues remain to be overcome. Firstly, almost all existing methods are designed for a particular type of brain signal data, and there is a lack of a unified method for handling both EEG and SEEG data. The challenge mainly lies in the different clinical patterns of brain signals that need to be measured in different ways. On the one hand, EEG collects noisy and rough brain signals on the scalp; differently, SEEG collects deeper signals with more stereo spatial information, which indicates more significant differences of different brain areas (Perucca et al., 2014) . On the other hand, in contrast to EEG with a goldstandard collection location, the monitoring areas of SEEG vary greatly between patients, leading to different number and position of channels. Therefore, how to find the commonalities of EEG and SEEG data to design a unified framework is challenging. Another issue is about the gap between existing methods and the real-world applications. In clinical scenarios, doctors typically locate brain lesions by analyzing signal patterns of each channel and their holistic correlations. A straight-forward way for this goal is to model each of the channels separately by single-channel time series models, which, however, cannot exploit correlations between brain areas (Davis et al., 2020; Lynn & Bassett, 2019) . As for the existing multivariable time series models, most of them can only capture implicit correlation patterns (Zerveas et al., 2021; Chen & Shi, 2021) , whereas explicit correlations are required by doctors for identifying lesions. Moreover, although some graph-based methods have been proposed to explicitly learn correlations, they focus on giving an overall prediction for all channels at a time but overlook the prediction on one specific channel (Zhang et al., 2022; Shang et al., 2021) . Therefore, how to explicitly capture the spatial and temporal correlations while giving channel-wise prediction is another issue to be overcome. In this paper, our main contribution is to propose a multi-channel self-supervised learning framework MBrain, which can be generally applied for learning representations of both EEG and SEEG data. In addition, we pay special attention to its application in seizure detection. Based on domain knowledge and data observations, we propose to learn the correlation graph between channels as the common cornerstone for both two types of brain signals. In particular, we employ Contrastive Predictive Coding (CPC) (Van den Oord et al., 2018) as the backbone model of our framework by extending it for handling multi-channel data with theoretically guaranteed effectiveness. Based on the multichannel CPC, we propose the instantaneous time shift task to learn the spatial correlations between channels, and the delayed time shift task and the replace discriminative task are designed to capture the temporal correlation patterns and to preserve the characteristics of each channel respectively. Extensive experiments show that MBrain outperforms several state-of-the-art baselines on largescale real-world EEG and SEEG datasets for the seizure detection task.

2. PRELIMINARY: THEORETICAL ANALYSIS OF MULTI-CHANNEL CPC

We employ Contrastive Predictive Coding (CPC) (Van den Oord et al., 2018) as the basis of our framework. The pretext task of CPC is to predict low-level local representations by high-level global contextual representations c t at the t-th time step. Theoretically, the optimal InfoNCE loss proposed by CPC with N -1 negative samples L opt N is a lower bound of the mutual information between contextual semantic distribution p(c t ) and raw data distribution p(x t+k ), i.e., L opt N ≥ -I(x t+k ; c t ) + log N , where k is the prediction step size. CPC is originally designed for single-channel sequence data only, and there are two natural ways to extend single channel CPC to multi-channel version. The first one is to use CNNs with multiple kernels to encode all channels simultaneously, which cannot offer explicit correlation patterns for doctors to identify lesions. The second one is to train a shared CPC regarding all channels as one, which has no ability to capture the correlation patterns. Taking a comprehensive consideration, we propose multi-channel CPC in this paper. Our motivation is to explicitly aggregate the semantic information of multiple channels

