MBRAIN: A MULTI-CHANNEL SELF-SUPERVISED LEARNING FRAMEWORK FOR BRAIN SIGNALS

Abstract

Brain signals are important quantitative data for understanding physiological activities and diseases of human brain. Meanwhile, rapidly developing deep learning methods offer a wide range of opportunities for better modeling brain signals, which has attracted considerable research efforts recently. Most existing studies pay attention to supervised learning methods, which, however, require high-cost clinical labels. In addition, the huge difference in the clinical patterns of brain signals measured by invasive (e.g., SEEG) and non-invasive (e.g., EEG) methods leads to the lack of a unified method. To handle the above issues, in this paper, we propose to study the self-supervised learning (SSL) framework for brain signals that can be applied to pre-train either SEEG or EEG data. Intuitively, brain signals, generated by the firing of neurons, are transmitted among different connecting structures in human brain. Inspired by this, we propose to learn implicit spatial and temporal correlations between different channels (i.e., contacts of the electrode, corresponding to different brain areas) as the cornerstone for uniformly modeling different types of brain signals. Specifically, we capture the temporal correlation by designing the delayed-time-shift prediction task; we represent the spatial correlation by a graph structure, which is built with proposed multi-channel CPC whose goal is to maximize the mutual information of each channel and its correlated ones. We further theoretically prove that our design can lead to a better predictive representation and propose the instantaneou-time-shift prediction task based on it. Finally, replace-discriminative-learning task is designed to preserve the characteristics of each channel. Extensive experiments of seizure detection on both EEG and SEEG large-scale real-world datasets demonstrate our model outperforms several state-of-the-art time series SSL and unsupervised models.

1. INTRODUCTION

Brain signals are foundational quantitative data for the study of human brain in the field of neuroscience. The patterns of brain signals can greatly help us to understand the normal physiological function of the brain and the mechanism of related diseases. There are many applications of brain signals, such as cognitive research (Ismail & Karwowski, 2020; Kuanar et al., 2018) , emotion recognition (Song et al., 2020; Chen et al., 2019) , neurological disorders (Alturki et al., 2020; Yuan et al., 2019) and so on. Brain signals can be measured by noninvasive or invasive methods (Paluszek et al., 2015) . The noninvasive methods, like electroencephalography (EEG), cannot simultaneously consider temporal and spatial resolution along with the deep brain information, but they are easier to implement without any surgery. As for invasive methods like stereoelectroencephalography (SEEG), they require extra surgeries to insert the recording devices, but have access to more precise and higher signal-to-noise data. For both EEG and SEEG data, there are multiple electrodes with contacts (also called channels) that are sampled at a fixed frequency to record brain signals. Recently, discoveries in the field of neuroscience have inspired advances of deep learning techniques, which in turn promotes neuroscience research. According to the literature, most deep learning-based studies of brain signals focus on supervised learning (Shoeibi et al., 2021; Rasheed et al., 2020; Zhang et al., 2021; Craik et al., 2019) , which relies on a large number of clinical labels. However, obtaining accurate and reliable clinical labels requires a high cost. In the meantime, the emergence of self-supervised learning (SSL) and its great success (Chen & He, 2021; Grill et al., 2020; He et al., 2020; Brown et al., 2020; Devlin et al., 2018; Raffel et al., 2020; Van den Oord et al., 2018) makes it a predominant learning paradigm in the absence of labels. Therefore, some recent studies have introduced the means of SSL to extract the representations of brain signal data. For example, Banville et al. (2021) directly applies general SSL tasks to pre-train EEG data, including relative position prediction (Doersch et al., 2015) , temporal shuffling (Misra et al., 2016) and contrastive predictive coding (Van den Oord et al., 2018) . Mohsenvand et al. (2020) designs data augmentation methods, and extends the self-supervised model SimCLR (Chen et al., 2020) in computer vision to EEG data. In contrast to the numerous works investigating EEG, few studies focus on SEEG data. Martini et al. (2021) proposes a self-supervised learning model for real-time epilepsy monitoring in multimodal scenarios with SEEG data and video recordings. Despite the advances on representation learning of brain signals, two main issues remain to be overcome. Firstly, almost all existing methods are designed for a particular type of brain signal data, and there is a lack of a unified method for handling both EEG and SEEG data. The challenge mainly lies in the different clinical patterns of brain signals that need to be measured in different ways. On the one hand, EEG collects noisy and rough brain signals on the scalp; differently, SEEG collects deeper signals with more stereo spatial information, which indicates more significant differences of different brain areas (Perucca et al., 2014) . On the other hand, in contrast to EEG with a goldstandard collection location, the monitoring areas of SEEG vary greatly between patients, leading to different number and position of channels. Therefore, how to find the commonalities of EEG and SEEG data to design a unified framework is challenging. Another issue is about the gap between existing methods and the real-world applications. In clinical scenarios, doctors typically locate brain lesions by analyzing signal patterns of each channel and their holistic correlations. A straight-forward way for this goal is to model each of the channels separately by single-channel time series models, which, however, cannot exploit correlations between brain areas (Davis et al., 2020; Lynn & Bassett, 2019) . As for the existing multivariable time series models, most of them can only capture implicit correlation patterns (Zerveas et al., 2021; Chen & Shi, 2021) , whereas explicit correlations are required by doctors for identifying lesions. Moreover, although some graph-based methods have been proposed to explicitly learn correlations, they focus on giving an overall prediction for all channels at a time but overlook the prediction on one specific channel (Zhang et al., 2022; Shang et al., 2021) . Therefore, how to explicitly capture the spatial and temporal correlations while giving channel-wise prediction is another issue to be overcome. In this paper, our main contribution is to propose a multi-channel self-supervised learning framework MBrain, which can be generally applied for learning representations of both EEG and SEEG data. In addition, we pay special attention to its application in seizure detection. Based on domain knowledge and data observations, we propose to learn the correlation graph between channels as the common cornerstone for both two types of brain signals. In particular, we employ Contrastive Predictive Coding (CPC) (Van den Oord et al., 2018) as the backbone model of our framework by extending it for handling multi-channel data with theoretically guaranteed effectiveness. Based on the multichannel CPC, we propose the instantaneous time shift task to learn the spatial correlations between channels, and the delayed time shift task and the replace discriminative task are designed to capture the temporal correlation patterns and to preserve the characteristics of each channel respectively. Extensive experiments show that MBrain outperforms several state-of-the-art baselines on largescale real-world EEG and SEEG datasets for the seizure detection task.

