AdaWAC: ADAPTIVELY WEIGHTED AUGMENTATION CONSISTENCY REGULARIZATION FOR VOLUMETRIC MEDICAL IMAGE SEGMENTATION

Abstract

Sample reweighting is an effective strategy for learning from training data coming from a mixture of different subpopulations. However, existing reweighting algorithms do not fully take advantage of the particular type of data distribution encountered in volumetric medical image segmentation, where the training data images are uniformly distributed but their associated data labels fall into two subpopulations-"label-sparse" and "label-dense"-depending on whether the data image occurs near the beginning/end of the volumetric scan or the middle. For this setting, we propose AdaWAC as an adaptive weighting algorithm that assigns label-dense samples to supervised cross-entropy loss and label-sparse samples to unsupervised consistency regularization. We provide a convergence guarantee for AdaWAC by appealing to the theory of online mirror descent on saddle point problems. Moreover, we empirically demonstrate that AdaWAC not only enhances segmentation performance and sample efficiency but also improves robustness to the subpopulation shift in labels.

1. INTRODUCTION

Modern machine learning has been revolutionizing the field of medical imaging, especially in computer-aided diagnosis with Computed Tomography (CT) and Magnetic Resonance Imaging (MRI) scans. While the successes of most classical learning algorithms (e.g., empirical risk minimization (ERM)) build upon the assumption that training samples are independently and identically distributed (i.i.d.) , real-world volumetric medical images rarely fit into this picture. Specifically for medical image segmentation, as instantiated in Figure 1 , the segmentation labels corresponding to different cross-sections of organs within a given volume tend to have distinct distributions. That is, the slices toward the beginning/end of the volume that contain no target organs have very few, if any, segmentation labels (which we refer to as "label-sparse"); whereas segmentation labels are prolific in the slices toward the middle of the volume ("label-dense"). Such discrepancy in labels results in distinct difficulty levels measured by the training cross-entropy (Wang et al., 2021b) and leads to various training schedulers (Tullis & Benjamin, 2011; Tang et al., 2018; Hacohen & Weinshall, 2019) . Motivated by the separation between label-sparse and label-dense samples, we explore the following questions in this work: What is the effect of separation between sparse and dense labels on segmentation? Can we leverage such separation to improve the segmentation accuracy? We first formulate the mixture of label-sparse and label-dense samples as a subpopulation shift in the conditional distribution of labels given images P (y|x). As illustrated in Figure 1 , such subpopulation shift induces a separation in supervised cross-entropy between sparse and dense labels despite the uniformity of data images. Prior works address the subpopulation shift issue by utilizing hard thresholding algorithms such as Trimmed Loss Estimator (Shen & Sanghavi, 2019) , MKL-SGD (Shah et al., 2020) , Ordered SGD (Kawaguchi & Lu, 2020) , and quantile-based Kacmarz algorithm (Haddock et al., 2020) . However, these trimmed-loss-based methods discard the samples from some subpopulations (e.g. samples with label corruption estimated by their losses) at each iteration, which results in loss of information in the discarded data, leading to reduced sample efficiency. Relaxing the hard thresholding operator to soft thresholding is proposed to incorporate the information from both subpopulations (Wang et al., 2018; Sagawa et al., 2020) . However, lowering the weights assigned to some subpopulations of data according to the properties of their labels reduces the importance of the data and labels simultaneously, suggesting that we may further improve the learning efficiency by exploiting the uniformity of data and the separation of labels separately. Instead of thresholding out or down-weighting the label-sparse samples, we exploit the image inputs of these samples via augmentation consistency regularization. Consistency regularization (Bachman et al., 2014; Laine & Aila, 2016; Sohn et al., 2020) aims to learn the proximity between data augmentations of the same samples; crucially, this set-up does not involve the data labels, and hence consistency regularization has become an essential strategy for utilizing unlabeled data. For medical imaging tasks, consistency regularization has been extensively studied in the semi-supervised learning setting (Bortsova et al., 2019; Zhao et al., 2019; Li et al., 2020; Wang et al., 2021a; Zhang et al., 2021; Zhou et al., 2021; Basak et al., 2022) . By contrast, we will explore its potency in the fully supervised setting-leveraging the spare information in all image inputs, regardless of their label subpopulations. Moreover, in light of the uniformity of unsupervised consistency on different slices throughout each volume, the augmentation consistency of the encoder layer outputs serves as a natural reference for separating samples from different subpopulations. Whereby, we introduce the weighted augmentation consistency (WAC) regularization-a minimax formulation that not only incorporates the consistency regularization but also leverages the consistency regularization as a reference for reweighting the cross-entropy and the augmentation consistency terms corresponding to different samples. At the saddle point, the WAC regularization automatically separates samples from different label subpopulations by assigning all weight to the consistency terms for label-sparse samples, and all weight to the cross-entropy terms for label-dense samples. Furthermore, as an algorithm for solving the minimax problem posed by the WAC regularization, we propose AdaWAC-an adaptive weighting scheme inspired by a mirror-descent-based algorithm for distributionally-robust optimization (Sagawa et al., 2020) . By adaptively adjusting the weights between the cross-entropy and consistency terms of different samples, AdaWAC comes with both a convergence guarantee and empirical success. Overall, we summarize the main contributions of this work as follows: • We cast the discrepancy between the sparse and dense labels within each volume as a subpopulation shift in the conditional distribution P (y|x) (Section 2). • We propose WAC regularization which uses the consistency of encoder layer outputs (in a UNet architecture) as a natural reference to incentivize separation between samples with sparse and dense labels (Section 3), along with an adaptive weighting algorithm-AdaWAC-for solving the WAC regularization problem with a convergence guarantee (Section 4). • We empirically demonstrate the potency of AdaWAC not only in enhancing segmentation performance and sample efficiency but also in improving distributional robustness (Section 5).

