MA N YDG: MANY-DOMAIN GENERALIZATION FOR HEALTHCARE APPLICATIONS

Abstract

The vast amount of health data has been continuously collected for each patient, providing opportunities to support diverse healthcare predictive tasks such as seizure detection and hospitalization prediction. Existing models are mostly trained on other patients' data and evaluated on new patients. Many of them might suffer from poor generalizability. One key reason can be overfitting due to the unique information related to patient identities and their data collection environments, referred to as patient covariates in the paper. These patient covariates usually do not contribute to predicting the targets but are often difficult to remove. As a result, they can bias the model training process and impede generalization. In healthcare applications, most existing domain generalization methods assume a small number of domains. In this paper, considering the diversity of patient covariates, we propose a new setting by treating each patient as a separate domain (leading to many domains). We develop a new domain generalization method ManyDG 1 , that can scale to such many-domain problems. Our method identifies the patient domain covariates by mutual reconstruction, and removes them via an orthogonal projection step. Extensive experiments show that ManyDG can boost the generalization performance on multiple real-world healthcare tasks (e.g., 3.7% Jaccard improvements on MIMIC drug recommendation) and support realistic but challenging settings such as insufficient data and continuous learning.

1. INTRODUCTION

The remarkable ability of functional approximation (Hornik et al., 1989 ) can be a double-edged sword for deep neural nets (DNN). Standard empirical risk minimization (ERM) models that minimize average training loss are vulnerable in applications where the model tends to learn spurious correlations (Liu et al., 2021; Wiles et al., 2022) between covariates (opposed to the causal factor (Gulrajani & Lopez-Paz, 2021)) and the targets during training. Thus, these models often fail on new data that do not have the same spurious correlations. In clinical settings, Perone et al. ( 2019 Unlike previous settings, this paper targets the patient covariate shift problem in the healthcare applications. Our motivation is that patient-level data unavoidably contains some unique personalized characteristics (namely, patient covariates), for example, different environments (Zhao et al., 2017) and patient identities, which are usually independent of the prediction targets, such as sleep stages (Zhao et al., 2017) and seizure disease types (Hirsch et al., 2021) . These patient covariates can cause spurious correlations in learning a DNN model. For example, in the electroencephalogram (EEG) sleep staging task, the patient EEG recordings may be collected from different environments, e.g., in sleep lab (Terzano et al., 2002) or at home (Kemp et al., 2000) . These measurement differences (Zhao et al., 2017) can induce noise and biases. It is also common that the label distribution per patient can be significantly different. A patient with insomnia (Rémi et al., 2019) could have more awake stages than ordinary people, and elders tend to have fewer rapid eye movement (REM) stages than teenagers (Ohayon et al., 2004) . Furthermore, these patient covariates can be even more harmful when dealing with insufficient training data. In healthcare applications, recent domain generalization (DG) models are usually developed in cross-institute settings to remove each hospital's unique covariates (Wang et al., 2020; Zhang et al., 2021; Gao et al., 2021; Reps et al., 2022) . Broader domain generalization methods were built with various techniques, including style-based data augmentations (Nam et al., 2021; Kang et al., 2022; Volpi et al., 2018) , episodic meta-learning strategies (Li et al., 2018a; Balaji et al., 2018; Li et al., 2019) and domain-invariant feature learning (Shao et al., 2019 ) by heuristic metrics (Muandet et al., 2013; Ghifary et al., 2016) or adversarial learning (Zhao et al., 2017) . Most of them (Chang et al., 2019; Zhou et al., 2021) Unlike previous works that assume a small number of domains, this paper considers a new setting for learning generalizable models from many more domains. To our best knowledge, we are the first to propose the many-domain generalization problem: In healthcare applications, we handle the diversity of patient covariates by modeling each patient as a separate domain. For this new manydomain problem, we develop an end-to-end domain generalization method, which is trained on a pair of samples from the same patient and optimized via a Siamese-type architecture. Our method combines mutual reconstruction and orthogonal projection to explicitly remove the patient covariates for learning (patient-invariant) label representations. We summarize our main contributions below: • We propose a new many-domain generalization problem for healthcare applications by treating each patient as a domain. Our setting is challenging: We handle many more domains (e.g., 2,702 in the seizure detection task) compared to previous works (e.g., six domains in Li et al. ( 2020)). • We propose a many-domain generalization method ManyDG for the new setting, motivated by a latent data generative model and a factorized prediction model. Our method explicitly captures the domain and domain-invariant label representation via orthogonal projection. • We evaluate our method ManyDG on four healthcare datasets and two realistic and common clinical settings: (i) insufficient labeled data and (ii) continuous learning on newly available data. Our method achieves consistently higher performance against the best baseline (e.g., 3.7% Jaccard improvements in the benchmark MIMIC drug recommendation task). 



Code is available at https://github.com/ycq091044/ManyDG.



); Koh et al. (2021); Castro et al. (2020) have shown that a prediction model trained on a set of hospitals often does not work well for another hospital due to covariate shifts, e.g., different devices, clinical procedures, and patient population.

limit the scope within CNN-based models (Bayasi et al., 2022) and batch normalization architecture (Li et al., 2017) on image classification tasks.

Domain Generalization (DG)The main difference betweenDG (Wang et al., 2022)  (also called multi-source domain generalization, MSDG) and domain adaptation(Wilson & Cook, 2020)  is that the former cannot access the test data during training and is thus more challenging. This paper focuses on the DG setting, for which recent methods are mostly developed from image classification. They can be broadly categorized into three clusters(Yao et al., 2022): (i) Style augmentation methods(Nam et al., 2021; Kang et al., 2022)  assume that each domain is associated with a particular style, and they mostly manipulate the style of raw data(Volpi et al., 2018;  Zhou et al., 2020b)  or the statistics of feature representations (the mean and standard deviations)(Li  et al., 2017; Zhou et al., 2021)  and enforces to predict the same class label; (ii) Domain-invariant feature learning(Li et al., 2018b; Shao et al., 2019)  aims to remove the domain information from the high-level features by heuristic objectives (such as MMD metric (Muandet et al.Our paper considers a new setting by modeling each patient as a separate domain. Thus, we deal with many more domains compared to all previous works(Peng et al., 2019). This challenging setting also motivates us to capture the domain information explicitly in our proposed method. Domain Generalization in Healthcare Applications Many healthcare tasks aim to predict some clinical targets of interest for one patient: During each hospital visit, such as estimating the risk of developing Parkinson's disease(Makarious et al., 2022)  and sepsis(Gao et al., 2021), predicting the