2. PRELIMINARY: THEORETICAL ANALYSIS OF MULTI-CHANNEL CPC

We employ Contrastive Predictive Coding (CPC) (Van den Oord et al., 2018) as the basis of our framework. The pretext task of CPC is to predict low-level local representations by high-level global contextual representations c t at the t-th time step. Theoretically, the optimal InfoNCE loss proposed by CPC with N -1 negative samples L opt N is a lower bound of the mutual information between contextual semantic distribution p(c t ) and raw data distribution p(x t+k ), i.e., L opt N ≥ -I(x t+k ; c t ) + log N , where k is the prediction step size. CPC is originally designed for single-channel sequence data only, and there are two natural ways to extend single channel CPC to multi-channel version. The first one is to use CNNs with multiple kernels to encode all channels simultaneously, which cannot offer explicit correlation patterns for doctors to identify lesions. The second one is to train a shared CPC regarding all channels as one, which has no ability to capture the correlation patterns. Taking a comprehensive consideration, we propose multi-channel CPC in this paper. Our motivation is to explicitly aggregate the semantic information of multiple channels to predict the local representations of one channel. Formally, we propose the following proposition as our basic starting point. Proposition 1. Introducing the contextual information of the correlated channels increases the amount of mutual information with the raw data of the target channel. I(x i t+k ; Φ(c t )) = I(x i t+k ; c i t , Φ({c j t } j̸ =i )) ≥ I(x i t+k ; c i t ), ) where i and j are indexes of the channels. Φ(•) represents some kinds of aggregate function, which has no additional formal constraints other than the need to retain information of the target channel. Proof. We use the linear operation of mutual information to obtain: I(x i t+k ; c i t , Φ({c j t } j̸ =i )) = I(x i t+k ; c i t ) + I(x i t+k ; Φ({c j t } j̸ =i )|c i t ). According to the non-negativity of the conditional mutual information, we complete the proof. It seems natural that the predictive ability of multiple channels is stronger than that of a single channel, which is also consistent with the assumption of Granger causality (Granger, 1969) to some extent. Therefore, we choose to approximate the more informative I(x i t+k ; Φ(c t )) to obtain more expressive representations. Specifically, followed by InfoNCE, we define our loss function L N as L N = - i E X i log f k (x t+k , Φ(c t )) xj ∈X f k (x j , Φ(c t )) , where X i denotes the data sample set consisting of one positive sample and N -1 negative samples of the i-th channel. We then establish the relationship between L N and I(x i t+k ; Φ(c t )). Theorem 1. Given a sample set for each channel X i = {x i 1 , . . . , x i N }, i = 1, . . . , n consisting of one positive sample from p(x i t+k |Φ(c t )) and N -1 negative samples from j p(x j t+k )/n, where n is the number of channels. The optimal L opt N is the lower bound of i I(x i t+k ; Φ(c t )): L opt N ≥ i -I(x i t+k ; Φ(c t )) + log N . Proof. The optimal f k (x t+k , Φ(c t )) is proportional to p(x i t+k |Φ(c t ))/( j p(x j t+k )/n), which is the same as single-channel CPC. And we can directly replace the data distributions in the proof of single-channel CPC (see details in Appendix B) to obtain the inequality below: L opt N ≥ i E X i log 1 n j p(x j t+k ) p(x i t+k |Φ(c t )) + log N = E X 1 ,X 2 ,...,X n log [ 1 n j p(x j t+k )] n Π j p(x j t+k |Φ(c t )) +n log N. (4) According to the Jensen Inequality, we obtain that ( j log p(x j t+k ))/n ≤ log ( j p(x j t+k )/n). By exponentiating the two equations, we have Π j p(x j t+k ) ≤ [ 1 n j p(x j t+k )] n . ( ) With the help of equation 5, we can further obtain the lower bound of equation 4: L opt N ≥ E X 1 ,X 2 ,...,X n log Π j p(x j t+k ) Π j p(x j t+k |Φ(c t )) + n log N = i -I(x i t+k ; Φ(c t )) + log N . (6) Then we complete the proof. We next analyze the advantages of multi-channel CPC over single-channel CPC. Our loss function L N leads to a better predictive representation because we approximate a more informative objective I(x i t+k ; Φ(c t )), if the optimal loss function for each channel has log N gap with I(x i t+k ; Φ(c t )), which is the same in single-channel CPC. Moreover, with the same GPU memory, the more channels, the smaller the batch size that can be accommodated. But we can randomly sample negative samples across all channels, which increases the diversity of negative samples. However, in order to narrow the approximation gap, equation 5 should be considered. The equality sign in this inequality holds if and only if samples from each channel follows the same distribution. In fact, for many largescale multi-channel time series data (e.g., brain signal data used in this paper), by normalizing each channel, they all exhibit close normal distributions leading to small gaps in equation 5.

3. PROPOSED METHOD

In this section, we introduce the details of our proposed self-supervised learning framework MBrain. For the commonality between EEG and SEEG, we are inspired by the synergistic effect of brain function and nerve cells, that is, different connectivity patterns correspond to different brain states (Lynn & Bassett, 2019) . In particular, for brain signals, nerve cells will spontaneously generate traveling waves and spread them out (Davis et al., 2020) , maintaining some characteristics such as shape during the process. Therefore, the degree of channel similarity implies different propagation patterns of traveling waves, reflecting the differences in connectivity patterns to some extent. Both EEG and SEEG brain signals follow the inherent physiological mechanism. Therefore, we propose to extract the correlation graph structure between channels (brain areas) as the cornerstone of unifying EEG and SEEG data (Section 3.1). Next, we introduce three self-supervised learning tasks to model brain signals in Section 3.2. We propose instantaneous time shift task based on multichannel CPC and delayed time shift task to capture the spatial and temporal correlation patterns. Then Replace discriminative task is further designed to preserve characteristics of each channel. Notations. For both EEG and SEEG data, there are multiple electrodes with C channels. We use X = {x l ∈ R C } L l=1 to represent raw time series data with L time points. i and j denote the index of channels. Y l,i ∈ {0, 1} is the label for the l-th time point and i-th channel. We use a W-length window with no overlap to obtain the time segments S = {s t } |S| t=1 (see details in Appendix A). The label corresponding to the t-th time segment and the i-th channel is denoted as Y s t,i .

