AdaWAC: ADAPTIVELY WEIGHTED AUGMENTATION CONSISTENCY REGULARIZATION FOR VOLUMETRIC MEDICAL IMAGE SEGMENTATION

Abstract

Sample reweighting is an effective strategy for learning from training data coming from a mixture of different subpopulations. However, existing reweighting algorithms do not fully take advantage of the particular type of data distribution encountered in volumetric medical image segmentation, where the training data images are uniformly distributed but their associated data labels fall into two subpopulations-"label-sparse" and "label-dense"-depending on whether the data image occurs near the beginning/end of the volumetric scan or the middle. For this setting, we propose AdaWAC as an adaptive weighting algorithm that assigns label-dense samples to supervised cross-entropy loss and label-sparse samples to unsupervised consistency regularization. We provide a convergence guarantee for AdaWAC by appealing to the theory of online mirror descent on saddle point problems. Moreover, we empirically demonstrate that AdaWAC not only enhances segmentation performance and sample efficiency but also improves robustness to the subpopulation shift in labels.

1. INTRODUCTION

Modern machine learning has been revolutionizing the field of medical imaging, especially in computer-aided diagnosis with Computed Tomography (CT) and Magnetic Resonance Imaging (MRI) scans. While the successes of most classical learning algorithms (e.g., empirical risk minimization (ERM)) build upon the assumption that training samples are independently and identically distributed (i.i.d.) , real-world volumetric medical images rarely fit into this picture. Specifically for medical image segmentation, as instantiated in Figure 1 , the segmentation labels corresponding to different cross-sections of organs within a given volume tend to have distinct distributions. That is, the slices toward the beginning/end of the volume that contain no target organs have very few, if any, segmentation labels (which we refer to as "label-sparse"); whereas segmentation labels are prolific in the slices toward the middle of the volume ("label-dense"). Such discrepancy in labels results in distinct difficulty levels measured by the training cross-entropy (Wang et al., 2021b) and leads to various training schedulers (Tullis & Benjamin, 2011; Tang et al., 2018; Hacohen & Weinshall, 2019) . Motivated by the separation between label-sparse and label-dense samples, we explore the following questions in this work: What is the effect of separation between sparse and dense labels on segmentation? Can we leverage such separation to improve the segmentation accuracy? We first formulate the mixture of label-sparse and label-dense samples as a subpopulation shift in the conditional distribution of labels given images P (y|x). As illustrated in Figure 1 , such subpopulation shift induces a separation in supervised cross-entropy between sparse and dense labels despite the uniformity of data images. Prior works address the subpopulation shift issue by utilizing hard thresholding algorithms such as Trimmed Loss Estimator (Shen & Sanghavi, 2019), MKL-SGD (Shah et al., 2020) , Ordered SGD (Kawaguchi & Lu, 2020), and quantile-based Kacmarz algorithm (Haddock et al., 2020) . However, these trimmed-loss-based methods discard the samples from some subpopulations (e.g. samples with label corruption estimated by their losses) at each iteration, which results in loss of information in the discarded data, leading to reduced sample efficiency. Relaxing the hard thresholding operator to soft thresholding is proposed to incorporate the information from both subpopulations (Wang et al., 2018; Sagawa et al., 2020) . However, lowering the weights assigned to some subpopulations of data according to the properties of their labels reduces the importance of the data and labels simultaneously, suggesting that we may further improve the learning efficiency by exploiting the uniformity of data and the separation of labels separately. et al., 2019; Zhao et al., 2019; Li et al., 2020; Wang et al., 2021a; Zhang et al., 2021; Zhou et al., 2021; Basak et al., 2022) . By contrast, we will explore its potency in the fully supervised setting-leveraging the spare information in all image inputs, regardless of their label subpopulations. Moreover, in light of the uniformity of unsupervised consistency on different slices throughout each volume, the augmentation consistency of the encoder layer outputs serves as a natural reference for separating samples from different subpopulations. Whereby, we introduce the weighted augmentation consistency (WAC) regularization-a minimax formulation that not only incorporates the consistency regularization but also leverages the consistency regularization as a reference for reweighting the cross-entropy and the augmentation consistency terms corresponding to different samples. At the saddle point, the WAC regularization automatically separates samples from different label subpopulations by assigning all weight to the consistency terms for label-sparse samples, and all weight to the cross-entropy terms for label-dense samples. Furthermore, as an algorithm for solving the minimax problem posed by the WAC regularization, we propose AdaWAC-an adaptive weighting scheme inspired by a mirror-descent-based algorithm for distributionally-robust optimization (Sagawa et al., 2020) . By adaptively adjusting the weights between the cross-entropy and consistency terms of different samples, AdaWAC comes with both a convergence guarantee and empirical success. Overall, we summarize the main contributions of this work as follows: • We cast the discrepancy between the sparse and dense labels within each volume as a subpopulation shift in the conditional distribution P (y|x) (Section 2). • We propose WAC regularization which uses the consistency of encoder layer outputs (in a UNet architecture) as a natural reference to incentivize separation between samples with sparse and dense labels (Section 3), along with an adaptive weighting algorithm-AdaWAC-for solving the WAC regularization problem with a convergence guarantee (Section 4). • We empirically demonstrate the potency of AdaWAC not only in enhancing segmentation performance and sample efficiency but also in improving distributional robustness (Section 5).



Figure 1: Evolution of cross-entropy losses versus consistency regularization terms for slices across one training volume (Case 40) in the Synapse dataset (Section 5) during training.

