COUNTERFACTUAL GENERATIVE NETWORKS

Abstract

Neural networks are prone to learning shortcuts -they often model simple correlations, ignoring more complex ones that potentially generalize better. Prior works on image classification show that instead of learning a connection to object shape, deep classifiers tend to exploit spurious correlations with low-level texture or the background for solving the classification task. In this work, we take a step towards more robust and interpretable classifiers that explicitly expose the task's causal structure. Building on current advances in deep generative modeling, we propose to decompose the image generation process into independent causal mechanisms that we train without direct supervision. By exploiting appropriate inductive biases, these mechanisms disentangle object shape, object texture, and background; hence, they allow for generating counterfactual images. We demonstrate the ability of our model to generate such images on MNIST and ImageNet. Further, we show that the counterfactual images can improve out-of-distribution robustness with a marginal drop in performance on the original classification task, despite being synthetic. Lastly, our generative model can be trained efficiently on a single GPU, exploiting common pre-trained models as inductive biases.

1. INTRODUCTION

Deep neural networks (DNNs) are the main building blocks of many state-of-the-art machine learning systems that address diverse tasks such as image classification (He et al., 2016) , natural language processing (Brown et al., 2020) , and autonomous driving (Ohn-Bar et al., 2020) . Despite the considerable successes of DNNs, they still struggle in many situations, e.g., classifying images perturbed by an adversary (Szegedy et al., 2013) , or failing to recognize known objects in unfamiliar contexts (Rosenfeld et al., 2018) or from unseen poses (Alcorn et al., 2019) . Many of these failures can be attributed to dataset biases (Torralba & Efros, 2011) or shortcut learning (Geirhos et al., 2020) . The DNN learns the simplest correlations and tends to ignore more complex ones. This characteristic becomes problematic when the simple correlation is spurious, i.e., not present during inference. The motivational example of (Beery et al., 2018) considers the setting of a DNN that is trained to recognize cows in images. A real-world dataset will typically depict cows on green pastures in most images. The most straightforward correlation a classifier can learn to predict the label "cow" is hence the connection to a green, grass-textured background. Generally, this is not a problem during inference as long as the test data follows the same distribution. However, if we provide the classifier an image depicting a purple cow on the moon, the classifier should still confidently assign the label "cow." Thus, if we want to achieve robust generalization beyond the training data, we need to disentangle possibly spurious correlations from causal relationships. Distinguishing between spurious and causal correlations is one of the core questions in causality research (Pearl, 2009; Peters et al., 2017; Schölkopf, 2019) . One central concept in causality is the assumption of independent mechanisms (IM), which states that a causal generative process is composed of autonomous modules that do not influence each other. In the context of image classification (e.g., on ImageNet), we can interpret the generation of an image as a causal process (Kocaoglu et al., 2018; Goyal et al., 2019; Suter et al., 2019) . We decompose this process into separate IMs, each controlling one factor of variation (FoV) of the image. Concretely, we consider three IMs: one generates the object's shape, the second generates the object's texture, and the third generates the background. With access to these IMs, we can produce counterfactual images, i.e., images of unseen combinations of FoVs. We can then train an ensemble of invariant classifiers on the generated coun-