3.1. LEARNING CORRELATIONS BETWEEN CHANNELS

As mentioned above, the correlation patterns between different brain areas can help us to distinguish brain activities in downstream tasks to a large extent. Taking the seizure detection task as an example, when seizures occur, more rapid and significant propagation of spike-and-wave discharges will appear (Proix et al., 2018) , which greatly enhances the correlation between channels. This phenomenon is also verified by data observations in Appendix C, which supports us to treat correlation graph structure learning as the common cornerstone of our framework. However, correlations between brain regions are difficult to be observed and recorded directly. Therefore, for each time step t, our goal is to learn the structure of the correlation graph, whose adjacency matrix is A t , where nodes in the graph indicate channels and weighted edges denote the correlations between channels. Considering that the brain is in normal and stable state most of the time, we first define the coarsegrained correlation graph as the prior graph for a particular individual as A coarse (i, j) = E st [Cosine(s t,i , s t,j )], where the expectation operation averages over all the correlation matrices computed in only one time segment s t , and Cosine(•, •) denotes the cosine similarity function. Next, based on A coarse , for each pair of channels, we further model their fine-grained short-term correlation within each time segment. We assume that the fine-grained correlations follow a Gaussian distribution element-wise, whose location parameters are elements of A coarse and scale parameters will be learned from the data. By means of the reparameterization trick, the short-term correlation matrix of the t-th time segment is sampled from the learned Gaussian distribution: σ t (i, j) = SoftPlus(MLP(c self t,τ,i , c self t,τ,j )), n t (i, j) ∼ N (0, 1), (9) A fine t (i, j) = A coarse (i, j) + σ t (i, j) × n t (i, j). SoftPlus(•) is a commonly used activation function to ensure the learned standard deviation is positive. c self t,τ is the contextual representation of raw time segments extracted by encoders (see details in Section 3.2). To remove the spurious correlations caused by low frequency signals and enhance the sparsity, which is a common assumption in neuroscience (Yu et al., 2017) , we filter the edges by a threshold-based function to obtain the final correlation graph structure A t : A t (i, j) = A fine t (i, j), A fine t (i, j) ≥ θ 1 , 0, A fine t (i, j) < θ 1 . ( ) Figure 1 : Overview of MBrain. The leftmost is the raw multi-channel brain signals. We use an encoder to map the raw data into a low-dimensional representation space. To capture the spatial and temporal correlation patterns, we propose three SSL tasks to guide the encoder to learn informative and distinguishable representations.

3.2. SELF-SUPERVISED LEARNING TASKS FOR BRAIN SIGNALS

To capture the correlation patterns in space and time, we propose two self-supervised tasks: instantaneous time shift that is based on multi-channel CPC and captures the short-term correlations focusing on spatial patterns; and delayed time shift for temporal patterns in broader time scales. Instantaneous Time Shift. For spatial patterns, we aim to leverage the contextual information of correlated channels to better predict future data of the target channel. Therefore, we apply multichannel CPC and utilize the fine-grained graph structure A t obtained in Section 3.1 as the correlations between channels. We first use a non-linear encoder g enc (1D-CNN with d kernels) mapping the observed time segments to the local latent d-dimensional representations z t = g enc (s t ) ∈ R T ×C×d for each channel separately. T is the sequential length after down sampling by g enc . Then an autoregressive model g ar is utilized to summarize the historical τ -length local information of each channel itself to obtain the respective contextual representations: c self t,τ = g ar (z t,1 , • • • , z t,τ ). In this step, we only extract the contextual information of all channels independently. Based on the graph structure A t , we instantiate the aggregate function Φ(•) in equation 4 as GNNs due to their natural message-passing ability on a graph. Here we use a one-layer directed GCN (Yun et al., 2019) to show the process: c other t,τ,i = ReLU j̸ =i A t (i, j) • c self t,τ,j j̸ =i A t (i, j) • Θ , ( ) where Θ is the learnable matrix. Considering that we only aggregate other channels' information, the self-loop in GCN is removed here. Finally, by combining both c self t,τ and c other t,τ to obtain the global representations c t,τ , the model can predict the local representations k 1 -step away z t,τ +k1 based on the multi-channel CPC loss: c t,τ = Concat(c self t,τ , c other t,τ ), L 1 = L N = -E t,i,k1 log c ⊤ t,τ,i W k1 z t,τ +k1,i zj ∈X i t c ⊤ t,τ,i W k1 z j , where X i t denotes the random noise set including one positive sample z t,τ +k1,i and N -1 negative samples. W k1 is the learnable bilinear score matrix of the k 1 -th step prediction. Delayed Time Shift. For brain areas far apart, there exists delayed brain signal propagation, which is confirmed by the data observations showed in Appendix C. We should consider these significant temporal correlations across several time steps in our model. Our motivation is that if a simple classifier can easily predict whether two time segments are highly correlated, the segment representations will be significantly different from those with weaker correlations. We thus define the delayed time shift task to encourage more distinguishable segment representations. Similar with instantaneous time shift, we first compute the cosine similarity matrix based on raw data between time segments across several time steps. For the i-th channel in the t-th time segment, the long-term correlation matrix B i t is computed as B i t (k 2 , j) = Cosine(s t,i , s t+k2,j ), where j traverses all channels including the target channel and k 2 traverses at most K 2 prediction steps. Then we construct pseudo labels Y i t according to B i t to encourage the segment representations with higher correlations to be closer. A predefined threshold θ 2 is set to assign pseudo labels: Y i t (k 2 , j) = 1, B i t (k 2 , j) ≥ θ 2 , 0, B i t (k 2 , j) < θ 2 . ( ) With the pseudo labels, we define the cross entropy loss of the delayed time shift prediction task: h t = Pooling(c self t,1 , • • • , c self t,T ), p = Softmax(MLP(Concat(h t,i , h t+k2,j ))), L 2 = -E t,i,k2,j Y i t (k 2 , j) log p + (1 -Y i t (k 2 , j)) log(1 -p) ( ) where p is the predicted probability that the two segments are highly correlated. In practical application, we randomly choose 50% labels from each Y i t for efficient training. Replace Discriminative Learning. Consistently exploiting correlation for all channels will weaken the specificity between channels. However, there are significant differences in the physiological signal patterns of different brain areas recorded by channels. Therefore, retaining the characteristics of each channel cannot be ignored for the modeling of brain signals. For this purpose, we further design the replace discriminative learning task. Following BERT (Devlin et al., 2018) , we randomly replace r% local representations throughout z t by ẑt , which is sampled from any T sequences and any C channels in z t . We use the notation I(ẑ t ) to represent the new local representations after replacement and the corresponding channel indexes of ẑt in the original sequence. We generate pseudo labels Y t of the task as below: Y t (τ, i) = 1, I(ẑ t,τ,i ) ̸ = i, 0, I(ẑ t,τ,i ) = i. τ and i traverse T sequences and C channels of ẑt . After obtaining ẑt , we put it into the autoregressive model to get the new contextual representations ĉt = g ar (ẑ t ). Finally, a simple discriminator implemented by an MLP is utilized to classify whether ĉt are replaced by other channels or not: L 3 = -E t,τ,i [Y t (τ, i) log q + (1 -Y t (τ, i)) log(1 -q)] , ( ) where q is the predicted probability that ĉt,τ,i is replaced. When the accuracy of discrimination increases, different channel representations output by the autoregressive model are easier to distinguish. Therefore, the task preserves the unique characteristics of each channel. Combining the multi-task loss functions equation 15, equation 20 and equation 22, we jointly train MBrain with L = (1-λ 1 -λ 2 )L 1 +λ 1 L 2 +λ 2 L 3 . After the SSL stage, the segment representations h t obtained from equation 18 are used for downstream tasks. 