1.1. RELATED WORK

Sample reweighting. Sample reweighting is a popular strategy for coping with subpopulation shifts in training data where different weights are assigned to samples from different subpopulations. In particular, the distributionally-robust optimization (DRO) (Ben-Tal et al., 2013; Duchi et al., 2016; Duchi & Namkoong, 2018; Sagawa et al., 2020) considers a collection of training sample groups from different distributions where, with the explicit grouping of samples, the goal is to minimize the worst-case loss over the groups. Without prior knowledge on sample grouping, importance sampling (Needell et al., 2014; Zhao & Zhang, 2015; Alain et al., 2015; Loshchilov & Hutter, 2015; Gopal, 2016; Katharopoulos & Fleuret, 2018) , iterative trimming (Kawaguchi & Lu, 2020; Shen & Sanghavi, 2019) , and empirical-loss-based reweighting (Wu et al., 2022) are commonly incorporated in the stochastic optimization process for adaptive reweighting and separation of samples from different subpopulations. Consistency regularization. Consistency regularization (Bachman et al., 2014; Laine & Aila, 2016; Sohn et al., 2020; Berthelot et al., 2019) is a popular way to exploit data augmentations that encourage the model to learn the vicinity among augmentations of the same sample, with the assumption that data augmentations generally preserve the semantic information in data. For medical imaging, consistency regularization is generally leveraged as a semi-supervised learning tool (Bortsova et al., 2019; Zhao et al., 2019; Li et al., 2020; Wang et al., 2021a; Zhang et al., 2021; Zhou et al., 2021; Basak et al., 2022) . In efforts to incorporate consistency regularization in medical image segmentation with augmentation-sensitive labels, Li et al. (2020) encourages transformation consistency between predictions with augmentations applied to the image inputs and the segmentation outputs. Basak et al. (2022) penalizes inconsistent segmentation outputs between teacherstudent models, with MixUp (Zhang et al., 2017) applied on image inputs of the teacher model and segmentation outputs of the student model. Instead of enforcing consistency in the segmentation output space as above, our algorithm leverages the insensitivity of sparse labels to augmentations and encourages consistent encodings (in the latent space of encoder outputs) on label-sparse samples.

