MA N YDG: MANY-DOMAIN GENERALIZATION FOR HEALTHCARE APPLICATIONS

Abstract

The vast amount of health data has been continuously collected for each patient, providing opportunities to support diverse healthcare predictive tasks such as seizure detection and hospitalization prediction. Existing models are mostly trained on other patients' data and evaluated on new patients. Many of them might suffer from poor generalizability. One key reason can be overfitting due to the unique information related to patient identities and their data collection environments, referred to as patient covariates in the paper. These patient covariates usually do not contribute to predicting the targets but are often difficult to remove. As a result, they can bias the model training process and impede generalization. In healthcare applications, most existing domain generalization methods assume a small number of domains. In this paper, considering the diversity of patient covariates, we propose a new setting by treating each patient as a separate domain (leading to many domains). We develop a new domain generalization method ManyDG 1 , that can scale to such many-domain problems. Our method identifies the patient domain covariates by mutual reconstruction, and removes them via an orthogonal projection step. Extensive experiments show that ManyDG can boost the generalization performance on multiple real-world healthcare tasks (e.g., 3.7% Jaccard improvements on MIMIC drug recommendation) and support realistic but challenging settings such as insufficient data and continuous learning.

1. INTRODUCTION

The remarkable ability of functional approximation (Hornik et al., 1989 ) can be a double-edged sword for deep neural nets (DNN). Standard empirical risk minimization (ERM) models that minimize average training loss are vulnerable in applications where the model tends to learn spurious correlations (Liu et al., 2021; Wiles et al., 2022) Unlike previous settings, this paper targets the patient covariate shift problem in the healthcare applications. Our motivation is that patient-level data unavoidably contains some unique personalized characteristics (namely, patient covariates), for example, different environments (Zhao et al., 2017) and patient identities, which are usually independent of the prediction targets, such as sleep stages (Zhao et al., 2017) and seizure disease types (Hirsch et al., 2021) . These patient covariates can cause spurious correlations in learning a DNN model. For example, in the electroencephalogram (EEG) sleep staging task, the patient EEG recordings may be collected from different environments, e.g., in sleep lab (Terzano et al., 2002) or at home (Kemp et al., 2000) . These measurement differences (Zhao et al., 2017) can induce noise and biases. It is also common that the label distribution per patient can be significantly different. A patient with



Code is available at https://github.com/ycq091044/ManyDG. 1



between covariates (opposed to the causal factor (Gulrajani & Lopez-Paz, 2021)) and the targets during training. Thus, these models often fail on new data that do not have the same spurious correlations. In clinical settings, Perone et al. (2019); Koh et al. (2021); Castro et al. (2020) have shown that a prediction model trained on a set of hospitals often does not work well for another hospital due to covariate shifts, e.g., different devices, clinical procedures, and patient population.