4.2. EXPERIMENTAL SETUP

To demonstrate the effectiveness of MBrain, we first define the seizure detection task and perform the experiments on EEG and SEEG datasets respectively. To examine the generalization of our model, we further conduct experiments corresponding to two clinically feasible schemes on SEEG dataset. The ablation study, hyperparameter analysis, variance and more experimental results are showed in Appendix G, H and I. We also show the case study of the correlation graph learned by Section 3.1 in Appendix F to confirm the ability of identifying different brain states. Task 1 (Seizure Detection). Given a time-ordered set including I S consecutive time segments with the index of the first segment being t 0 : S = {s t0 , . . . , s t0+I S }, models predict the labels Ŷ s t,i for all time segments in S (i.e., t = t 0 , . . . , t 0 + I S ) and all channels in each segment (i.e., i = 1, . . . , C). Seizure detection experiment. In this experiment, we first perform the self-supervised learning of the model on unlabeled data. Then the segment representations learned by the model are used for downstream seizure detection task (see details in Appendix D). During the training phase of downstream task, the encoder of SSL model will be fine-tuned with a very low learning rate. For EEG data, there is no overlap between the patients of training and testing sets. In addition, since EEG labels are coarse-grained, the segment representations in each 12-second EEG clip are pooled to one representation and then be used for seizure detection. For SEEG data, considering the difference in the number and location of implanted electrodes for patients, we conduct experiments for each patient independently and report the average performance. Transfer learning experiments. To meet practical clinical needs, we design two clinically feasible schemes for SEEG data. The first is the domain generalization experiment, that is, training the model on data of existing patients and directly predicting data of unknown patients. More specifically, we follow the "3-1-1" setting, where 3 patients are used for training, 1 patient is used for validation and 1 patient is used for testing. We conduct experiments for all combinations, pick up the best result for each patient, and report the average results over all patients. The second is the domain adaptation experiment (Motiian et al., 2017) . In this experiment, we first perform SSL on one patient (i.e., source domain) and then fine-tuning is performed using partial labeled data from another patient (i.e., target domain) in the testing set. Finally, we perform seizure detection on the remaining data of the target domain. We report the results on 12 cross-domain scenarios covering all combinations of four patients with typical seizure patterns in the SEEG dataset.

4.3. RESULTS OF SEIZURE DETECTION EXPERIMENT

The average performance of seizure detection on the SEEG dataset are presented in Table 1 . Since the positive-negative sample ratio of SEEG dataset is imbalanced, F -score is a more appropriate metric to evaluate the performance of models than only considering precision or recall. Especially in clinical applications, doctors pay more attention to finding as much seizures as possible, we thus choose F 1 -score and F 2 -score in the experiment. Overall, MBrain improves the F 1 -score by 28.92% and the F 2 -score by 26.85% on SEEG dataset compared to the best result of baseline methods, demonstrating that MBrain can learn informative representations from SEEG data. (Tang et al., 2022) , we also add Area Under the Receiver Operating Characteristic (AUROC) metric in our experiment. Our model is designed to learn the representation for each channel, while there is only one label for an EEG clip. Therefore, it requires the pooling operation to aggregate representations output by our model over channels and time segments for seizure detection. This setting makes the performance improvement of our model not as significant as that in the SEEG experiment. Nevertheless, MBrain still outperforms all baselines on F 1 -score, F 2 -score and AUROC with an increase of 2.26%, 9.23% and 2.74%, respectively. 

4.4. RESULTS OF DOMAIN GENERALIZATION EXPERIMENT