2. PROBLEM SETUP

Notations. We denote [K] = {1, . . . , K} for any K ∈ N. For an arbitrary tensor, we adapt the syntax for Python slicing on the subscript (except counting from 1) to represent its elements and subtensors. For example, x [i,j] denotes the (i, j)-entry of the two-dimensional tensor x, and x [i,:] denotes its i-th row. Let I be a function onto {0, 1} such that, for any event e, I {e} = 1 if e is true and 0 otherwise. For any distribution P and n ∈ N, let P n denote the joint distribution of n samples drawn i.i.d. from P . We refer to an event as happening with high probability (w.h.p.) if it takes place with probability 1 -Ω (poly (n)) -foot_0 .

2.1. PIXEL-WISE CLASSIFICATION WITH SPARSE AND DENSE LABELS

We consider the volumetric medical image segmentation as a pixel-wise multi-class classification problem where we aim to learn a pixel-wise classifier h : X → [K] d that serves as a good approximation for the ground truth h * : X → [K] d . Recall the separation of cross-entropy losses between samples with different fractions of nonbackground labels during training from Figure 1 . We refer to a sample (x, y) ∈ X × [K] d as label-sparse if most pixels in y are labeled as the background such that the cross-entropy loss on (x, y) converges rapidly in the early stage of training 1 . Otherwise, we say that (x, y) is label-dense. Formally, we describe such variation as a subpopulation shift in the label distribution. Definition 1 (Mixture of label-sparse and label-dense distributions). Let P 0 and P 1 be the distributions of label-sparse and label-dense samples with distinct conditional distributions P 0 (y|x) and P 1 (y|x), respectively, but the same marginal distribution P (x) such that P i (x, y) = P i (y|x) P (x) (i = 0, 1). For ξ ∈ [0, 1], we define a data distribution P ξ where each sample (x, y) ∼ P ξ is drawn either from P 1 with probability ξ or from P 0 with probability 1 -ξ. We aim to learn a pixel-wise classifier from a function class H ∋ h θ = argmax k∈[K] f θ (x) [j,:] for all j ∈ [d] where the underlying function f θ ∈ F, parameterized by some θ ∈ F θ , admits an encoder-decoder structure: F = f θ = ϕ θ • ψ θ ϕ θ : X → Z, ψ θ : Z → [0, 1] d×K . Here ϕ θ , ψ θ correspond to the encoder and decoder functions, respectively; (F θ , ⟨•, •⟩ F ) denotes the inner product space of parameters with the induced norm ∥•∥ F and dual norm ∥•∥ F , * ; (Z, ϱ) is a latent metric space. To learn from segmentation labels, we consider the averaged cross-entropy loss: ℓ CE (θ; (x, y)) = - 1 d d j=1 K k=1 I y [j] = k • log f θ (x) [j,k] = - 1 d d j=1 log f θ (x) [j,y [j] ] . (2) We assume proper learning where there exists θ * ∈ ξ∈[0,1] argmin θ∈F θ E (x,y)∼P ξ [ℓ CE (θ; (x, y))] that is invariant with respect to ξ.

2.2. AUGMENTATION CONSISTENCY REGULARIZATION

