MODEL PATCHING: CLOSING THE SUBGROUP PERFORMANCE GAP WITH DATA AUGMENTATION

Abstract

Classifiers in machine learning are often brittle when deployed. Particularly concerning are models with inconsistent performance on specific subgroups of a class, e.g., exhibiting disparities in skin cancer classification in the presence or absence of a spurious bandage. To mitigate these performance differences, we introduce model patching, a two-stage framework for improving robustness that encourages the model to be invariant to subgroup differences, and focus on class information shared by subgroups. Model patching first models subgroup features within a class and learns semantic transformations between them, and then trains a classifier with data augmentations that deliberately manipulate subgroup features. We instantiate model patching with CAMEL, which (1) uses a CycleGAN to learn the intra-class, inter-subgroup augmentations, and (2) balances subgroup performance using a theoretically-motivated subgroup consistency regularizer, accompanied by a new robust objective. We demonstrate CAMEL's effectiveness on 3 benchmark datasets, with reductions in robust error of up to 33% relative to the best baseline. Lastly, CAMEL successfully patches a model that fails due to spurious features on a real-world skin cancer dataset. We contribute a theoretical analysis (Section 3) to motivate our end-to-end framework. Our analysis codifies the distributional assumptions underlying the class-subgroup hierarchy and motivates our new consistency regularizer, which has a simple information theoretic interpretation under this framework. First, we introduce a natural model for the data generating process that decouples an example from its subgroup. Under this model, we prove that our consistency loss in Stage 2, when applied to subgroup augmentations from Stage 1, bounds the mutual information between the classifier output and the subgroup labels. Thus, training with our end-to-end framework forces the classifier to be invariant to subgroup-specific features. We conduct an extensive empirical study (Section 4) that validates CycleGAN Augmented Model Patching (CAMEL)'s ability to improve subgroup invariance and robustness. We first evaluate CAMEL on a controlled MNIST setup, where it cuts robust error rate to a third of other approaches while learning representations that are far more invariant, as measured by mutual information estimates. On two machine learning benchmarks CelebA and Waterbirds, CAMEL consistently outperforms state-of-the-art approaches that rely on robust optimization, with reductions in subgroup performance gap by up to 10%. Next, we perform ablations on each stage of our framework: (i) replacing the CycleGAN with state-of-the-art heuristic augmentations worsens the subgroup performance gap by 3.35%; (ii) our subgroup consistency regularizer improves robust accuracy by up to 2.5% over prior consistency losses. As an extension, we demonstrate that CAMEL can be used in combination with heuristic augmentations, providing further gains in robust accuracy of 1.5%. Besides CycleGANs, we show that other GAN-based augmentation methods can also be made significantly more robust by combining them with Stage 2 of model patching. Lastly, on the challenging real-world skin cancer dataset ISIC, CAMEL improves robust accuracy by 11.7% compared to a group robustness baseline. Our results suggest that model patching is a promising direction for improving subgroup robustness in real applications. Code for reproducing our results is available on GitHub.

1. INTRODUCTION

Machine learning models typically optimize for average performance, and when deployed, can yield inaccurate predictions on important subgroups of a class. For example, practitioners have noted that on the ISIC skin cancer detection dataset (Codella et al., 2018) , classifiers are more accurate on images of benign skin lesions with visible bandages, when compared to benign images where no bandage is present (Bissoto et al., 2019; Rieger et al., 2019) . This subgroup performance gap is an undesirable consequence of a classifier's reliance on subgroupspecific features, e.g. spuriously associating colorful bandages with a benign cancer class (Figure 1 ). A common strategy to side-step this issue is to use manual data augmentation to erase the differences between subgroups, e.g., using Photoshop (Winkler et al., 2019) or image tools (Rieger et al., 2019) to remove markings on skin cancer data before retraining a classifier. However, hand-crafting these augmentations may be impossible if the subgroup differences are difficult to manually express. Ideally, we would automatically learn the features differentiating the subgroups of a class, and then encourage a classifier to be invariant to these features when making its prediction. To this end, we introduce model patching, a framework that encapsulates this solution in two stages: • Learn inter-subgroup transformations. Isolate features that differentiate subgroups within a class, learning inter-subgroup transformations between them. These transformations change an example's subgroup identity but preserve the class label. • Train to patch the model. Leverage the transformations as controlled data augmentations that manipulate subgroup features, encouraging the classifier to be robust to their variation. In the first stage of model patching (Section 2.1), we learn, rather than specify, the differences between the subgroups of a class. We assume that these subgroups are known to the user, e.g. this is common when users perform error analysis (Oakden-Rayner et al., 2019) . Our key insight here is to learn these differences as inter-subgroup transformations that modify the subgroup membership of examples, while preserving class membership. Applying these semantic transformations as data augmentations in the second stage allows us to generate "imagined" versions of an example in the other subgroups of its class. This contrasts with conventional data augmentation, where heuristics such as rotations, flips, MixUp or CutOut (DeVries & Taylor, 2017; Zhang et al., 2017) are hand-crafted rather than learned. While these heuristics have been shown to improve robustness (Hendrycks et al., 2019) , the invariances they target are not well understood. Even when augmentations are learned (Ratner et al., 2017a) , they are used to address data scarcity, rather than manipulate examples to improve robustness in a prescribed way. Model patching is the first framework for data augmentation that directly targets subgroup robustness. Figure 1 : A vanilla model trained on a skin cancer dataset exhibits a subgroup performance gap between images of malignant cancers with and without colored bandages. GradCAM (Selvaraju et al., 2017) illustrates that the vanilla model spuriously associates the colored spot with benign skin lesions. With model patching, the malignancy is predicted correctly for both subgroups. The goal of the second stage (Section 2.2) is to appropriately use the transformations to remove the classifier's dependence on subgroup-specific features. We introduce two algorithmic innovations that target subgroup robustness: (i) a subgroup robust objective and; (ii) a subgroup consistency regularizer. Our subgroup robust objective extends prior work on group robustness (Sagawa et al., 2020) to our subgroup setting, where classes and subgroups form a hierarchy (Figure 2 left). Our new subgroup consistency regularizer constrains the predictions on original and augmented examples to be similar. While recent work on consistency training (Hendrycks et al., 2019; Xie et al., 2019) has been empirically successful in constructing models that are robust to perturbations, our consistency loss carries theoretical guarantees on the model's robustness. We note that our changes are easy to add on top of standard classifier training. 

2. CAMEL: CYCLEGAN AUGMENTED MODEL PATCHING

In this section, we walk through CAMEL's two-stage framework (Figure 2 ) in detail. In Section 2.1, we introduce Stage 1 of model patching, learning class-conditional transformations between subgroups. In Section 2.2, Stage 2 uses these transformations as black-box augmentations to train a classifier using our new subgroup robust objective (Section 2.2.1) and consistency regularizer (Section 2.2.2). Section 3 outlines our theoretical analysis on the invariance guarantees of our method. A glossary for all notation is included in Appendix A. Setup. We consider a classification problem where X ⊂ R n is the input space, and Y = {1, 2, . . . , C} is a set of labels over C classes. Each class y ∈ Y may be divided into disjoint subgroups Z y ⊆ Zfoot_0 . Jointly, there is a distribution P over examples, class labels, and subgroups labels (X, Y, Z). Given a dataset {(x i , y i , z i )} m i=1 , our goal is to learn a class prediction model f θ : X → ∆ C parameterized by θ, where ∆ C denotes a probability distribution over Y.

2.1. STAGE 1: LEARNING INTER-SUBGROUP TRANSFORMATIONS

The goal of the first stage is to learn transformations F z→z : X z → X z that translate examples in subgroup z to subgroup z , for every pair of subgroups z, z ∈ Z y in the same class y. Recent work has made impressive progress on such cross-domain generative models, where examples from one domain are translated to another, ideally preserving shared semantics while only changing domain-specific features. In this work, we use the popular CycleGAN model (Zhu et al., 2017) to learn mappings between pairs of subgroups, although we show that it is possible to substitute other models (see Section 4.2.4). Given datasets {x z } p i=1 , {x z } p i=1 from a pair of subgroups z, z ∈ Z y , we train a CycleGAN F z→z to transform between them. When classes have more than two subgroups, pairwise models can be trained between subgroups, or multi-domain models such as the StarGAN (Choi et al., 2018) can be used. We include a review of CycleGANs in Appendix C.1. Given these transformations {F z→z } z,z ∈Zy , we generate augmented data for every training example (x, y, z) by passing it through all F z→z , z ∈ Z y . We denote these generated examples xZy := {x z } z ∈Zy where xz = F z→z (x). For convenience, k denotes the number of subgroups |Z y |. Prior work that uses data augmentation to improve robustness has mostly relied on heuristic augmentations, and focused on robustness to out-of-distribution examples (Hendrycks et al., 2019) with (x, y) for the subgroup z, and α θ (x, y) = I[(arg max f θ (x)) = y] denotes correct prediction on an example.

Metric of Interest

Loss L(θ) ERM E P α θ (x, y) E P (f θ (x), y) GDRO min z∈Z E P z α θ (x, y) max z∈Z E Pz (f θ (x), y) SGDRO | max z∈Z y E P z α θ (x, y) -min z∈Z y E P z α θ (x, y)| E y∈Y {max z∈Z y E Pz (f θ (x), y)} empirical studies. In contrast, we learn to transform examples rather than specifying augmentations directly, and focus on improving worst-case subgroup robustness. We emphasize that while others have used cross-domain generative models for data augmentation, our novelty lies in targeting invariance to subgroup features using this style of augmentation. Past work has focused on domain adaptation (Huang et al., 2018) , few-shot learning (Antoniou et al., 2017) , and data scarcity (Bowles et al., 2018; Ratner et al., 2017b) , but has not attempted to explicitly control the invariance of the classifier using the learned augmentations. As we describe in our theoretical analysis (Section 3), our use of cross-domain models is a natural consequence of the class-subgroup setting.