In this experiment, we compare the baseline models that perform well in Table 1 . This is because, to some extent, the results in Table 1 represent an upper bound on the performance of these models. We point out that although GTS and MBrain are both graph-based models, GTS cannot be trained on multiple patients, since the DCRNN (Li et al., 2018) encoder used in GTS can only process graphs with fixed nodes. In contrast, our model is designed to learn the correlations between each pair of nodes and utilizes inductive GNN, so it can easily handle graphs with different numbers of input nodes. In general, the performance of models under the domain generalization setting decreases significantly (40.37% on average in terms of F 2 -score) compared with that in epilepsy detection experiment. The drop for recall metric is more pronounced, confirming that the distribution shift of patients in SEEG data is more significant than that in EEG. This results from the fact that different brain regions and different types of epileptic waves have different physiological properties and patterns. MBrain in this experiment still improves F 1 -score and F 2 -score by 21.77% and 27.83% respectively, compared to the best baseline. The results prove that MBrain has a superior generalization ability benefiting from rational inductive assumption of model design. 1 . Due to the long record, clinical SEEG data contains tens or even hundreds of seizures, allowing our model to use a subset of data to fine-tune and then to predict the remaining data. Therefore, as a clinical alternative to domain generalization, we conduct a compromise domain adaptation experiment. Table 3 shows the performance of the domain adaptation (DA) experiment for four patients with typical seizure patterns provided by doctors from SEEG dataset. More specifically, we train MBrain on one patient and fine-tune it on all other three patients. B→A denotes that the SSL model is trained on Patient-B and then performs seizure detection experiment with the SSL model being fine-tuned on Patient-A. The results of "Max-base" and "Non-DA" rows correspond to the performance of the best baseline and MBrain respectively in scenarios A→A, B→B, C→C and D→D. Compared with the results in the condition that the self-supervised model and downstream model are both trained on the same patient, the F 2 -scores of all 12 cross-domain scenarios reduce by less than 15%. Additionally, it can be observed that in all cross-domain scenarios, MBrain beats the best baseline in the corresponding scenarios without domain adaptation. It is worth noting that D→C scenario outperforms corresponding "Non-DA" result. The possible reason is that the signal patterns on Patient-D are more significant and recognizable than those on Patient-C. Therefore, the SSL model trained on higher quality source domain can better distinguish signal states when performing downstream tasks on target domain. Overall, the domain adaptation experiment makes MBrain achieve competitive performance with Table 1 by fine-tuning it on only a subset of the target domain. The results suggest that MBrain captures the inherent features and outputs generalized representations between patients, because we fine-tune the SSL model with a very low learning rate (1e-6). From the perspective of pre-training, the SSL model trained on the source patient gives good initial parameters for the fine-tuning stage on the target patient.

5. CONCLUSION AND DISCUSSION

In this paper, we propose a general multi-channel SSL framework MBrain, which can be applied for learning representations of both EEG and SEEG brain signals. Based on domain knowledge and data observations, we succeed to use the correlation graph between channels as the cornerstone of our model. The proposed instantaneous and delayed time shift tasks help us capture the correlation patterns of brain signals spatially and temporally. Extensive experiments of seizure detection on large-scale real-world datasets demonstrate the superior performance and generalization ability of MBrain. However, there are still some limitations of our work. For example, negative sampling of multi-channel CPC consumes certain memory and time. Besides, we lack a more automatic way to determine the time range of long-term temporal patterns. As for the future work, we plan to collect more types of brain signals and extend MBrain to more downstream tasks.

6. REPRODUCIBILITY STATEMENT

We provide the source code of our model MBrain in the Supplementary Material. Some implementation details of MBrain can be found in Appendix D and default hyperparameters can be found in the code. Users can run MBrain on their own datasets, following the same generation method of database mentioned in Section 4.1.

7. ETHICS STATEMENT

This paper proposes a novel self-supervised learning framework MBrain for brain signals, and conduct experiments on real-world large-scale EEG and SEEG datasets. The EEG dataset is public and SEEG dataset is non-public but anonymous. Overall, this work inherits some of the risks of the existing works implementing the EEG dataset and does not introduce any new ethical or future social concerns for SEEG dataset.

A PRELIMINARIES

Brain signal data. For both EEG and SEEG data, there are multiple electrodes with C contacts that are sampled at a fixed frequency to record the brain signals. We also call these contacts channels. For every sampling point, each channel records the potential value of the brain region in which they are located, constituting abstract multi-channel time series data. A complete record file contains a total of L time points, for which we use the notation X = {x l ∈ R C } L l=1 to represent. In the reminder of this paper, we use i and j to denote the index of channels, such as x l = {x l,i } C i=1 . For every x l,i , we assign a binary label Y l,i ∈ {0, 1} to it according to the start and end time of epileptic brain signals marked by doctors. The time points are in the epileptic state with positive labels (Y l,i = 1), while zero labels (Y l,i = 0) represent the normal data. Preprocessing. Following the existing time series works (Zhu, 2017; Bagnall et al., 2017; Schäfer, 2015) with the common preprocessing of segmentation, we use a W-length window to divide the original brain signal data X into time segments S = {s t ∈ R W×C } |S| t=1 without overlapping. The number of segments |S| = ⌊L/W⌋. The segment label is obtained from the time points of the whole segment, i.e., Y s t,i = max{Y t×W+1,i , . . . , Y (t+1)×W,i }.

B SINGLE-CHANNEL CPC

Contrastive Predictive Coding (CPC), a pioneering model for self-supervised contrastive learning, sets the pretext task to predict low-level local representations by high-level global contextual information c t . In this way, the model can avoid learning too many details of the raw data and pay more attention to the contextual semantic information of the sequence data. The InfoNCE loss proposed in CPC has become the basic design of the contrastive learning loss function. Specifically, given a raw data sample set X = {x 1 , . . . , x N } consisting of one positive sample from p(x t+k |c t ) and N -1 negative samples from the noisy distribution p(x t+k ), InfoNCE will optimize: L N = -E X log f k (x t+k , c t ) xj ∈X f k (x j , c t ) . ( ) In order to obtain the best classification probability of the positive sample with the cross entropy loss function, the optimal f k (x t+k , c t ) is proportional to p(x t+k |c t )/p(x t+k ). Furthermore, the optimal loss function is also closely related to mutual information, as below: L opt N = -E X log p(x t+k |c t )/p(x t+k ) p(x t+k |c t )/p(x t+k ) + xj ∈Xneg p(x j |c t )/p(x j ) ≥ E X log p(x t+k ) p(x t+k |c t ) N (24) = -I(x t+k ; c t ) + log N. ( ) Therefore, we can conclude that while minimizing the loss function L N , we are also constantly approximating the mutual information of raw data distribution p(x t+k ) and contextual semantic distribution p(c t ). It turns out that InfoNCE is indeed a well-established loss function designed for self-supervised contrastive learning.