Despite the invariance of f θ * to P ξ on the population loss, in practice we have a finite number of training samples and the predominance of label-sparse samples in the training set introduces difficulties due to the class imbalances. As a specific extreme scenario for the pixel-wise classifier with encoder-decoder structure (Equation ( 1)), when the label-sparse samples are predominant (ξ ≪ 1), a decoder function ψ θ that predicts every pixel as background can achieve near-optimal crossentropy loss, regardless of what the encoder function ϕ θ is, considerably compromising the test performance (cf. Table 1 ). To encourage legit encoding even in absence of sufficient dense labels, we leverage the unsupervised consistency regularization on the encoder function ϕ θ based on data augmentations. Let A be a distribution over transformations on X where for any x ∈ X , each A ∼ A (A : X → X ) induces an augmentation A (x) of x that perturbs low-level information in x. We aim to learn an encoder function ϕ θ : X → Z that is capable of filtering out low-level information from x and therefore provides similar encodings for augmentations of the same sample. Recalling the metric ϱ on Z, for a given scaling hyperparameter λ AC > 0, we measure the similarity between augmentations with a consistency regularization term on ϕ θ (•): for any A 1 , A 2 ∼ A 2 , ℓ AC (θ; x, A 1 , A 2 ) ≜ λ AC • ϱ ϕ θ (A 1 (x)) , ϕ θ (A 2 (x)) . (3) For the n training samples {(x i , y i )} i∈[n] ∼ P n ξ , we consider n pairs of data augmentation transfor- mations {(A i,1 , A i,2 )} i∈[n] ∼ A 2n . In the basic version, we encourage the similar encoding ϕ θ (•) of the augmentation pairs (A i,1 (x i ), A i,2 (x i )) for all i ∈ [n] via consistency regularization: min θ∈F θ * (γ) 1 n n i=1 ℓ CE (θ; (x i , y i )) + ℓ AC (θ; x i , A i,1 , A i,2 ) . ( ) We enforce consistency on ϕ θ (•) in light of the encoder-decoder architecture: the encoder is generally designed to abstract essential information and filters out low-level non-semantic perturbations (e.g., those introduced by augmentations), while the decoder recovers the low-level information for the pixel-wise classification. Therefore, with different A 1 , A 2 ∼ A, the encoder output ϕ θ (•) tends to be more consistent than the other intermediate layers, especially for label-dense samples.

3. WEIGHTED AUGMENTATION CONSISTENCY (WAC) REGULARIZATION

As the motivation, we begin with a key observation about the averaged cross-entropy: Remark 1 (Separation of averaged cross-entropy loss on P 0 and P 1 ). As demonstrated in Figure 1 , the sparse labels from P 0 tend to be much easier to learn than the dense ones from P 1 , leading to considerable separation of averaged cross-entropy losses on the sparse and dense labels after a sufficient number of training epochs -ℓ CE (θ; (x, y)) ≪ ℓ CE (θ; (x ′ , y ′ )) for most label-sparse samples (x, y) ∼ P 0 and label-dense samples (x ′ , y ′ ) ∼ P 1 . Although Equation ( 4) with consistency regularization alone can boost the segmentation accuracy during testing (cf. Table 4 ), it does not take the separation between label-sparse and label-dense samples into account. In Section 5, we will empirically demonstrate that proper exploitation of such separation, like the formulation introduced below, can bring compelling further improvements. Concretely, we formalized the notion of separation between P 0 and P 1 based on the consistency regularization term (Equation ( 3)) with the following assumptionfoot_1 . Assumption 1 (N -separation between P 0 and P 1 ). Given a sufficiently small γ > 0, let F θ * (γ) = {θ ∈ F θ | ∥θ -θ * ∥ F ≤ γ} be a compact and convex neighborhood of well-trained pixel-wise classifiersfoot_2 . We say that P 0 and P 1 are N -separated over F θ * (γ) if there exists ω > 0 such that both: (i) ℓ CE (θ; (x, y)) < ℓ AC (θ; x, A 1 , A 2 ) for all θ ∈ F θ * (γ) given (x, y) ∼ P 0 (ii) ℓ CE (θ; (x, y)) > ℓ AC (θ; x, A 1 , A 2 ) for all θ ∈ F θ * (γ) given (x, y) ∼ P 1 hold with probability 1 -Ω N 1+ω -1 over ((x, y) , (A 1 , A 2 )) ∼ P ξ × A 2 . This assumption is motivated by the empirical observation that the perturbation in ϕ θ (•) induced by A is more uniform across P 0 and P 1 than the averaged cross-entropy, as instantiated in Figure 3 . Under Assumption 1, up to a proper scaling hyperparameter λ AC , the consistency regularization (Equation ( 3)) can separate the averaged cross-entropy loss (Equation ( 2)) on N label-sparse and label-dense samples with probability 1 -Ω (N ω ) -1 by the union bound (as explained formally in Appendix A). In particular, the larger N correspond to the stronger separation between P 0 and P 1 . With Assumption 1, we introduce a minimax formulation that incentivizes the separation of labelsparse and label-dense samples automatically by introducing a flexible weight β [i] ∈ [0, 1] that balances ℓ CE (θ; (x i , y i )) and ℓ AC (θ; x i , A i,1 , A i,2 ) for each of the n samples. θ WAC , β ∈ argmin θ∈F θ * (γ) argmax β∈[0,1] n L WAC (θ, β) ≜ 1 n n i=1 L WAC i (θ, β) L WAC i (θ, β) ≜ β [i] • ℓ CE (θ; (x i , y i )) + (1 -β [i] ) • ℓ AC (θ; x i , A i,1 , A i,2 ) . With convex and continuous loss and regularization terms (formally in Proposition 1), Equation ( 5) has a saddle point where β separates the label-sparse and label-dense samples under Assumption 1. Proposition 1 (Formal proof in Appendix A). Assume that ℓ CE (θ; (x, y)) and ℓ AC (θ; x, A 1 , A 2 ) are convex and continuous in θ for all (x, y) ∈ X × [K] d and A 1 , A 2 ∼ A 2 ; F θ * (γ) ⊂ F θ is compact and convex. If P 0 and P 1 are n-separated (Assumption 1), then there exists β ∈ {0, 1} n and θ WAC ∈ argmin θ∈F θ * (γ) L WAC θ, β such that min θ∈F θ * (γ) L WAC θ, β = L WAC θ WAC , β = max β∈[0,1] n L WAC θ WAC , β . Further, β separates the label-sparse and label-dense samples- β [i] = I {(x i , y i ) ∼ P 1 }-w.h.p.. In other words, for n samples drawn from a mixture of the n-separated P 0 and P 1 , at the saddle point, Equation (5) automatically identifies the label-sparse samples with β [i] = 0, learning more from the unsupervised consistency regularization, and the label-dense ones with β [i] = 1, emphasizing more on the supervised averaged cross-entropy loss.