2.2. STAGE 2: SUBGROUP ROBUSTNESS WITH DATA AUGMENTATION

The goal of the second stage is to learn a classifier f θ on both the original and augmented data from Stage 1, using our subgroup robust objective (Section 2.2.1) and consistency regularizer (Section 2.2.2). Our robustness objective targets worst-case subgroup robustness, while our consistency regularizer forces the learned classifier to be invariant to subgroup features. Where relevant, we include discussion here on differences to prior work, with an extended related work in Appendix B.

2.2.1. A SUBGROUP ROBUSTNESS OBJECTIVE

We review two established objectives for training classifiers with their associated metrics and loss functions, and introduce our new objective to target subgroup robustness (cf. Table 1 ). Prior work: Empirical Risk Minimization (ERM). The usual training goal is to maximize the aggregate accuracy, optimized using the empirical risk w.r.t. a proxy loss function (Table 1, top) . Prior work: Group Robustness (GDRO). In our setting, aggregate performance is too coarse a measure of risk, since classes have finer-grained groups of interest. This can be accounted for by optimizing the worst-case performance over these groups. Letting P z denote the conditional distribution of examples associated with subgroup z ∈ Z, the robust accuracy can be quantified by measuring the worst-case performance among all groups. This can be optimized by minimizing the corresponding group robust risk (Table 1 , middle right). A stochastic algorithm for this group distributionally robust optimization (GDRO) objective was recently proposed (Sagawa et al., 2020) . Class-conditional Subgroup Robustness (SGDRO). The GDRO objective treats group structure as a flat hierarchy. While this approach accounts for worst-case subgroup performance, it loses the class-subgroup hierarchy of our setting. Tailored to this, we create the SGDRO training objective (Table 1 , bottom right) to optimize class-conditional worst-case subgroup robustness, aggregated over all classes (Figure 2 right). To measure subgroup robustness, we define the subgroup performance gap (Table 1 , bottom left) for a class as the gap between its best and worst performing subgroups. We note that both GDRO and SGDRO assume knowledge of the subgroups, which is a standard assumption in group robustness (Arjovsky et al., 2019; Ganin et al., 2016; Sagawa et al., 2020) .

2.2.2. SUBGROUP INVARIANCE USING A CONSISTENCY REGULARIZER

Standard models can learn to rely on spurious subgroup features when making predictions. Subgroup consistency regularization targets this problem by enforcing consistency on subgroup-augmented data, encouraging the classifier to become invariant to subgroup-features. Recall that Stage 2 connects to Stage 1 by receiving augmented data xZy , representing "imagined" versions of an example x in all other subgroups z of its class y. We define the self-consistency loss L s and translation-consistency loss L t as follows, where m = 1 k z f θ (x z ) denotes the average output distribution on the augmented examples. L s (x, xZy ; θ) = 1 k z∈Zy KL (f θ (x z ) m) (1) L t (x, xZy ; θ) = KL (f θ (x) m) The self-consistency loss is the more important component, encouraging predictions on augmented examples to be consistent with each other. As these augmented examples correspond to one "imagined" example per subgroup, self-consistency controls dependence on subgroup features. Translation consistency additionally forces predictions on the original example to be similar to those of the average CycleGAN-translated examples, ignoring potential artifacts that the CycleGANs generate. We note that consistency losses have been used before, e.g. UDA (Xie et al., 2019) and Aug-Mix (Hendrycks et al., 2019) use different combinations of KL divergences chosen empirically. Our regularization (1) is tailored to the model patching setting, where it has a theoretical interpretation relating to subgroup invariance (Section 3). We show empirical improvements over these alternate consistency losses in Section 4.2.2. Overall Objective. The total consistency loss averages over all examples, L c (θ) = 1 2 E (x,y)∼P L s (x, xZy ; θ) + L t (x, xZy ; θ) . Combining our SGDRO robust objective and the consistency loss with the consistency strength hyper-parameter λ yields the final objective, L CAMEL (θ) = L SGDRO (θ) + λL c (θ).

3. AN INFORMATION THEORETIC ANALYSIS OF SUBGROUP INVARIANCE

We introduce a framework to analyze our end-to-end approach, showing that it induces subgroup invariances in the model's features. First, we review a common framework for treating robustness over discrete groups that aims to create invariances, or independences between the learned model's features φ(X) and groups Z. We then define a new model for the distributional assumptions underlying the subgroup setting, which allows us to analyze stronger invariance guarantees by minimizing a mutual information (MI) upper bound. Formal definitions and full proofs are deferred to Appendix C. Prior work: Class-conditioned Subgroup Invariance. Prior work (Ganin et al., 2016; Li et al., 2018; Long et al., 2018) uses adversarial training to induce subgroup invariances of the form (φ(X) ⊥ Z) | Y , so that within each class, the model's features φ(X) appear the same across subgroups Z. We call this general approach class-conditional domain adversarial training (CDAT). Although these works are motivated by other theoretical properties, we show that they induce the above invariance by minimizing a variational lower bound of the corresponding mutual information. Lemma 1. CDAT minimizes a lower bound on the mutual information I(φ(X); Z | Y ). Since the model's features matter only insofar as they affect the output, for the rest of this discussion we assume without loss of generality that φ(X) = Ŷ is simply the model's prediction. A Natural Distributional Assumption: Subgroup Invariance on Coupled Sets. Although prior work generally has no requirements on how the data X among the groups Z relate to each other, we note that a common implicit assumption is that there is a "correspondence" between examples among different groups. We codify this distributional assumption explicitly. Informally, we say that every example x belongs to a coupled set [x], containing one example per subgroup in its (x's) class (Figure 3 ) (Appendix C.3, Definition 1). [X] is the random variable for coupled sets, i.e. it denotes sampling an example x and looking at its coupled set. Intuitively, x ∈ [x] represent hidden examples in the world that have identical class features to x and differ only in their subgroup features. These hidden examples may not be present in the train distribution and model patching "hallucinates" them, allowing models to directly learn relevant class features. ), so this is a stronger notion of invariance than CDAT permits. Additionally, the losses from the CycleGAN (Stage 1) and consistency regularizer (Stage 2) combine to form an upper bound on the mutual information rather than a lower bound, so that optimizing our loss is more appropriate. Theorem 1. For a model f θ with outputs Ŷ , the MI I( Ŷ ; Z | [X] ) is the Jensen-Shannon Divergence (JSD) of predictions on coupled sets E [x]∼[X] JSD f θ (x)) x∈[x] . In the case of k = 2 subgroups per class, this can be upper bounded by the CycleGAN and consistency losses E (x,y)∼(X,Y ) L s (x; xZy ; θ) 1 2 + z∈Zy L z CG (x; θ) 1 2 2 . In particular, the global optimum of the trained CAMEL model induces Ŷ ⊥ Z | [X]. The main idea is that the conditional MI I( Ŷ ; Z | [X]) can be related to model's predictions on all elements in a coupled set [x] using properties of the JSD. However, since we do not have true coupled sets, the consistency loss (3) only minimizes a proxy for this JSD using the augmentations xZy . Using standard GAN results, the divergence between the true and augmented distributions can be bounded by the loss of a discriminator, and the result follows from metric properties of the JSD. Thus, the CycleGAN augmentations (Stage 1) and our consistency regularizer (Stage 2) combine to provide an upper bound on our MI objective, tying together the model patching framework neatly.

4. EXPERIMENTS

We demonstrate that CAMEL can take advantage of the learned subgroup augmentations and consistency regularizer to improve robust and aggregate accuracy, while reducing the subgroup performance gap (defined in Table 1 ). We validate CAMEL against both standard training with no subgroup knowledge (ERM) and other baselines aimed at improving group robustness across 4 datasets. We also conduct extensive ablations to isolate the benefit of the learned inter-subgroup transformations over standard augmentation, and the subgroup consistency regularizer over prior consistency losses. Datasets. We briefly describe the datasets used, with details available in Appendix D.1. MNIST-Correlation. We mix data from MNIST (LeCun et al., 1998) and MNIST-Corrupted (Mu & Gilmer, 2019) to create a controlled setup for analyzing subgroup performance. Digit parity classes Y ∈ {even, odd} are divided into subgroups Z ∈ {clean, zigzag} from MNIST and MNIST-Corrupted. Y and Z are highly correlated, so that most even (odd) digits are clean (zigzag). CelebA-Undersampled. Following Sagawa et al. (2020) , we classify hair color Y ∈ {non-blonde, blonde} in the CelebA dataset (Liu et al., 2015) . Subgroups are based on gender Z = {female, male}. We subsample the set of non-blonde women so that most non-blonde (blonde) examples are men (women). Waterbirds. In this dataset to analyze spurious correlations (Sagawa et al., 2020) , birds Y ∈ {landbird, waterbird} are placed against image backgrounds Z ∈ {land, water}, with waterbirds (landbirds) more commonly appearing against water (land). ISIC. In this skin cancer dataset (Codella et al., 2018) , we classify Y ∈ {benign, malignant} cancers, with bandages Z appearing on ∼ 50% of only benign images. Methods. CAMEL instantiates model patching as described in Section 2. We use the original CycleGAN model with default hyperparameters (Appendix D.2). We compare against ERM and GDRO (Table 1 ), which respectively minimize the standard risk and robust risk (over all subgroups) on the training set. On MNIST-Correlation, we additionally compare against the IRM (Arjovsky et al., 2019) and CDAT (Li et al., 2018) baselines which target invariance assumptions (details in Appendix D.6). All classifiers are fine-tuned using a ResNet-50 architecture, with pretrained ImageNet weights. Detailed information about experimental setups are provided in Appendix D. Metrics. We evaluate on aggregate accuracy, robust accuracy, and subgroup gap for a class y, which are the metrics of interest for ERM, GDRO, and our subgroup robustness setting (Table 1 ).