C DATA OBSERVATIONS

As Figure 2 shows, for both EEG and SEEG data, we can observe that the correlation matrices are almost the same on two normal segments without overlapping in the same patient. To the opposite, the correlation matrix in the epileptic states differs from the normal ones greatly. These data observations verify the conclusion that correlation patterns can help us to distinguish different brain states, and support us to use the correlation matrix as the cornerstone of EEG and SEEG data. The data observations showed in Figure 3 and Figure 4 confirm that there still exist significant correlations between time segments across several time steps. Unlike instantaneous time shift, delayed correlations are not stable. This can be concluded from the numerical difference between the aver-

Channels

Temporal steps 

D IMPLEMENTATION DETAILS OF MBrain

The non-linear encoder g enc used in MBrain is composed of three 1-D convolution layers and a one-layer LSTM model (Hochreiter & Schmidhuber, 1997) is used as the autogressive model g ar . The model is optimized using Adam optimizer (Kingma & Ba, 2015) with a learning rate of 2e-4 and weight decay of 1e-6 for the self-supervised learning stage. And for the downstream seizure detection stage, the downstream model is optimized with a learning rate of 5e-4 and weight decay of 1e-6 while the SSL model is fine-tuned with a low learning rate of 1e-6. For the hyperparameters of MBrain, we set θ 1 = 0.5 and θ 2 = 0.5. We set the maximum value of k 1 in instantaneous time shift task as 8. As Figure 5 shows, we set K 2 = 7 so as to take into account the step with the most significant correlation in delayed time shift task. Lastly, we build our model using PyTorch 1.8 (Paszke et al., 2019) and trained it on a workstation with four NVIDIA TESLA T4 GPUs. For the downstream task, we first utilize an LSTM model (Hochreiter & Schmidhuber, 1997) to encode the segment representations of each channel in chronological order (10-second in SEEG clips and 12-second in EEG clips) independently. One-layer self-attention (Vaswani et al., 2017) is then adopted to all channels within the same time step. Finally, a two-layer MLP classifier is used to predict whether epilepsy is occurring in the time segments. All baselines share the same downstream model in our experiments.

E IMPLEMENTATION DETAILS OF BASELINES

• MiniRocket (Dempster et al., 2021) : Rocket (Dempster et al., 2020) is a state of the art supervised time series classification method based on evaluations on public benchmarks (Bagnall et al., 2017; Tan et al., 2020) , involves training a linear classifier on top of features extracted by a flat collection of numerous and various random convolutional kernels. MiniRocket is a variant of Rocket which improves processing time, while offering essentially the same accuracy. We uese the open source The data observation of how to choose hyperparameter K 2 . We first average the correlations of all channels of each channel in each time step. Then we average those of all channels in the same time step. As the figure shows, we empirically choose K 2 = 7. code from https://github.com/angus924/minirocket. For each subject, we use the features obtained through MiniRocket to train an independent logistic regression classifier for each channel and test it on the test set of that channel. • CPC (Van den Oord et al., 2018) : This is a self-supervised learning method based on a contrastive loss InfoNCE. The pretext task of CPC is set to predict future local low-level representations obtained from multi-layer CNNs by contextual high-level representations obtained from an autoregressive model. This is the backbone model in this paper. We use the open source code of the corrected version from https://github.com/facebookresearch/CPC audio. • SimCLR (Chen et al., 2020) : This is a simple yet effective framework for contrastive learning of visual representations and we use time-series specific augmentations to adapt it to our application. We implemented SimCLR on time series data by ourselves. We use the same encoder architecture and parameter configuration as TS-TCC. In the meantime, we also follow TS-TCC and use scaling (sigma=1.1) as the data augmentation way. • Triplet-Loss (T-Loss) (Franceschi et al., 2019) : The approach employs time-based negative sampling and a triplet loss to learn representations for time series segments. We use the default model architecture from the source code provided by the author (https://github.com/White-Link/UnsupervisedScalableRepresentationLearningTimeSeries). For the sampling method of negative samples, we use the data of the previous batch as the candidate set of negative samples of the current batch data (the negative sample candidate set for the first batch is itself). Since the dataloader is shuffled at the end of each epoch, there is no need to worry about the case where the set of sampled negative samples does not change. • Time Series Transformer (TST) (Zerveas et al., 2021) : This is a unsupervised representation learning framework for multivariate time series by training a transformer model to extract dense vector representations of time series through an input denoising objective. We use the default model architecture from the source code provided by the author (https://github.com/gzerveas/mvts transformer). • GTS (Shang et al., 2021) : This is a time series forecasting model that learns a graph structure among multiple time series and forecasts them simultaneously with a GNN. In view of this, this model can learn useful representations from unlabeled time series data. We use the default model architecture from the source code provided by the author (https://github.com/chaoshangcs/GTS). In the pre-training stage, we divide each time series segment into 10 parts on average, and learn a time series forecasting model that predicts the next 2 steps based on the previous 8 steps. In the downstream task stage, we use the representation after step 10 as the representation of the time series segment for the seizure detection task. • TS-TCC (Eldele et al., 2021) : This is an unsupervised time-series representation learning framework, applying a temporal contrasting module and a contextual contrasting module to learn robust and discriminative representations. We use the default model architecture from the open source code provided by the author (https://github.com/emadeldeen24/TS-TCC). • TS2Vec (Yue et al., 2021) : This is a universal representation learning framework for time series, that applies hierarchical contrasting to learn scale-invariant representations within augmented context views. We use the default model architecture from the source code provided by the author (https://github.com/yuezhihan/ts2vec). In this section, we study the correlation graphs between the channels learned by MBrain. We randomly sample normal and seizure SEEG clips of one particular patient, and visualize their correlation graphs. The correlation graphs showed in Figure 6 , we use the threshold θ 1 defined in Section 3.1 for the preservation of the edges. In addition, the thickness of an edge represents the edge weight, and the size of a node represents the sum of the edge weights of all the edges linked to that node. It can be observed that during the normal phase, the correlation is sparser and the weights between edges are smaller, indicating a weaker correlation between channels. In contrast, during the seizure phase, the pattern between channels varies, with the correlation becomes denser and the edge weights become larger. Furthermore, in the correlation graph of the seizure phase, edges with larger weights are usually connected to 2 seizure channels, like Channel-2, Channel-35 and Channel-38 in Figure 6 (b), which can help surgeons to better localize the seizure lesions.

