DOMAIN GENERALIZATION WITH SMALL DATA

Abstract

In this work, we propose to tackle the problem of domain generalization in the context of insufficient samples. Instead of extracting latent feature embeddings based on deterministic models, we propose to learn a domain-invariant representation based on the probabilistic framework by mapping each data point into probabilistic embeddings. Specifically, we first extend empirical maximum mean discrepancy (MMD) to a novel probabilistic MMD that can measure the discrepancy between mixture distributions (i.e., source domains) consisted of a serial of latent distributions rather than latent points. Moreover, instead of imposing the contrastive semantic alignment (CSA) loss based on pairs of latent points, a novel probabilistic CSA loss encourages positive probabilistic embedding pairs to be closer while pulling other negative ones apart. Benefiting from the learned representation captured by probabilistic models, our proposed method can marriage the measurement on the distribution over distributions (i.e., the global perspective alignment) and the distribution-based contrastive semantic alignment (i.e., the local perspective alignment). Extensive experimental results on three challenging medical datasets show the effectiveness of our proposed method in the context of insufficient data compared with state-of-the-art baseline methods.

1. INTRODUCTION

Nowadays, we have witnessed a lot of successes via imposing machine learning techniques in a variety of tasks related to computer vision and natural language processing, such as face recognition Li et al. (2022b) , object detection Zaidi et al. (2022) , and speech recognition Mridha et al. (2022) . Despite many achievements so far, the widely-adopted assumption for most existing methods, i.e., it is identically and independently distributed between training and testing data, may not always hold in actual applications Zhou et al. (2022); Liu et al. (2022) . In the real-world scenario, it is quite common that the distribution between training and testing data may be different, owing to sophisticated environments. For example, resulting from the differences of device vendor and staining method, acquired histopathological images of breast cancer from different healthcare centers exist significant domain gaps (a.k.a., domain shift, see Figure 1 To be specific, if the discrepancy between source domains in feature space can be minimized, the model is expected to be better generalize well on unseen target domain, owning to learned domaininvariant and transferable feature representation Ben-David et al. (2006) . For instance, an classical contrastive semantic alignment (CSA) loss proposed by (Motiian et al., 2017) In this paper, we propose to learn domain-invariant representation from multiple source domains to tackle the domain generalization problem in the context of insufficient samples. Instead of extracting latent embeddings (i.e., latent points) based on deterministic models (e.g., convolutional neural networks, CNNs), we propose to leverage a probabilistic framework endowed by variational Bayesian inference to map each data point into probabilistic embeddings (i.e., the latent distribution) for domain generalization. Specifically, by following the domain-invariant learning from global (distribution-wise) perspective, we propose to extend empirical maximum mean discrepancy (MMD) to a novel probabilistic MMD (P-MMD) that can empirically measure the discrepancy between mixture distributions (a.k.a., distributions over distributions), consisted of a serial of latent distributions rather than latent points. From a local perspective, instead of imposing the CSA loss based on pairs of latent points, a novel probabilistic contrastive semantic alignment (P-CSA) loss with kernel mean embedding is proposed to encourage positive probabilistic embedding pairs closer while pulling other negative ones apart. Extensive experimental results on three challenging medical imaging classification tasks, including epithelium stroma classification on insufficient histopathological images, imbalanced-class based skin lesion classification, and spinal cord gray matter segmentation, show that our proposed method can achieve better cross-domain performance in the context of insufficient data compared with state-of-the-art baseline methods. 2 RELATED WORKS 2021). Among these methods, feature representation learning, which aims to explore invariant feature information that can be shared across domains, demonstrates to be a widely adopted method for the problem of DG. For feature representation learning-based DG method, Li et al.



for more details), which may lead to the catastrophic deterioration of the performance Qi et al. (2020). To address this issue, domain generalization (DG) is developed to learn a model from multiple related yet different domains (a.k.a., source domains) that is able to generalize well on unseen testing domain (a.k.a., target domain). Recently, researchers proposed quite a few domain generalization approaches, such as data augmentation with randomization Yue et al. (2019), data generalization with stylization Verma et al. (2019); Zhou et al. (2021), meta learning Li et al. (2018a); Kim et al. (2021)-based training schemes, among which representation learning-based methods are one of the most popular ones. These representation learning-based methods Balaji et al. (2019) aim to learn domain-invariant feature representation.

was to encourage positive sample pairs (with same label) from different domains closer while pulling other negative pairs (with different labels) apart. (Dou et al., 2019) introduced the CSA loss which jointly considers local class alignment loss (for point-wise domain alignment) and global class alignment loss (for distribution-wise alignment).

Figure 1: Histopathological image examples of breast cancer tissue from three different healthcare institutes, including NKI with 626 images, IHC with 645 images, and VGH with 1324 images. There are two different tissue types, including epithelium and stroma. Obvious domain gaps (e.g., the density of tissue and the staining color) can be observed.

DOMAIN GENERALIZATION AND ITS APPLICATION IN MEDICAL IMAGE CLASSIFICATION Existing DG methods can be generally categorized into three different streams, namely data augmentation/generation Yue et al. (2019); Graves (2011); Zhou et al. (2021), meta-learning Li et al. (2018a); Kim et al. (2021) and feature representation learning Li et al. (2018b); Gong et al. (2019); Xiao et al. (