4.1. SUBGROUP ROBUSTNESS AND INVARIANCE ON BENCHMARK DATASETS

We first compare all methods on the benchmark datasets, with results summarized in Table 2 . CAMEL increases aggregate and robust accuracy while closing the subgroup gap. On all datasets, CAMEL improves both aggregate and robust accuracy by up to 5.3%, mitigating the tradeoff that other methods experience. CAMEL also balances out the performance of subgroups within each class, e.g., on Waterbirds, reducing this subgroup gap by 10.22% on landbirds compared to GDRO. CAMEL learns subgroup-invariant representations. To measure the invariance of models, we report an estimate of the mutual information defined in Lemma 1, calculated using class-conditional domain prediction heads (Appendix D.5). Table 3 illustrates that CAMEL is the only method that successfully makes the model invariant to subgroups.

4.2. MODEL PATCHING ABLATIONS

We perform ablations on the major components of our framework: (1) substituting learned augmentations with alternatives like heuristic augmentations in Stage 1, and (2) substituting prior consistency losses for our subgroup consistency regularizer in Stage 2.

4.2.1. EFFECT OF LEARNED AUGMENTATIONS

We investigate the interaction between the type of augmentation used and the strength of consistency regularization, by varying the consistency loss coefficient λ on Waterbirds (Table 4 ). We compare to: (i) subgroup pairing, where consistency is directly enforced on subgroup examples from a class without augmentation and (ii) heuristic augmentations, where the CycleGAN is substituted with a Strong consistency regularization enables CAMEL's success. As λ increases from 20 to 200, CAMEL's robust accuracy rises by 7% while the subgroup gap is 9.37% lower. For both ablations, performance deteriorates when λ is large. Subgroup pairing is substantially worse (14.84% lower) since it does not use any augmentation, and as we expected does not benefit from increasing λ. Heuristic augmentations (e.g. rotations, flips) are not targeted at subgroups and can distort class information (e.g. color shifts in AugMix), and we observe that strongly enforcing consistency (λ = 200) makes these models much worse. Overall, these results agree with our theoretical analysis. CAMEL combines flexibly with other augmentations. Empirically, we find that using heuristic augmentations in addition to the CycleGAN (method CAMEL + Heuristic) can actually be beneficial, with a robust accuracy of 90.62% and a subgroup gap that is 1.83% lower than using CAMEL alone.

4.2.2. ANALYZING THE SUBGROUP CONSISTENCY REGULARIZER

Next, we investigate our choice of consistency regularizer, by substituting it for (i) a triplet Jensen-Shannon loss (Hendrycks et al., 2019) and (ii) a KL-divergence loss (Xie et al., 2019) in CAMEL (Figure 4 ). Our goal is to show that our theoretically justified regularizer reduces overfitting, and better enforces subgroup invariance. Consistency regularization reduces overfitting. Figure 4 illustrates the train and validation crossentropy loss curves for CAMEL and GDRO on the small (landbird, water) Waterbirds subgroup (184 examples). Consistency regularization shrinks the gap between train and validation losses, strongly reducing overfitting compared to GDRO. Z label using ERM or GDRO. Evaluation is still performed by comparing against the coarser-grained Y label corresponding to the predicted Z value. We use MNIST-Correlation in the setting of Table 2 . In Table 5 , we find that while both ERM and GDRO perform well on the larger subgroups, the performance on minority subgroups drops compared to the simpler setting in Table 2 where they are trained only with Y labels. These results are consistent with the intuition that being asked to predict Z labels forces models to learn information about the spurious attributes instead of the real class, and highlights the importance of erasing subgroup information in order to learn invariant representations.

4.2.4. ADDITIONAL GAN ABLATIONS

As the quality of images generated by GANs continues to improve, previous works highlighted in Appendix B have considered them as data augmentation methods, where training an ERM classifier on GAN-augmented images can improve model accuracy. However, in Appendix D.8, we show that this does not improve robust accuracy. In contrast, we show that CAMEL can flexibly incorporate three other GAN baselines as alternatives to CycleGAN in Stage 1, whereby combining them with Stage 2 in the model patching pipeline improves robust accuracy by over 20 points.

4.3. REAL-WORLD APPLICATION IN SKIN CANCER CLASSIFICATION

We conclude by demonstrating that CAMEL can improve performance substantially on the real-world ISIC (Codella et al., 2018) skin cancer dataset (Table 6 ). We augment only the benign class, which is split into subgroups due to the presence of a colored bandage (Figure 1 ) while the malignant class contains no subgroups. We also report AUROC, as is conventional in medical applications. CAMEL substantially improves robust accuracy by 11.7% and importantly, increases accuracy on the critical malignant cancer class from 65.59% (ERM) and 64.97% (GDRO) to 78.86%. While standard ERM models spuriously correlate the presence of the colored bandage with the benign class, CAMEL reduces the model's dependence on spurious features. We verify this by constructing a modified ISIC subgroup (Appendix D.7) for the malignant class that also contains bandages. Figure 1 illustrates using GradCAM (Selvaraju et al., 2017) that CAMEL removes the model's reliance on the spurious bandage feature, shifting attention to the skin lesion instead.

5. CONCLUSION

Domain experts face a common problem: how can classifiers that exhibit unequal performance on different subgroups of data be fixed? To address this, we introduced model patching, a new framework that improves a classifier's subgroup robustness by encouraging subgroup-feature invariance. Theoretical analysis and empirical validation suggest that model patching can be a useful tool for domain experts in the future.

A GLOSSARY OF NOTATION

We provide a glossary of notation used throughout the paper. Total consistency loss (Eq 3) L : X 2 → R A distance function, used for CycleGAN consistency losses λ Hyperparameter controlling the strength of the consistency loss KL(•) The KL divergence JS(•) The Jensen-Shannon divergence (Definition 2) I(•) The Mutual Information

B EXTENDED RELATED WORK

We provide a comprehensive overview of related work and highlight connections to our work below.

B.1 OVERVIEW OF DATA AUGMENTATION

Data augmentation is widely used for improving the aggregate performance of machine learning models in computer vision (Krizhevsky et al., 2012; Szegedy et al., 2014) , natural language processing (Kolomiyets et al., 2011; Sennrich et al., 2015; Zhang et al., 2015) and audio (Cui et al., 2015; Ko et al., 2015) . The theoretical motivation for data augmentation is largely based on the tangent propagation formalism (Dao et al., 2018; Simard et al., 1991; 1992; 1998) which expresses the desired invariances induced by a data augmentation as tangent constraints on the directional derivatives of the learned model. Early work considered augmentations as image defects (Baird, 1992) or stroke warping (Yaeger et al., 1996) for character recognition. Since then, augmentation is considered an essential ingredient in computer vision (LeCun et al., 1998; Simard et al., 2003) , with commonly used augmentations including random flips, rotations and crops (He et al., 2016; Krizhevsky et al., 2012; Szegedy et al., 2014) . Applications of augmentation in computer vision include object detection (Dwibedi et al., 2017; Zoph et al., 2019) and scene understanding (Dvornik et al., 2018) In natural language processing, common data augmentation techniques include back-translation (Sennrich et al., 2015; Yu et al., 2018) , synonym or word substitution (Fadaee et al., 2017; Kobayashi, 2018; Kolomiyets et al., 2011; Wang & Yang, 2015; Zhang et al., 2015) , noising (Xie et al., 2017) , grammar induction (Jia & Liang, 2016 ), text editing (Wei & Zou, 2019) and other heuristics (Deschacht & Moens, 2009; Silfverberg et al., 2017) . In speech and audio applications, augmentation is also commonly used, through techniques such as vocal tract length warping (Jaitly & Hinton, 2013; Ko et al., 2015) and stochastic feature mapping (Cui et al., 2015; Stylianou et al., 1998) . In this work, we perform an empirical evaluation on image classification tasks although our ideas can be extended to classification of other modalities such as speech and text.

B.2 AUGMENTATION PRIMITIVES AND PIPELINES

Next, we highlight the particular augmentation primitives that have been used in prior work. Our work is differentiated by the use of learned augmentation primitives using CycleGANs (Zhu et al., 2017) , as well as a theoretical justification for this choice. Hand-Crafted Augmentation Primitives. Commonly used primitives are typically heuristic transformations, such as rotations, flips or crops (Krizhevsky et al., 2012; Szegedy et al., 2014) . Recent work has hand-crafted more sophisticated primitives, such as Cutout (DeVries & Taylor, 2017), Mixup (Zhang et al., 2017) , CutMix (Yun et al., 2019) and MixMatch (Berthelot et al., 2019) . While these primitives have culminated in compelling performance gains (Cubuk et al., 2019a; b) , they produce unnatural images and distort image semantics. Assembling Augmentation Pipelines. Recent work has explored learning augmentation policies -the right subset of augmentation primitives, and the order in which they should be applied. The learning algorithm used can be reinforcement learning (Cubuk et al., 2019a; Ratner et al., 2017a) or random sampling (Cubuk et al., 2019b) . More computationally efficient algorithms for learning augmentation policies have also been proposed (Ho et al., 2019; Lim et al., 2019) . These pipelines are primarily derived from the fixed set of generic image transformations we discussed earlier, and do not directly target specific attributes. By contrast, we consider learning augmentation primitives that target subgroup robustness, and additionally demonstrate in Section 4.2.2 that heuristic augmentations can complement CAMEL to yield additional performance gains. Learned Augmentation Primitives. There is substantial prior work in learning image transformations that produce semantic, rather than superficial changes to an image. A common paradigm is to learn a semantically meaningful data representation, and manipulate embeddings in this representation to produce a desired transformation. Transformations can then be expressed as vector operations over embeddings (Reed et al., 2015; Upchurch et al., 2017) or manifold traversals (Gardner et al., 2015; Reed et al., 2014) . Alternative approaches rely on training conditional generative models (Almahairi et al., 2018; Brock et al., 2016; Choi et al., 2018; Isola et al., 2017; Zhu et al., 2017) that learn a mapping between two or more image distributions. Much of this prior work is motivated by the need for sophisticated tools for image editing (Karras et al., 2018; Upchurch et al., 2017) e.g. for creative applications of machine learning (Mazzone & Elgammal, 2019) . Closer to our setting is work that explores the use of these transformations for data augmentation. A prominent use case focuses on imbalanced datasets, where learned augmentations are used to generate examples for underrepresented classes or domains. Examples include BaGAN (Mariani et al., 2018) , DAGAN (Antoniou et al., 2017) , TransferringGAN (Wang et al., 2018) and others (Beery et al., 2019; Hu et al., 2019; Molano et al., 2018; Mounsaveng et al., 2019; Tran et al., 2017; Zhang et al., 2018) . Applications to medical data (Pesteie et al., 2019; Sandfort et al., 2019) and person re-identification (chen Sun et al., 2019) have also been explored. Our model patching framework differs substantially from these papers, since we focus on robustness. We discuss this intersection next.