G ABLATION STUDY

We study the effectiveness of each component in MBrain. Specifically, we compare MBrain with different model variants removing following different components. We firstly remove the correlation graph structure learning module from the instantaneous time shift task and degenerate the task to single-channel CPC while still uniformly sampling negative samples in all channels. This variant is denoted as "MBrain-Graph". Next, we respectively remove the whole instantaneous time shift task, the delayed time shift task and replace discriminative task. These model variants are denoted as "MBrain-Instant", "MBrain-Delay" and "MBrain-Replace". Finally, we consider the condition of preserving only one self-supervised task. "MBrain-onlyInstant", "MBrain-onlyDelay" and "MBrain-onlyReplace" indicate that MBrain only performs instantaneous time shift task, delayed time shift task and replace discriminative task respectively. We repeat the experiments five times in the fine-tuning stage of the downstream seizure detection task. Table 4 shows the results of ablation study on SEEG dataset. It can be observed that the complete MBrain achieves the best performance on F 1 and F 2 metrics, demonstrating the effectiveness of each component in our model design. For "MBrain-Instant", the significant decrease in performance illustrates that capturing the spatial and short-term patterns is quite important and is the key to learning the essential representations in multi-channel brain signals. For "MBrain-Graph", the decrease in performance demonstrates that multi-channel CPC can greatly help learn more informative representations. Additionally, the performance in "MBrain-Delay" and "MBrain-Replace" also decreases significantly, illustrating that modeling long-term temporal patterns and preserving the characteristics of channels can help learn more distinguishable representations. In the condition where only one self-supervised task is preserved, it can be observed that the instantaneous time shift task is the most important, which is as expected, and the delayed time shift task and the replace discriminative task contribute similarly to the performance of the complete model. In addition to removing some components and tasks from MBrain, we also design some ablation experiments to verify the effectiveness of our proposed graph structure learning. We have proposed two ideas on how to directly implement the multi-channel CPC in Section 2. For the second idea, we have reported the results of a shared CPC regarding all channels as one on the CPC row of Table 1 . For the first idea, we design two strategies to combine multi-channel CNN or MLP into CPC respectively to learn representations for each channel. • Directly use 1-Dimension CNN to encode the whole time series data and the number of channels during the process is C → 256 → 256 → C × 256, and split the output into C representations, each of which is a 256-dimensional representation. Then an LSTM is implemented to it. Then we execute the self-supervised task and the downstream task of CPC based on the representations for each channel as MBrain does, this variant is denoted as "CPC-Conv". • We use the contextual representations of all n channels as input to an MLP in a fixed order, but we set the representation of the target channel to 0 tensor when we aggregate them. By using the output of MLP as the aggregated representation of other channels, we perform subsequent experiments following exactly the same steps as MBrain. We name this variant as "CPC-MLP". We can observe that the performance of "CPC-Conv" decreases dramatically. We speculate that this is because the channels are relatively independent, and the correlation between most channels is weak or even non-existent. Direct adoption of multi-channel convolution may introduce spurious and noisy correlations. However, the graph structure learning proposed by us has a sparsity assumption, and the representation extraction of each channel is relatively independent, so it can effectively learn and aggregate more significant information. For "CPC-MLP", we use an MLP to aggregate the representations of other channels, and then concatenate it with the representation of the target channel to predict future data. Unlike "CPC-Con" which directly adopts multi-channel convolution for the raw data to obtain the "mixed" low-level representations, "CPC-MLP", like MBrain, learns the correlation of channels based on the "separate" high-level representations. Therefore, the performance of "CPC-MLP" does not drop as dramatically as that of "CPC-Conv". It can be observed that "CPC-MLP" outperforms "CPC" on the EEG dataset. This may result from the fact that the number of channels in EEG dataset is only 19, while that in SEEG dataset is 52 to 124. Consequently, the ablation results show that the graph structure learning we design has a reasonable and parameter-efficient inductive assumption. Sensitivity analysis on loss weights. Our loss function is defined as: L = (1-λ 1 -λ 2 )L 1 +λ 1 L 2 + λ 2 L 3 , where L 1 , L 2 and L 3 are the loss of instantaneous time shift prediction task, delayed time shift prediction task and replace discriminative task respectively, and λ 1 and λ 2 are hyperparameters to balance the three pre-training tasks. We search both of the weights of λ 1 and λ 2 in the set {0.1, 0.2, 0.3, 0.4, 0.5} and report the tuning results with F 2 -score for seizure detection task on patient-A from SEEG dataset. In 7(a) and 7(b), we can see that λ 1 = 0.5 and λ 2 = 0.3 lead to the optimal performance. In addition, MBrain consistently performs better than the best baseline. Sensitivity analysis on replace ratio. We perform sensitivity analysis on replace ratio r% from replace discriminative task. We search the replace ratio from 5% to 95% and report the tuning results with F 2 -score for seizure detection task on patient-A from SEEG dataset. As Figure 8 shows, when the replace ratio is set as 45%, MBrain has the best performance of 71.06±3.41. While MBrain gets the smallest standard deviation and the second best performance of 70.63±1.41 when the replace ratio is set as 15%.

I FULL RESULTS

