UNBIASED REPRESENTATION OF ELECTRONIC HEALTH RECORDS FOR PATIENT OUTCOME PRE-DICTION

Abstract

Fairness is one of the newly emerging focuses for building trustworthy artificial intelligence (AI) models. One of the reasons resulting in an unfair model is the algorithm bias towards different groups of samples. A biased model may benefit certain groups but disfavor others. As a result, leaving the fairness problem unresolved might have a significant negative impact, especially in the context of healthcare applications. Integrating both domain-specific and domain-invariant representations, we propose a masked triple attention transformer encoder (MTATE) to learn unbiased and fair data representations of different subpopulations. Specifically, MTATE includes multiple domain classifiers and uses three attention mechanisms to effectively learn the representations of diverse subpopulations. In the experiment on real-world healthcare data, MTATE performed the best among the compared models regarding overall performance and fairness.

1. INTRODUCTION

Electronic Health Record (EHR) based clinical risk prediction using temporal machine learning (ML) and deep learning (DL) models benefits clinicians for providing precise and timely interventions to high-risk patients and better-allocating hospital resources (Xiao et al., 2018; Shamout et al., 2020) . Nevertheless, a long-standing issue that hinders ML and DL model deployment is the concern about model fairness (Gianfrancesco et al., 2018; Ahmad et al., 2020) . Fairness in AI/DL refers to a model's ability to make a prediction or decision without any bias against any individual or group (Mehrabi et al., 2021) . The behaviors of a biased model often result in two facets: it performs significantly better in certain populations than the others (Parikh et al., 2019) , and it makes inequities decisions towards different groups (Panch et al., 2019) . Clinical decision-making based upon biased predictions may cause delayed treatment plans for patients in minority groups or misspend healthcare resources where treatment is unnecessary (Gerke et al., 2020) . The data distribution shift problem across different domains is one of the major reasons a model could be biased (Adragna et al., 2020) . To address the fairness issue, domain adaptation methods have been developed. The main idea is to learn invariant hidden features across different domains, such that a model would perform similarly no matter to which domain the test cases belong. Pioneer domain adaptation models, including DANN (Ganin et al., 2016) , VARADA (Purushotham et al., 2017), and VRNN (Chung et al., 2015) , learn invariant hidden features by adding a domain classifier and using a gradient reversal layer to maximize the domain classifier's loss. In return, the learned hidden features are indifferent across domains. Recent work MS-ADS (Khoshnevisan & Chi, 2021) has shown robust performance across minority racial groups by maximizing the distance between the globally-shared presentations with individual local representations of every domain, which effectively consolidates the invariant globally-shared representations across domains. However, it is difficult to align large domain shifts and model complex domain shifts across multiple overlapped domains. Alternatively, the data distribution shift problem could be addressed using domain-specific bias correction approaches. A recent study showed that features strongly associated with the outcome of interest could be subpopulation-specific (Chouldechova & Roth, 2018) . It indicates that lumping together all features from patients with different backgrounds might bury unique domain-specific In summary, both domain adaptation and domain-specific bias correction approaches address the same fairness issue with different assumptions about the relationships between latent representation and the prediction outcome. The former believes that performance variation across domains would be benefited from invariant feature representation, while the latter affirms domain-specific representations. It remains unclear whether domain-invariant and domain-specific data representation should be used for a prediction task. To better address the fairness issue, we propose an adaptive multi-task learning algorithm, called MTATE (i.e. Masked Triple Attention Transformer Encoder) , to automatically learn and select the optimal and fair data representations instead of explicitly choosing domain adaptation or domainspecific bias correction. Under this setting, both invariant and domain-specific representations are special cases where one of the approaches dominates the data representation. The purpose of MTATE is to generate multiple masked representations of the same data that are attended by both time-wise attention and multiple feature-wise attentions in parallel, where each masked representation corresponds to a specific domain classification task. For example, one of the domain classifiers breaks the patient cohort into subpopulations defined by race, and another classifier is focused on gender. The learned EHR representations could be domain-specific, domain-invariant, or the mix of the two reflected by the domain classification loss values. A low loss value indicates the representation is domain-specific, and a high loss value indicates domain-invariant. The model will compute the representation-wise attention for each individual testing case, leading to personalized data representation for downstream predictive tasks. The overall framework of MTATE is shown in Figure 1 . The primary goal of MTATE is to learn an unbiased representation to make fair and precise patient outcome predictions in a real-world healthcare setting. To demonstrate the effectiveness of MTATE, we focus on rolling mortality predictions for patients with Acute Kidney Injury requiring Dialysis (AKI-D), a severe complication for critically ill patients, with a high in-hospital mortality rate (Lee & Son, 2020) . The clinical risk classification for AKI-D patients is challenging due to complex subphenotypes and treatment exposures (Neyra & Nadkarni, 2021; Vaara et al., 2022) . There is an urgent need to develop actionable approaches to account for patients' backgrounds and subpopulations for personalized medicine and improve patients-centered outcomes (Chang et al., 2022) .



Figure1: Overall framework of masked triple attention transformer encoder (MTATE). HR, SBP, and sCr stand for heart rate, systolic blood pressure, and serum creatinine, respectively. x t represents all clinical features at time t, f i represents values of feature i at all time points. Z ′ i represents the data representations learned from the i th feature relevance attention module.