4. ADAPTIVELY WEIGHTED AUGMENTATION CONSISTENCY (AdaWAC)

Remark 2 (Connection to hard thresholding algorithms). The saddle point of Equation ( 5) is closely related to hard thresholding algorithms like Ordered SGD (Kawaguchi & Lu, 2020) and iterative trimmed loss (Shen & Sanghavi, 2019) . In each iteration, these algorithms update the model only on a proper subset of training samples based on the (ranking of) current empirical risks. Compared to hard thresholding algorithms, (i) Equation ( 5) additionally leverages the unused samples (e.g., label-sparse samples) for unsupervised consistency regularization on data augmentations, which is known for improving generalization and feature learning even in supervised settings (Yang et al., 2022; Shen et al., 2022) ; (ii) meanwhile, it does not require prior knowledge of the sample subpopulations (e.g., ξ for P ξ ) which is essential for hard thresholding algorithms. Equation ( 5) further facilitates the more flexible optimization process. As we will empirically show in Table 2 , despite the close relation between Equation ( 5) and the hard thresholding algorithms (Remark 2), such updating strategies may be suboptimal for solving Equation ( 5). Algorithm 1 Adaptively Weighted Augmentation Consistency (AdaWAC) 1: Input: Training samples {(x i , y i )} i∈[n] ∼ P n ξ , augmentations {(A i,1 , A i,2 )} i∈[n] ∼ A 2n , maximum number of iterations T ∈ N, learning rates η θ , η β > 0, pretrained initialization for the pixel-wise classifier θ 0 ∈ F θ * (γ). 2: Initialize the sample weights β 0 = 1/2 ∈ [0, 1] n . 3: for t = 1, . . . , T do 4: Sample i t ∼ [n] uniformly 5: b ← (β t-1 ) [it] , 1 -(β t-1 ) [it] 6: b [1] ← b [1] • exp (η β • ℓ CE (θ t-1 ; (x it , y it ))) 7: b [2] ← b [2] • exp (η β • ℓ AC (θ t-1 ; x it , A it,1 , A it,2 )) 8: β t ← β t-1 , (β t ) [it] ← b [1] / ∥b∥ 1 9: θ t ← θ t-1 -η θ • (β t ) [it] • ∇ θ ℓ CE (θ t-1 ; (x it , y it )) + 1 -(β t ) [it] • ∇ θ ℓ AC (θ t-1 ; x it , A it,1 , A it,2 ) 10: end for Inspired by the breakthrough made by Sagawa et al. (2020) in the distributionally-robust optimization (DRO) setting where gradient updating on weights is shown to enjoy better convergence guarantees than hard thresholding, in Algorithm 1, we introduce an adaptive weighting algorithm for solving Equation ( 5) based on online mirror descent. In contrast to the commonly used stochastic gradient descent (SGD), the flexibility of online mirror descent in choosing the associated norm space not only allows gradient updates on sample weights, but also grants distinct learning dynamics to sample weights β t and model parameters θ t , which leads to the following convergence guarantee. Proposition 2 (Formally in Proposition 3, proof in Appendix B, assumptions instantiated in Example 1). Assume that ℓ CE (θ; (x, y)) and ℓ AC (θ; x, A 1 , A 2 ) are convex and continuous in θ for all (x, y) ∈ X × [K] d and A 1 , A 2 ∼ A 2 ; F θ * (γ) ⊂ F θ is convex and compact. If there ex- ist 4 (i) C θ, * > 0 such that 1 n n i=1 ∇ θ L WAC i (θ, β) 2 F , * ≤ C 2 θ, * and (ii) C β, * > 0 such that 1 n n i=1 max {ℓ CE (θ; (x i , y i )) , ℓ AC (θ; x i , A i,1 , A i,2 )} 2 ≤ C 2 β, * for all θ ∈ F θ * (γ), β ∈ [0, 1] n , then with η θ = η β = 2 5T (γ 2 C 2 θ, * +2nC 2 β, * ) , Algorithm 1 provides E max β∈[0,1] n L WAC θ T , β -min θ∈F θ * (γ) L WAC θ, β T ≤ 2 5 γ 2 C 2 θ, * + 2nC 2 β, * T where θ T = 1 T T t=1 θ t and β T = 1 T T t=1 β t . In addition to the convergence guarantee, Algorithm 1 also demonstrates superior performance over hard thresholding algorithms for segmentation problems in practice (Table 2 ). An intuitive explanation is that instead of filtering out all the label-sparse samples via hard thresholding, the adaptive weighting allows the model to learn from some sparse labels at the early epochs, while smoothly down-weighting ℓ CE of these samples since learning sparse labels tends to be easier (Remark 1). With the learned model tested on a mixture of label-sparse and label-dense samples, learning sparse labels at the early stage is crucial for accurate segmentation.

