LEARNING DOMAIN-AGNOSTIC REPRESENTATION FOR DISEASE DIAGNOSIS

Abstract

In clinical environments, image-based diagnosis is desired to achieve robustness on multi-center samples. Toward this goal, a natural way is to capture only clinically disease-related features. However, such disease-related features are often entangled with center-effect, disabling robust transferring to unseen centers/domains. To disentangle disease-related features, we first leverage structural causal modeling to explicitly model disease-related and center-effects that are provable to be disentangled from each other. Guided by this, we propose a novel Domain Agnostic Representation Model (DarMo) based on variational Auto-Encoder. To facilitate disentanglement, we design domain-agnostic and domain-aware encoders to respectively capture disease-related features and varied center effects by incorporating a domain-aware batch normalization layer. Besides, we constrain the disease-related features to well predict the disease label as well as clinical attributes, by leveraging Graph Convolutional Network (GCN) into our decoder. The effectiveness and utility of our method are demonstrated by the superior performance over others on both public datasets and in-house datasets.

1. INTRODUCTION

A major barrier to the deployment of current deep learning systems to medical imaging diagnosis lies in their non-robustness to distributional shift between internal and external cohorts (Castro et al., 2020; Ma et al., 2022; Lu et al., 2022) , which commonly exists among multiple healthcare centers (e.g., hospitals) due to differences in image acquisition protocols. For example, the image appearance can vary a lot among scanner models, parameters setting, and data preprocessing, as shown in Fig. 1  (a, b, c ). Such a shift can deteriorate the performance of trained models, as manifested by a nearly 6.7% AUC drop of empirical risk minimization (ERM) method from internal cohorts (source domain, in distribution) to external cohorts (unseen domain, out of distribution), as shown in Fig. 1 (bar graph). To resolve this problem, existing studies have been proposed to learn task-related features (Castro et al., 2020; Kather et al., 2022; Wang et al., 2021b) from multiple environments of data. Although the learned representation can capture lesion-related information, it is not guaranteed that such features can be disentangled from the center effect, i.e., to variations in image distributions due to domain differences in acquisition protocols (Fang et al., 2020; Du et al., 2019; Garg et al., 2021) . The mixtures of such variations lead to biases in learned features and final predictions. Therefore, a key question in robustness is: in which way can the disease-related features be disentangled from center-effect? Recently, (Sun et al., 2021) showed that the task-related features can be disentangled from others, but requires that the input X and the output Y are generated simultaneously. However, this requirement often does not satisfy disease prediction scenarios, e.g., Y can refer to ground-truth disease labels acquired from pathological examination, which can affect lesion patterns in image X. To achieve this disentanglement, we build our model in Fig. 2 (b), via structural causal modeling (SCM) that can effectively encode prior knowledge beyond data with hidden variables and causal relations. As shown, we introduce v ma and v mi to respectively denote macroscopic and microscopic parts of disease-related features that often employed in clinical diagnosis. Specifically, the macroscopic features encode morphology-related attributes (Surendiran & Vadivel, 2012) of lesion areas, as summarized in American College of Radiology (ACR) (Sickles et al., 2013) ; while the microscopic features are hard to observe but reflect subtle patterns of lesions. Taking the mammogram in Fig. 2 (a) as an illustration, the macroscopic features refer to the margins, shapes, and speculations of the masses; while the microscopic features refer to the textures, and the curvatures of contours (Ding et al., 2020a) . As these disease-related patterns vary between malignancy and benign, they are determined by the disease status Y and we have y → (v ma , v mi ) in Fig. 2 (b) correspondingly. Besides, the v ma differs from v mi , as it is related to clinical attributes A that are easy to observe from the image. In addition to disease-related features, we also introduce v d to account for domain gaps from the center effect in the image. Note that given the image X (i.e., condition on X), the v d is correlated to (v ma , v mi ), making them entangled with each other. This entanglement can cause bias and thus unstable prediction behaviors when transferred to unseen centers/domains. Equipped with this causal modeling, we can observe that the distributional shift of data is mainly accounted for by the variation of v d across domains. Moreover, we can theoretically prove that when this variation is diverse enough, the disease-related features can be disentangled from the center effect. To the best of our knowledge, we are the first to prove that this disentanglement is possible, in the literature on imaging diagnosis. Inspired by this result, we propose a disentangling learning framework, dubbed as Domain Agnostic Representation Model (DarMo), to disentangle diseaserelated features for prediction. Specifically, we adopt a variational auto-encoder framework and decompose the encoder into domain-agnostic and domain-aware branches, which respectively encode disease-related information (v ma , v mi ) and domain effect v d . To account for the variation of v d across domains, we propose to incorporate a domain-aware batch normalization (BN) layer into the domainaware encoder, to well capture the effect in each domain. To capture disease-related information, we use disease labels to supervise (v ma , v mi ) and additionally constrain v ma to reconstruct clinical attributes with Graph Convolutional Network (GCN) to model relations among attributes. To verify the utility and effectiveness of our method, we perform our method on mammogram benign/malignant classification. Here the clinical attributes are those related to the masses, which are summarized in ACR (Sickles et al., 2013) and are easy to obtain. We consider four datasets (one public and three in-house) that are collected from different sources. The results on unseen domains show that our method can outperform others by 6.2%. Besides, our learned disease-related features can successfully encode the information on the lesion areas.



Figure 1: Domain differences between multi centers (Cases-a,b,c) and AUC evaluation of Ours/ ERM (training by Empirical Risk Minimization) under internal/external cohort. Cases-a,b,c: similar cases in different centers (red rectangles: lesion areas). The bar graph: in the external cohort (unseen domain) ERM performs a large drop on AUC, instead, our proposed method performs stable.

