AdaWAC: ADAPTIVELY WEIGHTED AUGMENTATION CONSISTENCY REGULARIZATION FOR VOLUMETRIC MEDICAL IMAGE SEGMENTATION

Abstract

Sample reweighting is an effective strategy for learning from training data coming from a mixture of different subpopulations. However, existing reweighting algorithms do not fully take advantage of the particular type of data distribution encountered in volumetric medical image segmentation, where the training data images are uniformly distributed but their associated data labels fall into two subpopulations-"label-sparse" and "label-dense"-depending on whether the data image occurs near the beginning/end of the volumetric scan or the middle. For this setting, we propose AdaWAC as an adaptive weighting algorithm that assigns label-dense samples to supervised cross-entropy loss and label-sparse samples to unsupervised consistency regularization. We provide a convergence guarantee for AdaWAC by appealing to the theory of online mirror descent on saddle point problems. Moreover, we empirically demonstrate that AdaWAC not only enhances segmentation performance and sample efficiency but also improves robustness to the subpopulation shift in labels.

1. INTRODUCTION

Modern machine learning has been revolutionizing the field of medical imaging, especially in computer-aided diagnosis with Computed Tomography (CT) and Magnetic Resonance Imaging (MRI) scans. While the successes of most classical learning algorithms (e.g., empirical risk minimization (ERM)) build upon the assumption that training samples are independently and identically distributed (i.i.d.), real-world volumetric medical images rarely fit into this picture. Specifically for medical image segmentation, as instantiated in Figure 1 , the segmentation labels corresponding to different cross-sections of organs within a given volume tend to have distinct distributions. That is, the slices toward the beginning/end of the volume that contain no target organs have very few, if any, segmentation labels (which we refer to as "label-sparse"); whereas segmentation labels are prolific in the slices toward the middle of the volume ("label-dense"). Such discrepancy in labels results in distinct difficulty levels measured by the training cross-entropy (Wang et al., 2021b) and leads to various training schedulers (Tullis & Benjamin, 2011; Tang et al., 2018; Hacohen & Weinshall, 2019) . Motivated by the separation between label-sparse and label-dense samples, we explore the following questions in this work: What is the effect of separation between sparse and dense labels on segmentation? Can we leverage such separation to improve the segmentation accuracy? We first formulate the mixture of label-sparse and label-dense samples as a subpopulation shift in the conditional distribution of labels given images P (y|x). As illustrated in Figure 1 , such subpopulation shift induces a separation in supervised cross-entropy between sparse and dense labels despite the uniformity of data images. However, these trimmed-loss-based methods discard the samples from some subpopulations (e.g. samples with label corruption estimated by their losses) at each iteration, which results in loss of 1



Prior works address the subpopulation shift issue by utilizing hard thresholding algorithms such as Trimmed Loss Estimator (Shen & Sanghavi, 2019), MKL-SGD(Shah et al., 2020), Ordered SGD (Kawaguchi & Lu, 2020), and quantile-based Kacmarz algorithm(Haddock et al., 2020).