5. EXPERIMENTS

In this section, we investigate the proposed AdaWAC algorithm (Algorithm 1) on different medical image segmentation tasks with different UNet-like architectures. We first demonstrate the performance improvements brought by AdaWAC in terms of sample efficiency and robustness to subpopulation shift (Table 1 ). Then, we verify the empirical advantage of AdaWAC compared to the closely related hard thresholding algorithms as discussed in Remark 2 (Table 2 ). Our ablation study (Table 4 ) further illustrates the indispensability of both sample reweighting and consistency regularization, the deliberate combination of which leads to the superior performance of AdaWACfoot_4 . 

5.1. SEGMENTATION PERFORMANCE OF AdaWAC WITH TRANSUNET

Segmentation on Synapse. Figure 2 visualizes the segmentation predictions on 6 Synapse test slices given by models trained via AdaWAC (ours) and via the baseline (ERM+SGD) with Tran-sUNet (Chen et al., 2021) . We observe that AdaWAC provides more accurate predictions on the segmentation boundaries and captures small organs better than the baseline. Visualization of AdaWAC. As shown in Figure 3 , with ℓ CE (θ t ; (x i , y i )) (Equation ( 2)) of labelsparse versus label-dense slices weakly separated in the early epochs, the model further learns to distinguish ℓ CE (θ t ; (x i , y i )) of label-sparse/label-dense slices during training. By contrast, ℓ AC (θ t ; x i , A i,1 , A i,2 ) (Equation ( 3)) remains mixed for all the slices in the entire training process. As a result, the CE weights of label-sparse slices are much smaller than those of label-dense ones, pushing AdaWAC to learn more image representations but less pixel classification for slices with sparse labels and learn more pixel classification for slices with dense labels. 1 ). Specifically, (i) half-slice contains slices with even indices only in each volumefoot_7 ; (ii) half-vol consists of 9 volumes uniformly sampled from the total 18 volumes in full where different volumes tend to have distinct ξs (i.e., ratios of label-dense samples); (iii) half-sparse takes the first half slices in each volume, most of which tend to be label-sparse (i.e., ξs are made to be small). As shown in Table 1 , the model trained with AdaWAC on half-slice generalizes as well as a baseline model trained on full, if not better. Moreover, the half-vol and half-sparse experiments illustrate the robustness of AdaWAC to subpopulation shift. Furthermore, such sample efficiency and distributional robustness of AdaWAC extend to the more widely used UNet architecture. We defer the detailed results and discussions on UNet to Appendix E.1. Comparison with hard thresholding algorithms. Table 2 illustrates the empirical advantage of AdaWAC over the hard thresholding algorithms, as suggested in Remark 2. In particular, we consider the following hard thresholding algorithms: (i) trim-train learns only from slices with at least one non-background pixel and trims the rest in each iteration on the fly; (ii) trim-ratio ranks the cross-entropy loss ℓ CE (θ t ; (x i , y i )) in each iteration (mini-batch) and trims samples with the lowest cross-entropy losses at a fixed ratio -the ratio of all-background slices in the full training set (1 -1280 2211 ≈ 0.42); (iii) ACR further incorporates the augmentation consistency regularization directly via addition of ℓ AC (θ t ; x i , A i,1 , A i,2 ) without reweighting; (iv) pseudo-AdaWAC simulates the sample weights β at the saddle point and learns via ℓ CE (θ t ; (x i , y i )) on slices with at least one non-background pixel while via ℓ AC (θ t ; x i , A i,1 , A i,2 ) otherwise. We notice that naive incorporation of ACR brings less observable boosts to the hard-thresholding methods. Therefore, the deliberate combination via reweighting in AdaWAC is essential for performance improvement. Segmentation on ACDC. Performance improvements granted by AdaWAC are also observed on the ACDC dataset (Table 3 ). We defer detailed visualization of ACDC segmentation to Appendix E. 