B.3 DATA AUGMENTATION AND MODEL ROBUSTNESS

Prior work on model robustness has mostly focused on learning models that are robust to bounded p-norm perturbations (Goodfellow et al., 2014b; Moosavi-Dezfooli et al., 2018; Papernot et al., 2015; Szegedy et al., 2013) using ideas such as adversarial training (Madry et al., 2017) . A separate line of work considers consistency training (Hendrycks et al., 2019; Kannan et al., 2018; Zheng et al., 2016) , where predictions are made invariant to input perturbations, often by minimizing a divergence between the predictions for the original and perturbed examples. Consistency regularization has also been shown to be effective for semi-supervised learning (Xie et al., 2019) . Consistency training. We contrast equation (3) with consistency losses from prior work. Unsupervised Data Augmentation (UDA) (Xie et al., 2019) simply controls an asymmetric divergence between the original example and each augmented example individually z KL(f (x) f (xz)). AugMix (Hendrycks et al., 2019) uses a Jensen-Shannon divergence 1 k + 1   KL (f (x) m) + z∈Zy KL (f (xz) m)   where m = 1 k+1 f (x) + i f (xi) . This can be seen as a version of our consistency, but with different weights and a different mean distribution that the KL's are being computed against. Our loss (3) has an important asymmetry between the original example x and the augmentations xi. One reason to prefer it is simply noting that as the number k of subgroups grows, the AugMix loss tends to the second term, and does not control for the discrepancy between predictions on the original domain f (x) and the augmented ones f (xi). Our consistency regularization instead allows us to bound a mutual information objective between variables in the joint subgroup distribution, yielding a tractable and interpretable objective (Section 3). In addition, we compare with these consistency losses and provide empirical results in Section 4.2.2. Robustness to more general augmentations has also been explored (Baluja & Fischer, 2017; Engstrom et al., 2017; Kanbak et al., 2017; Odena et al., 2016; Qiu et al., 2019; Song et al., 2018; Xiao et al., 2018) , but there is limited work on making models more robust to semantic data augmentations. The only work we are aware of is AdvMix (Gowal et al., 2019) , which combines a disentangled generative model with adversarial training to improve robustness. Our work contributes to this area by introducing the model patching framework to improve robustness in a targeted fashion. Specifically, under the data-generating model that we introduce, augmentation with a CycleGAN (Zhu et al., 2017) model allows us to learn predictors that are invariant to subgroup identity.

B.4 LEARNING ROBUST PREDICTORS

Recent work (Sagawa et al., 2020) introduced GDRO, a distributionally robust optimization method to improve worst-case accuracy among a set of pre-defined subgroups. However, optimizing the GDRO objective does not necessarily prevent a model from learning subgroup-specific features. Instead, strong modeling assumptions on the learned features may be required, e.g. Invariant Risk Minimization (Arjovsky et al., 2019) attempts to learn an invariant predictor through a different regularization term. However, these assumptions are only appropriate for specialized setups where extreme out-of-domain generalization is desired. Unfortunately, these approaches still suffer from standard learning and generalization issues stemming from a small number of examples in the underperforming subgroup(s) -even with perfect subgroup information. Additionally, they necessarily trade off average (aggregate) accuracy against a different robust metric.

B.5 RELATIONSHIP TO SUBGROUP FAIRNESS

A common goal in subgroup fairness is to ensure statistical parity of predictions across groups, and a variety of fairness criteria have been proposed (Mehrabi et al., 2019) . Others have considered a theoretical setting where the number of subgroups can be large (possibly infinite) (Kearns et al., 2018) . Typically, this line of work assumes that groups are common across classes, while in our setting, we consider the possibility of different subgroups in each class. The bulk of this work focuses on prediction problems on small datasets, rather than with high-dimensional image data.

C DETAILED ANALYSIS

We begin with background material on the CycleGAN (Appendix C.1) and the Jensen-Shannon Divergence (Appendix C.2). Appendix C.3 contains a longer discussion of the modeling assumptions in Section 3, fleshing out the distributional assumptions and definition of coupled sets. Appendix C.4 and Appendix C.5 completes the proofs of the results in Section 3. CycleGAN uses a cycle consistency loss to ensure that the mappings F and G are nearly inverses of each other, which biases the model toward learning meaningful cross-domain mappings. An additional identity loss is sometimes used which also encourages the maps F, G to preserve their original domains i.e. F (a) ≈ a for a ∼ PA. These cycle consistency and identity losses can be modeled by respectively minimizing LCG(a, F (G(a))) and LCG(a, F (a)) for some function LCG which measures some notion of distance on A (with analogous losses for B). The original CycleGAN uses the 1 distance L(a, ã) = a-ã 1. However, we note that many other functions can be used to enforce similarity. In particular, we point out that a pair-conditioned discriminator D{a, ã} → [0, 1] 2 can also be used, which accepts a coupled pair of original and translated examples and assigns a probability to each of being the original example. If the guesses for the true and translated examples are Da and Dã respectively, then the distance is L(a, ã) = maxD log Da + log(1 -Dã) + log 2. To sanity check that this has properties of a distance, note that L decreases as a, ã are more similar, as the discriminator has trouble telling them apart. Semantically consistent mappings are encouraged with the cycle consistency and identity losses, e.g. to ensure that F (a) = a for all a ∈ A.

C.1 BACKGROUND: CYCLEGAN

Intuitively, the discriminator loss is a measure of how similar the original and generated distributions are, which will be used in Section C.5 to prove our main result.

C.2 BACKGROUND: PROPERTIES OF THE JENSEN-SHANNON DIVERGENCE

We define the Jensen-Shannon divergence (JSD) and its properties that will be used in our method and analysis. Definition 2. The Jensen-Shannon Divergence (JSD) of distributions P1, . . . , P k is JS(P1, . . . , P k ) = 1 k k i=1 KL(Pi M ) where M = 1 k k i=1 Pi. We overload the JS(•) function in the following ways. The JSD of random variables X1, . . . , X k is the JSD of their laws (distributions). Additionally, we define the JSD of vector-valued inputs if they represent distributions from context. For example, for a model f that outputs a vector representing a categorical distribution, JS(f θ (x1), . . . , f θ (x k )) is the JSD of those distributions. We briefly review important properties of the JSD. Unlike the KL divergence and other notions of distributional distance, the JSD can be related to a metric. Proposition 1. The JSD is the square of a metric. In particular, any three distributions p, q, r satisfy JS(p, q) 1/2 + JS(q, r) 1/2 ≥ JS(p, r) 1/2 . Finally, the following fact about the JSD relating it to the mutual information of a mixture distribution and its indicator variable will be useful in our analysis. Proposition 2. Let Z be a uniform categorical indicator variable with support [k] and Pi, i ∈ [k] be distributions. Let X ∼ Pz, z ∼ Z be the random variable associated with the mixture distribution of the Pi controlled by the indicator Z. Then I(X; Z) = JS(P1, . . . , P k ). Finally, we review standard results (e.g., from the GAN literature) on the relationship between discriminators and the JS divergence, which relates the loss of an optimal discriminator to the JSD of the two distributions. We include a proof for completeness. Proposition 3. Consider two domains A and Ã (i.e., distributions on a common support A), with densities p(a), p(a) respectively. Consider a discriminator D : A → R optimized to maximize the loss L(D) = 1 2 E a∼p(a) log D(a) + 1 2 E a∼ p(a) log(1 -D(a)). Then the value of this loss for the optimal discriminator D * is JS(A, Ã) -log 2. Proof. Differentiate the loss with respect to the discriminator's output D(a) for any example a ∈ A, which yields 1 2 p(a) 1 D(a) - 1 2 p(a) 1 1 -D(a) . The loss is maximized at D * (a) = p(a) p(a)+ p(a) . The result follows from plugging this discriminator into the loss and using Definition 2: L(D * ) = 1 2 E a∼p(a) log p(a) p(a) + p(a) + 1 2 E a∼ p(a) p(a) p(a) + p(a) = 1 2 KL A A + Ã 2 + 1 2 KL Ã A + Ã 2 -log(2) = JS(A, Ã) -log 2.

C.3 SUBGROUP INVARIANCE USING COUPLED DISTRIBUTIONS

A common framework for treating robustness over discrete groups aims to create invariances, or independencies between the learned model's features and these groups. We review this approach, before defining a new model for the distributional assumptions used in this work. The notion of coupled sets we introduce underlies both stages of the framework and allows for stronger invariance guarantees than previous approaches, which will be analyzed in Appendix C.5.