I.1 FULL RESULTS OF SEIZURE DETECTION EXPERIMENT Significant Analysis. Due to the large variance of the experimental results, we further conduct the significance test of the mean values to show that the performance of MBrain is indeed significantly better than that of other baseline models. We would like to emphasize that we are primarily concerned with the F 2 -score of the models, because in the clinical practice, doctors focus on finding as many seizures as possible. Thus, we focus on the significant analysis of F 2 -score for the results of Table 1 . We perform significance analysis according to the following procedures: (1) For SEEG data, we first combine the repeat results of each model for each patient into one vector. For EEG data, we directly use the vector of repeat results. Significance analysis is performed on F2 score vectors of all models pair by pair. (2) We use Levene's testfoot_0 to test whether two populations have homogeneity of variance, which is a critical property in significant analysis. (3) On the condition of homogeneity of variance, we conduct independent sample T-testfoot_1 for the two populations. We show the p-values of Table 5 and Table 7 in Figure 9 in the form of thermal maps respectively. Considering the symmetry of the significance test, we only show the half of the p-value matrix. From these two figures, we can see that for EEG data, MBrain significantly outperforms other baselines. For SEEG data, although the performance of T-Loss and TS2Vec is not significantly different In this experiment, we still keep the self-supervised model trained on the source domain unchanged and repeat the experiments on the target domain five times in the fine-tuning stage of the downstream seizure detection task. The mean and standard deviation results are presented in Table 9 . It can be observed that the variances of the results under DA setting are generally higher than those under Non-DA setting, indicating a large variation in the patterns of patients in SEEG dataset. Nevertheless, high-quality source patient data can still be used to pre-train an SSL encoder with good average performance on the target patient. In practice, we can make predictions in an ensemble way by training multiple classifiers.



https://en.wikipedia.org/wiki/Levene's_test https://en.wikipedia.org/wiki/Student%27s_t-test#Unpaired_and_paired_ two-sample_t-tests



Figure 2: The normal and seizure correlation matrices of EEG and SEEG brain signals. The top row is for SEEG and the bottom row is for EEG. For clear presentation, we sample some channels in SEEG data. The leftmost column including two figures are the base correlation matrices on normal data. The two figures in the middle column represent the matrices after subtracting another normal correlation matrices from the base matrices, and the rightmost column includes matrices after subtracting seizure correlation matrices from the base matrices.

Figure 4: The correlation matrices of delayed time shift of EEG data. The top figure shows the averaged correlation matrix over all 12-second EEG clips. And the bottom figure represents the correlation matrix of one particular 12-second clip which is randomly sampled. The computation and operation are the same as Figure 3.

Figure5: The data observation of how to choose hyperparameter K 2 . We first average the correlations of all channels of each channel in each time step. Then we average those of all channels in the same time step. As the figure shows, we empirically choose K 2 = 7.

Figure 6: Case study on correlation graphs learned by MBrain.

Figure 7: Sensitivity analysis on loss weights.

Figure8: Sensitivity analysis on replace ratio r%.

Figure 9: Significant analysis.

to 124 channels are used for recording signals. It is worth noting that since SEEG data are collected in a high frequency (1,000Hz or 2,000Hz) through multiple channels, our data is massive. In total, we have collected 470 hours of SEEG signals with a total capacity of 550GB. Professional neurosurgeons help us label the epileptic segments for each channel. For the self-supervised learning stage, we randomly sample 5,000 10-second SEEG clips for training and validation. As for the downstream seizure detection task, we first obtain 5,000 sampled 10-second SEEG clips (80% for training and 20% for validation). For the testing, we sample another 2,550 10-second SEEG clips with positive and negative sample ratio of 1:50. We use a 1-second window to segment each clip without overlap and our target is to make predictions for all channels in each 1-second segment.EEG dataset. We use the Temple University Hospital EEG Seizure Corpus (TUSZ) v1.5.2(Shah et al., 2018) as our EEG dataset. It is the largest public EEG seizure database, containing 5,612 EEG recordings, 3,050 annotated seizures from clinical recordings, and eight seizure types. We include 19 EEG channels in the standard 10-20 system. For experimental efficiency, we generate a smaller dataset from TUSZ. We randomly sample 3,000 12-second EEG clips for self-supervised learning. As for the downstream seizure detection task, we first obtain 3,000 sampled 12-second EEG clips (80% for training and 20% for validation). For the testing, we sample another 3,900 12-second EEG clips with positive and negative sample ratio of 1:10. It is worth noting that the labels of EEG data are coarse-grained, which means we only have the label of whether epilepsy occurs in an EEG clip.

The average performance of the seizure detection experiment on SEEG and EEG datasets.Table 1 also shows the results of seizure detection experiment on EEG dataset. Following the common evaluation scheme on EEG dataset

The average performance of the domain generalization experiment on SEEG dataset.

The performance of the domain adaptation experiment on SEEG dataset in terms of F 2score. DA row denotes the performance of MBrain in the domain adaptation experiment setting. Max-base and Non-DA rows represent the best performance of baselines and MBrain in the seizure detection experiment. Bold numbers and * indicate the best and the second best performance.

The results of ablation study.

The average performance of the seizure detection experiment on the SEEG dataset.

The average performance of the seizure detection experiment on the SEEG dataset with the encoder of the SSL model frozen. First, we keep the initial parameters of the trained self-supervised model unchanged, repeat the experiments five times in the fine-tuning stage of the downstream seizure detection task, and report the mean and standard deviation results in Table5 and 7. Next, we freeze the entire encoder of SSL model and only train the downstream model and the classifier in the seizure detection task. The same reproducible experimental results are shown in Table6 and Table 8. We can see that the performance of most models drops slightly compared with the results without freezing the SSL model. Nevertheless, our model still has a competitive performance in F 2 -score. For the supervised model MiniRocket without a learnable representation extractor, the results remain unchanged in the two experimental settings.

The average performance of the seizure detection experiment on the EEG dataset.

The average performance of the seizure detection experiment on the EEG dataset with the encoder of the SSL model frozen.

The performance of the domain adaptation experiment on SEEG dataset in terms of F 2score. DA row denotes the performance of MBrain in the domain adaptation experiment setting. Max-base and Non-DA rows represent the best performance of baselines and MBrain in the seizure detection experiment. Bold numbers and * indicate the best and the second best performance.