5.2. ABLATION STUDY

On the influence of consistency regularization. To illustrate the role of consistency regularization in AdaWAC, we consider the reweight-only scenario with λ AC = 0 such that ℓ AC (θ t ; x i , A i,1 , A i,2 ) ≡ 0 and therefore b [2] (Algorithm 1 line 7) remains intact. With zero consistency regularization in AdaWAC, reweighting alone brings little improvement (Table 4 ). On the influence of sample reweighting. We then investigate the effect of sample reweighting under different reweighting learning rates η β (recall Algorithm 1): (i) ACR-only for η β = 0 (equivalent to the naive addition of ℓ AC (θ t ; x i , A i,1 , A i,2 )), (ii) AdaWAC-0.01 for η β = 0.01, and (iii) AdaWAC-1.0 for η β = 1.0. As Table 4 implies, when removing reweighting from AdaWAC, augmentation consistency regularization alone improves DSC slightly from 76.28 (baseline) to 77.89 (ACR-only), whereas AdaWAC boosts DSC to 79.12 (AdaWAC-1.0) with a proper choice of η β . 

6. DISCUSSION

In this paper, we exploit the non-uniformity in labels commonly observed in volumetric medical image segmentation via AdaWAC-a deliberate combination of adaptive weighting and augmentation consistency regularization. By casting the separation between sparse and dense segmentation labels as a subpopulation shift in the label distribution, we leverage the unsupervised consistency regularization on encoder layer outputs (of UNet architectures) as a natural reference to distinguish label-sparse and label-dense samples via comparisons against the supervised average cross-entropy losses. We formulate such comparisons as a weighted augmentation consistency (WAC) regularization problem and propose an adaptive weighting scheme-AdaWAC-for iterative and smooth separation of samples from different subpopulations with a convergence guarantee. Our experiments demonstrate empirical effectiveness of AdaWAC not only in improving segmentation performance and sample efficiency but also in enhancing robustness to the subpopulation shift in labels.