Class-conditioned Subgroup Invariance.

In order for a model to have the same performance over all values of Z, intuitively it should learn "Z-invariant features", which can be accomplished in a few ways. Invariant Risk Minimization (IRM) (Arjovsky et al., 2019) calls the Z labels environments and aims to induce (Y | φ(X)) ⊥ Z, where φ(X) are the model's features, so that the classifier does not depend on the environment. Another line of work treats Z as domains and uses adversarial training to induce invariances of the form (φ(X) ⊥ Z) | Y (Ganin et al., 2016; Li et al., 2018; Long et al., 2018) , so that within each class, the model's features look the same across domains. We call this general approach class-conditional domain adversarial training (CDAT), which attaches a domain Z prediction head per class Y , and adopts an adversarial minmax objective so that the featurizer φ(X) erases Z related information and reduces the model's dependence on Z. Coupling-conditioned Subgroup Invariance. Although previous works generally make no assumptions on how the data X among the groups Z relate to each other, we note that a common implicit requirement is that there is a "correspondence" between examples among different groups. We codify this distributional assumption explicitly with a notion of coupling, which allows us to define and analyze stronger invariances. In particular, we assume that the underlying subgroups are paired or coupled, so that every example can be translated into the other subgroups. Definition 1 formalizes our distributional notion of coupled sets. Definition 1. For a given distribution P , a coupled set within class y is a set {xz}z∈Z y consisting of one example from each subgroup of y, where each example has the same probability.foot_1 A coupling for a distribution P on (X, Y, Z) is a partition of all examples in X into coupled sets. For any example x ∈ X , let [x] denote its coupled set. Let [x]1, . . . , [x] k denote the elements of a coupled set [x] in a class with k subgroups. Let [X] denote the random variable that samples a coupled set; i.e. taking [x] for a random x sampled from any fixed subgroup z. Additionally, we say that a distribution is subgroup-coupled if it satisfies Definition 1, i.e. it has a coupling. In the context of subgroups of a class y, this assumption entails that every example can be factored into its subgroup and coupled set membership. All examples that are members of a particular coupled set can be thought of as sharing a set of common features that signal membership in the class. Separately, examples that are members of a particular subgroup can be thought to share common features that signal subgroup membership. Together, these two pieces of information identify any example from class c. We represent this assumption by letting the (unobserved) random variable [X] represent the "class identity" of an example X, which can be thought of as the class features that aren't specific to any subgroup. Thus, the full generating process of the data distribution (X, Y, Z, [X]) consists of independently choosing a coupled set [X] and subgroup Z within a class Y , which together control the actual example X. Note that [X] and Z are both more fine-grained and thus carry more information than Y . This process is illustrated in Figure 6a . Figure 6b illustrates this concept for the MNIST-Corrupted dataset (Mu & Gilmer, 2019) . Given a digit class such as Y = 3, subgroups correspond to corruptions such as zigzags and dotted lines applied to the digits. A coupled set consists of these corruptions applied to a clean digit. Definition 1 allows us to reason about the following stronger invariances. Given class y ∈ Y, every example in subgroup z ∈ Zy implicitly has corresponding examples in all subgroups Zy within its class, and the learned features for each of these coupled sets should be identical in order to equalize performance between subgroups. Thus instead of the weaker goal (φ(X) ⊥ Z) | Y , we use the stronger coupling-conditioned invariance  (φ(X) ⊥ Z) | Y, [X] = (φ(X) ⊥ Z) | [X]. G L x N F X Q O s d 9 o d 9 c j s M M v x A = " > A A A C S n i c b V D B a h s x F N Q 6 a Z q 6 a e q 0 x 1 5 E T C E n d 7 e k t J d A I B B 6 S w L d x L B e g q R 9 a w t L 2 k V 6 W 2 O W / Y Z c 0 3 / K D / Q 3 c g u 9 R H b 2 0 N g Z E A w z 8 3 h P w 0 s l H Y b h 3 6 C z s f l q 6 / X 2 m + 7 b n X e 7 7 3 t 7 H y 5 d U V k B s S h U Y Y e c O V D S Q I w S F Q x L C 0 x z B V d 8 e r L w r 3 6 D d b I w v 3 B e Q q r Z 2 M h c C o Z e i p N h S o + u e / 1 w E C 5 B 1 0 n U k j 5 p c X 6 9 F 3 w Z Z Y W o N B g U i j m X R G G J a c 0 s S q G g 6 Y 4 q B y U T U z a G x F P D N L i 0 X l 7 b 0 M 9 e y W h e W P 8 M 0 q X 6 / 0 T N t H N z z X 1 S M 5 y 4 V W 8 h v u Q l F e Y / 0 l q a s k I w 4 m l R X i m K B V 1 8 n W b S g k A 1 9 4 Q J K / 2 t V E y Y Z Q J 9 Q d 2 R g Z k o t G Y m q 0 e a 5 0 0 S p Z 7 4 L T y v + 1 H T P E + c N q 3 J 6 9 N V D 8 + 8 O Z M Z o F Q Z 1 G d N 4 y u O V g t d J 5 d f B 9 H h 4 N v F Y f + Y t 2 V v k 0 9 k n x y Q i H w n x + Q n O S c x E U S S G 3 J L / g F K C R U 4 d C L H s d Z Z O d G Z n h G M = " > A A A C S H i c b V D L T h s x F P W E Q t P w K I 9 l N 1 Y j J F Z h B o H a D V K k S h E 7 q N o E 1 G S E b M + d Y M X 2 j O w 7 o G g 0 n 9 A t / B N / w F + w Q + x w w i x K 6 J E s H Z 1 z r u 7 1 4 b m S D s P w I W g s f V h e + d j 8 1 F p d W 9 / 4 v L m 1 P X B Z Y Q X 0 R a Y y e 8 G Z A y U N 9 F G i g o v c A t N c w T m f / J j 5 5 9 d g n c z M b 5 z m E G s 2 N j K V g q G X f v 2 h x 5 e b 7 b A T z k H f k 6 g m b V L j 7 H I r 2 B 8 l m S g 0 G B S K O T e M w h z j k l m U Q k H V G h U O c i Y m b A x D T w 3 T 4 O J y f m t F d 7 2 S 0 D S z / h m k c / X f i Z J p 5 6 a a + 6 R m e O U W v Z n 4 P 2 9 Y Y P o 9 L q X J C w Q j X h e l h a K Y 0 d n H a S I t C F R T T 5 i w 0 t 9 K x R W z T K C v p z U y c C M y r Z l J y p H m a T W M Y k / 8 F p 6 W 7 a i q 3 i Z 6 V W 3 y s r f o 4 a k 3 b 2 Q C K F U C 5 W l V + Y q j x U L f k 8 F B J z r s H P 0 8 b H d 5 X X a T f C F f y R 6 J y D f S J S f F j V D z / l Z K 2 3 H f h i E = " > A A A C S H i c b V D B S h x B F O y Z J G r W x G h y z K X J I O S 0 m Q k G v Q g L g u S 2 K 8 n q w u 4 g 3 T 1 v 1 s b u n q H 7 j c s y z C d 4 1 X / y D / I X u Y k 3 e 9 c 5 x D U F D U V V P d 7 r 4 q W S D u P 4 T x C + e v 1 m b X 3 j b W f z 3 f u t D 9 s 7 H 0 9 d U V k B Q 1 G o w o 4 4 c 6 C k g S F K V D A q L T D N F Z z x y 6 O F f 3 Y F 1 s n C / M Z 5 C a l m U y N z K R h 6 6 d e I H p 5 v R 3 E 3 X o K + J E l L I t J i c L 4 T f J t k h a g 0 G B S K O T d O 4 h L T m l m U Q k H T m V Q O S i Y u 2 R T G n h q m w a X L l p l j Z Y e G Z R C g V h k L c O G i a u 2 R V M I z V M g y v 8 8 u J A 9 6 J S 0 q q 2 8 R m k S / X f C c + 0 c 3 P N Y 1 I z n L l V b y H + z 5 u 2 W H 0 o v D R N i 2 D E / a K q V R R r u v g + L a U F g W o e C R N W x l u p m D H L B M a S B r m B G 1 F r z U z p c 8 2 r M M 2 K S O I W X v l h F s L D x F H o T O 6 P V j 0 8 j u a N L A G l K s E f h x A r z l Y L f U x O 9 0 f Z e H T w Z T w 8 5 F 3 Z f f K a v C F v

C.4 MI BOUNDS FOR CLASS-CONDITIONED INVARIANCE

