DOMAIN GENERALIZATION WITH SMALL DATA

Abstract

In this work, we propose to tackle the problem of domain generalization in the context of insufficient samples. Instead of extracting latent feature embeddings based on deterministic models, we propose to learn a domain-invariant representation based on the probabilistic framework by mapping each data point into probabilistic embeddings. Specifically, we first extend empirical maximum mean discrepancy (MMD) to a novel probabilistic MMD that can measure the discrepancy between mixture distributions (i.e., source domains) consisted of a serial of latent distributions rather than latent points. Moreover, instead of imposing the contrastive semantic alignment (CSA) loss based on pairs of latent points, a novel probabilistic CSA loss encourages positive probabilistic embedding pairs to be closer while pulling other negative ones apart. Benefiting from the learned representation captured by probabilistic models, our proposed method can marriage the measurement on the distribution over distributions (i.e., the global perspective alignment) and the distribution-based contrastive semantic alignment (i.e., the local perspective alignment). Extensive experimental results on three challenging medical datasets show the effectiveness of our proposed method in the context of insufficient data compared with state-of-the-art baseline methods.

1. INTRODUCTION

Nowadays, we have witnessed a lot of successes via imposing machine learning techniques in a variety of tasks related to computer vision and natural language processing, such as face recognition Li et al. (2022b ), object detection Zaidi et al. (2022 ), and speech recognition Mridha et al. (2022) . Despite many achievements so far, the widely-adopted assumption for most existing methods, i.e., it is identically and independently distributed between training and testing data, may not always hold in actual applications Zhou et al. (2022); Liu et al. (2022) . In the real-world scenario, it is quite common that the distribution between training and testing data may be different, owing to sophisticated environments. For example, resulting from the differences of device vendor and staining method, acquired histopathological images of breast cancer from different healthcare centers exist significant domain gaps (a.k.a., domain shift, see Figure 1 To be specific, if the discrepancy between source domains in feature space can be minimized, the model is expected to be better generalize well on unseen target domain, owning to learned domaininvariant and transferable feature representation Ben-David et al. (2006) . For instance, an classical contrastive semantic alignment (CSA) loss proposed by (Motiian et al., 2017) was to encourage positive sample pairs (with same label) from different domains closer while pulling other negative pairs (with different labels) apart. (Dou et al., 2019) introduced the CSA loss which jointly considers local class alignment loss (for point-wise domain alignment) and global class alignment loss (for distribution-wise alignment).



for more details), which may lead to the catastrophic deterioration of the performance Qi et al. (2020). To address this issue, domain generalization (DG) is developed to learn a model from multiple related yet different domains (a.k.a., source domains) that is able to generalize well on unseen testing domain (a.k.a., target domain). Recently, researchers proposed quite a few domain generalization approaches, such as data augmentation with randomization Yue et al. (2019), data generalization with stylization Verma et al. (2019); Zhou et al. (2021), meta learning Li et al. (2018a); Kim et al. (2021)-based training schemes, among which representation learning-based methods are one of the most popular ones. These representation learning-based methods Balaji et al. (2019) aim to learn domain-invariant feature representation.

