COUNTERFACTUAL GENERATIVE NETWORKS

Abstract

Neural networks are prone to learning shortcuts -they often model simple correlations, ignoring more complex ones that potentially generalize better. Prior works on image classification show that instead of learning a connection to object shape, deep classifiers tend to exploit spurious correlations with low-level texture or the background for solving the classification task. In this work, we take a step towards more robust and interpretable classifiers that explicitly expose the task's causal structure. Building on current advances in deep generative modeling, we propose to decompose the image generation process into independent causal mechanisms that we train without direct supervision. By exploiting appropriate inductive biases, these mechanisms disentangle object shape, object texture, and background; hence, they allow for generating counterfactual images. We demonstrate the ability of our model to generate such images on MNIST and ImageNet. Further, we show that the counterfactual images can improve out-of-distribution robustness with a marginal drop in performance on the original classification task, despite being synthetic. Lastly, our generative model can be trained efficiently on a single GPU, exploiting common pre-trained models as inductive biases.

1. INTRODUCTION

Deep neural networks (DNNs) are the main building blocks of many state-of-the-art machine learning systems that address diverse tasks such as image classification (He et al., 2016) , natural language processing (Brown et al., 2020) , and autonomous driving (Ohn-Bar et al., 2020) . Despite the considerable successes of DNNs, they still struggle in many situations, e.g., classifying images perturbed by an adversary (Szegedy et al., 2013) , or failing to recognize known objects in unfamiliar contexts (Rosenfeld et al., 2018) or from unseen poses (Alcorn et al., 2019) . Many of these failures can be attributed to dataset biases (Torralba & Efros, 2011) or shortcut learning (Geirhos et al., 2020) . The DNN learns the simplest correlations and tends to ignore more complex ones. This characteristic becomes problematic when the simple correlation is spurious, i.e., not present during inference. The motivational example of (Beery et al., 2018) considers the setting of a DNN that is trained to recognize cows in images. A real-world dataset will typically depict cows on green pastures in most images. The most straightforward correlation a classifier can learn to predict the label "cow" is hence the connection to a green, grass-textured background. Generally, this is not a problem during inference as long as the test data follows the same distribution. However, if we provide the classifier an image depicting a purple cow on the moon, the classifier should still confidently assign the label "cow." Thus, if we want to achieve robust generalization beyond the training data, we need to disentangle possibly spurious correlations from causal relationships. Distinguishing between spurious and causal correlations is one of the core questions in causality research (Pearl, 2009; Peters et al., 2017; Schölkopf, 2019) . One central concept in causality is the assumption of independent mechanisms (IM), which states that a causal generative process is composed of autonomous modules that do not influence each other. In the context of image classification (e.g., on ImageNet), we can interpret the generation of an image as a causal process (Kocaoglu et al., 2018; Goyal et al., 2019; Suter et al., 2019) . We decompose this process into separate IMs, each controlling one factor of variation (FoV) of the image. Concretely, we consider three IMs: one generates the object's shape, the second generates the object's texture, and the third generates the background. With access to these IMs, we can produce counterfactual images, i.e., images of unseen combinations of FoVs. We can then train an ensemble of invariant classifiers on the generated coun-Invariant Classifier Ensemble: An ostrich shape with ostrich texture on an ostrich background. Classifier: An ostrich.

Invariant Classifier Ensemble:

An ostrich shape with strawberry texture on a diving background.

Classifier:

A strawberry.

Generated Counterfactual Image

Figure 1 : Out-of-Domain (OOD) Classification. A classifier focuses on all factors of variation (FoV) in an image. For OOD data, this can be problematic: a FoV might be a spurious correlation, hence, impairing the classifier's performance. An ensemble, e.g., a classifier with a common backbone and multiple heads, each head invariant to all but one FoV, increases OOD robustness. terfactual images, such that every classifier relies on only a single one of those factors. The main idea is illustrated in Figure 1 . By exploiting concepts from causality, this paper links two previously distinct domains: disentangled generative models and robust classification. This allows us to scale our experiments beyond small toy datasets typically used in either domain. The main contributions of our work are as follows: • We present an approach for generating high-quality counterfactual images with direct control over shape, texture, and background. Supervision is only provided by the class label and certain inductive biases we impose on the learning problem. • We demonstrate the usefulness of the generated counterfactual images for the downstream task of image classification on both MNIST and ImageNet. Our model improves the classifier's out-of-domain robustness while only marginally degrading its overall accuracy. • We show that our generative model demonstrates interesting emerging properties, such as generating high-quality binary object masks and unsupervised image inpainting. We release our code at https://github.com/autonomousvision/counterfactual generative networks 

2. STRUCTURAL CAUSAL MODELS FOR IMAGE GENERATION

In this section, we first introduce our ideas on a conceptual level. Concretely, we form a connection between the areas of causality, disentangled representation learning, and invariant classifiers, and highlight that domain randomization (Tobin et al., 2017) is a particular instance of these ideas. In section 3, we will then formulate a concrete model that implements these ideas for image classification. Our goals are two-fold: (i) We aim at generating counterfactual images with previously unseen combinations like a cat with elephant texture or the proverbial "bull in a china shop." (ii) We utilize these images to train a classifier invariant to chosen factors of variation. In the following, we first formalize the problem setting we address. Second, we describe how we can address this setting by structuring a generator network as a structural causal model (SCM). Third, we show how to use the SCM for training robust classifiers.

2.1. PROBLEM SETTING

Consider a dataset comprised of (high-dimensional) observations x (e.g. images), and corresponding labels y (e.g. classes). A common assumption is that each x can be described by lower-dimensional, semantically meaningful factors of variation z (e.g., color or shape of objects in the image). If we can disentangle these factors, we are able to control their influence on the classifier's decision. In the disentanglement literature, the factors are often assumed to be statistically independent, i.e., z is distributed according to p(z) = Π n i=1 (z i ) (Locatello et al., 2018) . However, assuming independence is problematic because certain factors might be correlated in the training data, or the combination of some factors may not exist. Consider the colored MNIST dataset (Kim et al., 2019) , where both the digit's color and its shape correspond to the label. The simplest decision rule a classifier can learn is to count the number of pixels of a specific color value; no notion of the digit's shape is required. This kind of correlation is not limited to constructed datasets -classifiers trained on ImageNet (Deng et al., 2009) strongly rely on texture for classification, significantly more than on the object's shape (Geirhos et al., 2018) . While texture or color is a powerful classification cue, we do not want the