Recall that the high-level goal of CDAT is to induce independencies between subgroup information and the model's feature representation. In order to induce the desired invariance (φ(X) ⊥ Z) | Y of class features from subgroup identities, a natural approach is to minimize the conditional mutual information I(φ(X); Z | Y ), which is minimized at 0 when the invariance is satisfied and grows when φ(X) and Z are predictive of each other. This mutual information can be estimated using standard techniques. Lemma 1. CDAT minimizes a lower bound on the mutual information I(φ(X); Z | Y ), where φ(X) is the feature layer where the domain prediction head is attached. Proof. We have I(φ(X); Z | Y ) = H(Z | Y ) -H(Z | φ(X), Y ) = H(Z | Y ) + E x,y∼p(x,y) E z∼p(z|φ(x),y) [log(p(z|φ(x), y))] ≥ H(Z | Y ) + E x,y∼p(x,y) E z∼p(z|φ(x),y) [log(p ψ (z|φ(x), y))] = H(Z | Y ) + E y,z,φ(x) [log(p ψ (z|φ(x), y))] , which bounds the MI variationally through a parametrized conditional model p ψ . Up to an additive term H(Z | Y ) which is a constant of the data distribution, this is simply the cross-entropy loss of a model trained on top of the featurizer φ to predict Z from φ(X) and Y , which coincides with the domain adversarial training approach. By specializing φ(X) to Ŷ , we obtain Corollary 1. If CDAT attaches a domain prediction head to the prediction layer Ŷ , it optimizes a lower bound on I( Ŷ ; Z | Y ). Thus, although approaches involving domain adversarial training (Ganin et al., 2016; Li et al., 2018) motivate their approach through alternate concepts such as H-divergences and GAN-based adversarial games, we see that they are implicitly minimizing a simple variational estimate for mutual information. In Section 4, Table 3 's reported estimate of the mutual information uses Corollary 1. We also use the notation [x] for a generated coupled set and [x]z as its realization in subgroup z (a specific augmented example). Note that [x] and the notation xZy from Section 2.2 refer to the same thing, the set of augmented examples. We can control the difference between augmented and true subgroup distribution in two ways. First, the translation-loss Lt (2) regularizes the average predictions from the augmentations to match those of the original example, constraining the prediction model to ignore general distribution shifts introduced by the generative models. Moreover, the discrepancy between the loss we are minimizing via CycleGAN-augmented examples Ls = Ex JS (f θ ([x]1), . . . , f θ ([x] k )) (1) and the true objective JS (f θ ([x]1), . . . , f θ ([x] k )) can be bounded by the loss of the pair-conditioned CycleGAN discriminators (Section 2.1), via metric properties of the JSD. Models such as CycleGAN directly control the deviation of augmentions from the original examples, via the GAN discriminators and consistency losses. The following Lemma says that CycleGAN discriminator loss is the divergence between the original distribution in subgroup z, and the generated distribution of subgroup z, paralleling standard GAN results (Goodfellow et al., 2014a) . Lemma 3. The optimal discriminator between original subgroup distribution Pz and augmented subgroup Pz has loss L * CG = E [x]∼[X] JS([x]z, [x]z) -log 2. Proof of Lemma 3. By Proposition 3, E [x]∼[X] JS([x]z, [x]z) = log 2 + 1 2 E [x]∼[X] log D z [x] ([x]z) + 1 2 E [x]∼[X] log(1 -D z [x] ([x]z)) where D z [x] is a discriminator for this coupled set (within subgroup z). Instead of training a separate discriminator per example or coupled set, it is enough to train a single discriminator D conditioned on this specific coupled set ([x]z, [x]z). In other words this is a discriminator whose input is both the original example [x]z and a generated version [x]z, and for each input guesses its chance of being a real example. This is exactly the pair-conditioned discriminator described in Section C.1. Proof of Theorem 1. We finally put the pieces together to prove the main result, restated here for convenience. Theorem 1. For a model f θ with outputs Ŷ , the MI I( Ŷ ; Z | [X]) is the Jensen-Shannon Divergence (JSD) of predictions on coupled sets E [x]∼[X] JSD f θ (x)) x∈[x] . In the case of k = 2 subgroups per class, this can be upper bounded by the CycleGAN and consistency losses E (x,y)∼(X,Y ) Ls(x; xZy ; θ) 1 2 + z∈Zy L z CG (x; θ) 1 2 2 . In particular, the global optimum of the trained CAMEL model induces Ŷ ⊥ Z | [X]. First, the equivalence of the quantity we care about I(Z; Ŷ ; [X]) and the consistency loss on true coupled sets is given by Lemma 2. It remains to bound EJS(f θ ([x]1), f θ ([x]2)), which can be bounded by the consistency loss on augmented examples EJS(f θ ( [x]1), f θ ([x]2)) and the optimal CycleGAN losses EJS(f θ ([x]i), f θ ([x]i)) by metric properties of the JSD. Proof of Theorem 1. Consider any fixed subgroup z and let Xz denote the R.V. from the mixture distribution of Pz and Pz, i.e. either a true example or an augmented example from subgroup z. Let W denote the (binary) indicator of this mixture. Then JS(f θ ([x]z), f θ ([x]z)) = I(W ; f θ ( Xz)) ≤ I(W ; Xz) = JS([x]z, [x]z), where the equalities are Proposition 2 and the inequality is an application of the data processing inequality on the Markov chain W → Xz → f θ ( Xz). Combining equation ( 9) with Lemma 3, applying the definition of L z CG , and summing over two groups z = 1, z = 2 yields JS(f θ ([x]1), f θ ([x]1)) 1 2 + JS(f θ ([x]2), f θ ([x]2)) 1 2 ≤ L z 1 CG (x; θ) 1 2 + L z 2 CG (x; θ) 1 2 By definition of the self-consistency loss (1) and Definition 2, for any sample x and where [x] denotes the generated coupled set {F1(x), F2(x)} as usual. Denoting the right hand side Ls(x; θ) for shorthand, summing equations ( 10) and ( 11), and using the metric property of the JSD (Proposition 1) gives JS(f θ ([x]1), f θ ([x]2)) = Ls(x, [x]; θ), JS(f θ ([x]1), f θ ([x]2)) 1 2 ≤ Ls(x; θ) 1 2 + L z 1 CG (x; θ) 1 2 + L z 2 CG (x; θ) 1 2 . Finally, squaring and averaging over the dataset and applying Lemma 2 gives the result of Theorem 1: I( Ŷ ; Z | [X]) ≤ Ex∼X Ls(x; θ) 1 2 + L z 1 CG (x; θ) 1 2 + L z 2 CG (x; θ) 1 2 2 . These pieces can be combined to show that the GAN-based modeling of subgroups (Stage 1) and the consistency regularizer (Stage 2) together minimize the desired identity-conditioned mutual information, which completes the proof of Theorem 1.

D EXPERIMENTAL DETAILS

We provide detailed information about our experimental protocol and setup for reproducibility, including dataset information in D.1,

D.1 DATASET INFORMATION

We provide details for preprocessing and preparing all datasets in the paper. Table 8 summarizes the sizes of the subgroups present in each dataset. All datasets will be made available for download. MNIST-Correlation. We mix data from MNIST (LeCun et al., 1998) and MNIST-Corrupted (Mu & Gilmer, 2019) to create a controlled setup. We classify digit parity Y ∈ {even, odd}, where each class is divided into subgroups Z ∈ {clean, zigzag}, drawing digits from MNIST and MNIST-Corrupted (with the zigzag corruption) respectively. To generate the dataset, we use the following procedure: • Fix a total dataset size N , and a desired correlation ρ. • Sample ISIC. Train on 100 images each from both benign subgroups (with and without bandaids) for 4000 epochs with a batch size of 4, cycle loss coefficient of 10.0 and identity loss coefficient of 10.0. We flip inputs randomly (with probability 0.5) and randomly crop upto 10% of every image. -(ρ+1)N 4 even digits from MNIST -N 2 -(ρ+1)N 4 even digits from MNIST-Corrupted -N 2 -(ρ+1)N

D.3 ARCHITECTURES AND TRAINING INFORMATION

All training code is written in Python with tensorflow-2.0. All models are trained with Stochastic Gradient Descent (SGD), with a momentum of 0.9. In order to isolate the effect of our method, we do not use any data augmentation (such as pad and crop operations or random flips) when training the classifier. MNIST-Correlation. We train a convolutional neural network from scratch, initialized with random weights. The architecture is provided below, Other datasets. All models are fine-tuned using a ResNet-50 architecture, with pretrained ImageNet 

D.4 HYPERPARAMETERS

For model selection, we use robust accuracy on the validation setfoot_4 . The selected model's hyperparameters are then run 3 times, and the results averaged over these trials are reported in Table 2 . Below, we provide details of all hyperparameter sweeps, and in Table 12 , we include the best hyperparameters found for each method and dataset.

D.4.1 CELEBA-UNDERSAMPLED

We run sweeps for all methods over 50 epochs. ERM. Sweep over learning rates {0.0001, 0.00005, 0.00002, 0.00001} with weight decay fixed to 0.05.

GDRO.

Sweep over adjustment coefficients in {1.0, 3.0} and learning rates {0.0001, 0.00005} with weight decay fixed to 0.05. GDRO. This is our main baseline as described in Section 2, and uses a stochastic optimization method (Sagawa et al., 2020). GDRO uses subgroup information to optimize the worst-case loss over all subgroups. We note that GDRO requires the specification of an adjustment coefficient, and we describe the best found coefficients in Table 12 . CDAT. We use a generic domain adversarial training approach using a domain prediction head attached to the last feature layer of the model φ(X). The domain head predicts the subgroup identity of the given example, and we use gradient reversal in order to erase domain information from the representation φ(X). We vary the magnitude of the gradient reversal on the domain loss (which we call the domain loss coefficient in Table 12 ) in order to find the best-performing model.

IRM.

We implement the IRM penalty (Arjovsky et al., 2019), and treat the subgroups as separate environments across which the model should perform well.

D.6.2 ABLATIONS

Subgroup Pairing. We simply take pairs of examples that lie in different subgroups and enforce consistency on them. Heuristic Augmentations. We build a pipeline inspired by AugMix (Hendrycks et al., 2019) using the following operations: shearing, translation, rotation, flipping, contrast normalization, pixel inversion, histogram equalization, solarization, posterization, contrast adjustment, color enhancement, brightness adjustment, sharpness adjustment, cutout and mixup. We sample between 1 and 3 of these augmentations in a random order and apply them to the image. We include an additional ablation on the MNIST-Correlation dataset where we vary the consistency penalty coefficient λ in Table 9 . Compared to heuristic augmentations, CAMEL provides substantial improvements that are stable across different values of λ.

D.7 ISIC SPURIOUS CORRELATIONS