Although the sparsity of non-background pixels in the segmentation label is a key feature of label-sparse samples (as the name suggests), the unknown cut-off on such sparsity degenerates it as a sufficient condition for the rapid convergence of cross-entropy loss (Figure1). Instead of making distinction with the sparsity of non-background pixels, we formalize a natural separation between label-sparse and label-dense samples in Assumption 1, based on which our algorithm can distinguish different samples spontaneously. We note that although Assumption 1 can be rather strong, it is only required for the separation guarantee of label-sparse and label-dense samples with high probability in Proposition 1, but not for the adaptive weighting algorithm introduced in Section 4 or in practice for the experiments. With pretrained initialization, we assume that the optimization algorithm is always probing in F θ * (γ). Following the convention, we use * in subscription to denote the dual spaces. For instance, recalling the parameter space F θ characterized by the norm ∥•∥ F from Section 2.1, we use ∥•∥ F , * to denote its dual norm; while C θ, * , C β, * upper bound the dual norms of the gradients with respect to θ and β. We release our code anonymously at https://anonymous.4open.science/r/adawac-F5F8. https://www.synapse.org/#!Synapse:syn3193805/wiki/217789 https://www.creatis.insa-lyon.fr/Challenge/acdc/ Such sampling is equivalent to doubling the time interval between two consecutive scans or halving the scanning frequency in practice, resulting in the halving of sample size.



Figure 1: Evolution of cross-entropy losses versus consistency regularization terms for slices across one training volume (Case 40) in the Synapse dataset (Section 5) during training.

Experiment setup. We conduct experiments on two volumetric medical image segmentation tasks: abdominal CT segmentation for Synapse multi-organ dataset (Synapse) 6 and cine-MRI segmentation for Automated cardiac diagnosis challenge dataset (ACDC) 7 , with two UNet-like architectures: TransUNet(Chen et al., 2021) and UNetRonneberger et al. (2015) (deferred to Appendix E.1). For the main experiments with TransUNet in Section 5, we follow the official implementation in(Chen et al., 2021) and use ERM+SGD as the baseline. We evaluate segmentations with two standard metrics-the average Dice-similarity coefficient (DSC) and average 95-percentile of Hausdorff distance (HD95). Dataset and implementation details are deferred to Appendix D. Given the sensitivity of medical image semantics to perturbations, our experiments only involve simple augmentations (i.e., rotation and mirroring) adapted from(Chen et al., 2021).

Figure 2: Visualization of segmentation predictions against the ground truth (in grayscale) on Synapse. Top to bottom: ground truth, ours (AdaWAC), baseline.

Figure 3: ℓ CE (θ t ; (x i , y i )) (top), CE weights β t (middle), and ℓ AC (θ t ; x i , A i,1 , A i,2 ) (bottom) of the entire Synapse training process. The x-axis indices slices 0-2211. The y-axis enumerates epochs 0-150. Individual volumes (cases) are partitioned by black lines; while the purple lines separate slices with/without non-background pixels.

AdaWAC with TransUNet trained on the full Synapse and its subsets.

AdaWAC versus hard thresholding algorithms with TransUNet on Synapse.

AdaWAC with TransUNet trained on ACDC.

Ablation study of AdaWAC with TransUNet trained on Synapse.