For completeness, we include a detailed evaluation for the ISIC dataset in Table 10 . Here, we highlight that regardless of what criterion is used for model selection between robust accuracy and AUROC, CAMEL exceeds the performance of the other methods. For ISIC, we also create an alternate evaluation dataset with artificial images in order to test whether a model spuriously correlates the presence of a bandage with the benign cancer class. To construct this dataset, we use image segmentation to automatically extract images of the bandages from the benign cancer class, and superimpose them on images with malignant cancers. This allows us to generate the artificial subgroup of the malignant cancer class that would contain images with bandages. We use this dataset to highlight how CAMEL improves the model's dependence on this spurious feature in Figure 1 . 



Note that this allows each class to have the same subgroups, or for classes to have overlapping subgroups as special cases. Many of the datasets we consider in our experiments (Section 4) have this property. Note that this will typically not hold for the training distribution, since some subgroups may be underrepresented, making it much less probable that examples from those subgroups are sampled in a coupled set. However, we are concerned with robustness to a test distribution where the subgroups are of equal importance and equally likely. odd digits from MNIST The particular model used was taken from https://github.com/qubvel/classification_ models. For the ISIC dataset, we additionally performed model selection using AUROC, as illustrated in Table6. The consistency penalty is increased linearly on every step, from 0 to λ with rates 0.002 and 0.005 for λ = 50.0 and λ = 10.0 respectively.



Figure 2: The model patching framework. (Left) The class-subgroup hierarchy with each class Y divided into subgroups (e.g. Y = blonde hair into Z ∈ {male, female}). We learn inter-subgroup augmentations to transform examples between subgroups of a class. (Right) To patch the classifier, we augment examples by changing their subgroup membership and then train with our subgroup consistency loss and robust objective.

Figure 3: Coupled sets for subgroups of the Y = 7 class. This idea of coupled sets underlies both stages of the framework and enables stronger invariance guarantees. Given this notion, all examples x in a coupled set [x] should have identical predictions in order to be robust across subgroups, modeled by the desired invariance ( Ŷ ⊥ Z) | [X]. Instead of Lemma 1, we aim to minimize I( Ŷ ; Z | [X]). Note that I( Ŷ ; Z | [X]) ≥ I( Ŷ ; Z | Y ), which follows from the chain rule for MI (proof in Appendix C), so this is a stronger notion of invariance than CDAT permits. Additionally, the losses from the CycleGAN (Stage 1) and consistency regularizer (Stage 2) combine to form an upper bound on the mutual information rather than a lower bound, so that optimizing our loss is more appropriate.

Given two groups A and B, CycleGAN learns mappings F : B → A and G : A → B given unpaired samples a ∼ PA, b ∼ PB. Along with these generators, it has adversarial discriminators DA, DB trained with the standard GAN objective, i.e. DA distinguishes samples a ∼ PA from generated samples F (b), where b ∼ PB. In CAMEL, A and B correspond to data from a pair of subgroups z, z of a class.

Figure 5 visualizes the CycleGAN model. Definition 1. The sum of the CycleGAN cycle consistency LCG(a, F (G(a)) and identity LCG(a, F (a)) losses on domain A is denoted L A CG (a; θ) for overall CycleGAN parameters θ, and similarly for domain B. In the context of Stage 1 of model patching, let L z CG (x; θ) denote the loss when the domain is one of the subgroups z.

Figure 5: CycleGAN learns mappings on domains A ∪ B, where F maps examples to A and G maps to B. To model possible distribution shift introduced by the generative model, we denote their images as Im(F ) = Ã, Im(G) = B respectively.Semantically consistent mappings are encouraged with the cycle consistency and identity losses, e.g. to ensure that F (a) = a for all a ∈ A.

(a) Joint distribution of examples X with their class labels Y , subgroup labels Z, and coupled sets [X]. t e x i t s h a 1 _ b a s e 6 4 = " x P b l e

R 3 w X 3 w E P x 7 i n a C d u Y j e Y b O 5 i P I U b O l < / l a t e x i t > t e x i t s h a 1 _ b a s e 6 4 = " B S 2 P 3

k j P S J I G P y l 9 y S u + A + e A y e g u f X a C O o Z 3 b I G z Q a L y S O s t s = < / l a t e x i t > X = < l a t e x i t s h a 1 _ b a s e 6 4 = " D l e 2 e 7 i w c u p

1 8 t a G 7 n o l o 3 l h / T N I l + q / E z X T z s 0 1 9 0 n N 8 M K t e g v x f 9 6 4 w v w g r a U p K wQ j n h b l l a J Y 0 M X H a S Y t C F R z T 5 i w 0 t 9 K x Q W z T K C v p z M x M B O F 1 s x k 9 U T z v B k n q S d + C 8 / r K G m a 5 4 n j p j V 5 f b z q Y d + b M 5 k B S p V B 3 W 8 a X 3 G y W u h L c v q 9 m + x 1 f 5 z s R T 3 e l r 1 B P p M v 5 C t J y D 7 p k Z 9 k Q I Z E k C m 5 J j f k N r g L / g b 3 w c N T N A z a m U / k G c L w E S D G s t k = < /l a t e x i t > Ŷ = < l a t e x i t s h a 1 _ b a s e 6 4 = " N 4 A 5 R o j E o + 0 r 9 n V A J 8 V w K / + 2 P T Y = " > A A A C T n i c b V B N a x s x F N Q 6 b e o 6 T Z q k x 1 5 E T a A n d z c 4 t J d C o B B 6 S w p 1 P v A u Q d K + j U U k 7 S K 9 b T B C v 6 L X 5 D / 1 2 j / S W 2 l l Z w + N 0 w H B M D O P 9 z S 8 U d J h m v 5 M e m t P n q 4 / 6 z 8 f b L z Y 3 H q 5 v b N 7 6 u r W C p i I W t X 2 n D M H S h q Y o E Q F 5 4 0 F p r m C M 3 7 9 a e G f f Q P r Z G 2 + 4 r y B Q r M r I y s p G E b p I p 8 x 9 B e B f r z c H q a j d A n 6 m G Q d G Z I O J 5 c 7 y b u 8 r E W r w a B Q z

S U b e k 0 P y m Z y Q C R F E k + / k l t w l P 5 J f y e / k z 3 2 0 l 3 Q z r 8 g D 9 P p / A e U 6 t a c = < / l a t e x i t > example prediction (b) Illustration with the MNIST-Corrupted dataset (Mu & Gilmer, 2019), where subgroups Z are different types of corruptions.

Figure 6: Subgroup-coupled distributions separate the coupled set to which an example belongs (with respect to their class), from its subgroup label. Note that since features matter insofar as their effect on the final output Ŷ , it suffices to look at the case φ(X) = Ŷ . We first show in Section C.4 that CDAT methods target the invariance ( Ŷ ⊥ Z) | Y by minimizing a lower bound for the conditional mutual information, I( Ŷ ; Z | Y ) (Lemma 1).In Section C.5, we prove our main result: our combined objective function (4) targets the stronger invariance ( Ŷ ⊥ Z) | [X] by upper bounding the corresponding MI, which can be interpreted as forcing matching outputs for the examples in every coupled set.

Figure 5 also illustrates the concept of Definition 3: original domains A, B have corresponding domains Ã, B that are the images of the generators F, G.

Figure 8: Results of inter-subgroup transformations on MNIST-Correlation.

Figure 9: Results of inter-subgroup transformations on CelebA-Undersampled. Generation examples use the CycleGAN trained on the non-blonde class.

The only preprocessing common to all methods is standard ImageNet normalization using µ =[0.485, 0.456, 0.406], σ = [0.229, 0.224, 0.225].

Figure 10: Results of inter-subgroup transformations on Waterbirds. Generation examples use the CycleGAN trained on the landbirds class.

Comparison of metrics and losses for classifier training. Here Pz and Pz are marginal distributions of

A comparison between CAMEL and other methods on 3 benchmark datasets. Evaluation metrics include robust & aggregate accuracy and the subgroup performance gap, calculated on the test set. Results are averaged over 3 trials (one standard deviation indicated in parentheses).

Estimated MI between predictions and subgroups computed on MNIST-Correlation.

Ablation analysis (Section 4.2.1) that varies the consistency penalty coefficient λ. For brevity, we report the maximum subgroup performance gap over all classes. -the-art heuristic augmentation pipeline(Hendrycks et al., 2019) (Appendix D.6) containing rotations, flips, cutout etc. Our goal is to validate our theoretical analysis, which suggests that strong consistency training should help most when used with the coupled examples generated by the CycleGAN. We expect that the ablations should benefit less from consistency training since, (i) subgroup pairing enforces consistency on examples across subgroups that may not lie in the same coupled set; and (ii) heuristic augmentations may not change subgroup membership at all, and may even change class membership.

MNIST-Correlation results when training to predict Z labels and testing on Y labels. Test robust accuracy bolded. Consistency loss ablations on Waterbirds. (Left) loss curves on the (landbird, water) subgroup.The addition of the CAMEL consistency loss to GDRO reduces overfitting. (Right) Robust accuracy decrease with alternate consistency losses (Triplet JS(Hendrycks et al., 2019) and KL(Xie et al., 2019)) on CAMELgenerated data or heuristic augmentations.

Comparison on ISIC. Average of 3 trials (one standard deviation indicated in parentheses).

Summary of notation used throughout this work. The class of a subgroup z f θ : X → ∆ |Y| The parameterized class prediction model, returning a categorical distribution over Y Ŷ A random variable with support Y indicating a random sample from the output of f θ

Number of training, validation and test examples in each dataset.

Ablation analysis (Section 4.2.1) that varies the consistency penalty coefficient λ on the MNIST-Correlation dataset. For brevity, we report the maximum subgroup performance gap over all classes. We use standard training with a cross-entropy loss. ERM cannot take advantage of knowledge of the subgroups, so this constitutes a standard baseline that a practitioner might use to solve a task.

The values of the best hyperparameters found for each dataset and method.

ACKNOWLEDGMENTS AND DISCLOSURE OF FUNDING

We thank Pang Wei Koh, Shiori Sagawa, Geoff Angus, Jared Dunnmon, and Nimit Sohoni for assistance with baselines and datasets and useful discussions. We thank members of the Hazy Research group including Mayee Chen, Megan Leszczynski, Sarah Hooper, Laurel Orr, and Sen Wu for useful feedback on previous drafts. KG and AG are grateful for Sofi Tukker's assistance throughout this project. We gratefully acknowledge the support of DARPA under Nos. FA86501827865 (SDH) and FA86501827882 (ASED); NIH under No. U54EB020405 (Mobilize), NSF under Nos. CCF1763315 (Beyond Sparsity), CCF1563078 (Volume to Velocity), and 1937301 (RTML); ONR under No. N000141712266 (Unifying Weak Supervision); the Moore Foundation, NXP, Xilinx, LETI-CEA, Intel, IBM, Microsoft, NEC, Toshiba, TSMC, ARM, Hitachi, BASF, Accenture, Ericsson, Qualcomm, Analog Devices, the Okawa Foundation, American Family Insurance, Google Cloud, Swiss Re, the Salesforce Deep Learning Research grant, the HAI-AWS Cloud Credits for Research program, and members of the Stanford DAWN project: Teradata, Facebook, Google, Ant Financial, NEC, VMWare, and Infosys. The U.S. Government is authorized to reproduce and distribute reprints for Governmental purposes notwithstanding any copyright notation thereon. Any opinions, findings, and conclusions or recommendations expressed in this material are those of the authors and do not necessarily reflect the views, policies, or endorsements, either expressed or implied, of DARPA, NIH, ONR, or the U.S. Government.

C.5 MI BOUNDS FOR COUPLING-CONDITIONED INVARIANCE

The stronger distributional assumptions of Definition 1 allow us to analyze the invariance φ(X) ⊥ Z | [X], which can be interpreted as forcing matching features for the data in every coupled set.True Coupled Sets. Given a subgroup-coupled distribution, access to coupled sets allows analysis of stronger invariance assumptions.First, we confirm that this is indeed a stronger notion of invariance, that is(5)This follows from the chain rule for mutual inequality:Here, the first two equalities follow from Definition 1 (in particular, [X] and Z are more fine-grained than Y ), and the last two follow from the chain rule for mutual information.In particular, equation ( 5) quantifies the intuition that conditioning on an example's coupled set reveals more information then just conditioning on its class. Conversely, minimizing the LHS of (5) necessarily minimizes the objective I(Z; φ(X) | Y ) in (Li et al., 2018) , and an additional non-negative term I(Z; [X] | φ(X), Y ) relating the features and identity of examples.Moreover, the features φ(X) are only relevant insofar as their ability to predict the label. Specializing φ(X), this stronger conditional MI is related to the model's predictions; it is exactly equal to the self-consistency regularizer (1) if the model had access to true coupled sets [x].Thus, in the case where φ(X) = Ŷ is simply the model's prediction, this MI is simply the Jensen-Shannon divergence of the model's predictions.Lemma 2.Proof. For any features φ, the mutual information can be writtenwhere the random variable E[φ(X) | [X]] denotes the formal conditional expectation. The second equality follows sinceConsider specializing this to the case when φ(X) = Ŷ , i.e. it represents the random variable where an output class prediction Ŷ is sampled from the final class probability predictions f θ (X) of the model. Since this is distributed as P Ŷ |Xz = f θ (Xz), we obtainwhere the second equality follows by Proposition 2.Augmented Coupled Sets. In practice, we may not have true coupled sets This generates a dataset with balanced Y and Z with size N 2 each. For our experiments, we use N = 40000, ρ = 0.98. This makes Y and Z highly correlated, so that most even (odd) digits are clean (zigzag). For validation, we use 50% of the training data.CelebA-Undersampled. We modify the CelebA dataset (Liu et al., 2015) by undersampling the (Y = non-blonde, Z = female) subgroup in the training set. The original dataset contains 71629 examples in this training subgroup, and we keep a random subset of 4054 examples. This number is chosen to make the ratio of subgroup sizes equal in both classes 4054 66874 ≈ 1387 22880 . We do not modify the validation or test datasets.This modification introduces a spurious correlation between hair-color and gender, which makes the dataset more appropriate for our setting. We preprocess images by resizing to 128 × 128 × 3 before use.Waterbirds. We use the Waterbirds dataset (Sagawa et al., 2020) and resize images to 224 × 224 × 3 before use. Note that this differs from the preprocessing used by (Sagawa et al., 2020) , who first resize to 256 × 256 × 3 and then center-crop the image to 224 × 224 × 3. The preprocessing they use makes the task easier, since some part of the (spurious) background is cropped out, while we retain the full image.ISIC. We use the ISIC dataset (Codella et al., 2018) and resize images to 224 × 224 × 3 before use.

D.2 CYCLEGAN TRAINING DETAILS

We use the default hyperparameters suggested by (Zhu et al., 2017) for CycleGAN training, with batchnorm for layer normalization. We use Adam for optimization (β1 = 0.5) with a constant learning rate of 0.0002 for both generators and both discriminators. CAMEL. Sweep over consistency penalties in {5.0, 10.0, 20.0, 50.0}. Learning rate is fixed to 0.00005, weight decay fixed to 0.05 and the adjustment coefficient is fixed to 3.0.

D.4.2 WATERBIRDS

We run sweeps for all methods over 500 epochs.ERM. Sweep over learning rates {0.001, 0.0001, 0.00001} and weight decays {0.5, 0.001}.GDRO. Sweep over learning rates {0.00001, 0.00005} and weight decays {0.5, 0.05} with adjustment coefficient fixed to 1.0 and batch size 24. We also separately swept weight decays {1.0, 0.001} and adjustment coefficients over {1.0, 2.0}.

CAMEL.

Sweep over consistency penalties in {100.0, 200.0} and learning rates {0.00005, 0.0001}. Weight decay fixed to 0.001 and adjustment coefficient is fixed to 2.0. Separately, we sweep over learning rates {0.00001, 0.00002, 0.00005, 0.0001}, fixing the consistency penalty to 200.0, weight decay to 0.05 and adjustment coefficient to 1.0.

D.4.3 MNIST-CORRELATION

We run sweeps for all methods over 100 epochs.ERM. Sweep over learning rates {0.0001, 0.0002, 0.0005, 0.001} and weight decays {0.0005, 0.05}.GDRO. Sweep over learning rates {0.0001, 0.0002, 0.0005, 0.001} and weight decays {0.0005, 0.05}. Ad- justment coefficient is fixed to 1.0.CDAT. Sweep over domain loss coefficients {-0.1, -0.01, 0.1, 1.0}. We fix learning rate to 0.001 and weight decay to 0.0005. We run CDAT for 400 epochs, since it takes much longer to converge.

IRM.

Sweep over IRM penalty {0.01, 0.1, 1.0, 10, 100, 1000, 10000} and learning rates {0.0005, 0.001}.Weight decay is fixed to 0.0005.

CAMEL.

Sweep over consistency penalty weights {0.0, 2.0, 5.0, 10.0, 50.0}. Learning rate is fixed to 0.001 and weight decay is fixed to 0.0005.

D.4.4 ISIC

We run sweeps for all methods over 75 epochs.ERM. Sweep over weight decays {0.5, 0.05, 0.00005}. Learning rate is fixed to 0.0001.GDRO. Sweep over learning rates {0.0001, 0.00001} and weight decays {0.5, 0.05, 0.00005}. Adjustment coefficient is fixed to 0.CAMEL. Sweep over learning rates {0.0001, 0.00005}, weight decays {0.01, 0.05}, consistency penalties {10.0, 50.0} and annealing rates {0.005, 0.002}.

D.5 MUTUAL INFORMATION MEASUREMENT

For the mutual information measurement experiment on MNIST-Correlation in Section 4.1, we additionally attach a domain prediction head to the final feature layer. This domain prediction head is then used to predict the subgroup z of any example x. Note that this domain prediction head does not pass back gradients to the main model, it merely observes the learned representation and attempts to improve prediction accuracy of the subgroups using this. Intuitively, this captures how much information about the subgroups is available to be "squeezed-out" by the domain prediction head. This constitutes a use of Lemma 1 to estimate the mutual information, and we report the average cross-entropy loss (added to log 2).

D.6 BASELINE COMPARISONS

We describe the baselines that we compare to, with implementations for each of these available in our code release. 

D.8 ALTERNATIVE GAN AUGMENTATION BASELINES

As noted in Section 2.1, Stage 1 of the model patching pipeline can be integrated with alternative domain translation models. As an additional baseline, we compare to alternative GAN augmentation methods. Typically, these methods are used as a data augmentation method, but not evaluated on robustness.We consider the Augmented CycleGAN (Almahairi et al., 2018) , Data Augmentation GAN (DAGAN) (Antoniou et al., 2017) and StarGAN-v2 (Choi et al., 2020) models, either when used in combination with ERM, or when as a part of the model patching baseline. When used as a part of model patching, we replace the CycleGAN in Stage 1 with the alternative GAN model.We used released code for Augmented CycleGAN and DAGAN to generate data for the Waterbirds dataset. For StarGANv2, we used pre-trained models for Celeb-A. We note that DAGAN is meant to be a self-contained data augmentation pipeline, so we did not consider it in conjunction with Model Patching.The results of this comparison is are shown in 11. In particular, these alternate models have poor robust performance when used purely for data augmentation. Their performance improves when integrated in the model patching pipeline.

