INVESTIGATING AND SIMPLIFYING MASKING-BASED SALIENCY MAP METHODS FOR MODEL INTER-PRETABILITY

Abstract

Saliency maps that identify the most informative regions of an image for a classifier are valuable for model interpretability. A common approach to creating saliency maps involves generating input masks that mask out portions of an image to maximally deteriorate classification performance, or mask in an image to preserve classification performance. Many variants of this approach have been proposed in the literature, such as counterfactual generation and optimizing over a Gumbel-Softmax distribution. Using a general formulation of masking-based saliency methods, we conduct an extensive evaluation study of a number of recently proposed variants to understand which elements of these methods meaningfully improve performance. Surprisingly, we find that a well-tuned, relatively simple formulation of a masking-based saliency model outperforms many more complex approaches. We find that the most important ingredients for high quality saliency map generation are (1) using both masked-in and masked-out objectives and ( 2) training the classifier alongside the masking model. Strikingly, we show that a masking model can be trained with as few as 10 examples per class and still generate saliency maps with only a 0.7-point increase in localization error.

1. INTRODUCTION

The success of CNNs (Krizhevsky et al., 2012; Szegedy et al., 2015; He et al., 2016; Tan & Le, 2019) has prompted interest in improving understanding of how these models make their predictions. Particularly in applications such as medical diagnosis, having models explain their predictions can improve trust in them. The main line of work concerning model interpretability has focused on the creation of saliency maps-overlays to an input image that highlight regions most salient to the model in making its predictions. Among these, the most prominent are gradient-based methods (Simonyan et al., 2013; Sundararajan et al., 2017; Selvaraju et al., 2018) and masking-based methods (Fong & Vedaldi, 2017; Dabkowski & Gal, 2017; Fong & Vedaldi, 2018; Petsiuk et al., 2018; Chang et al., 2019; Zintgraf et al., 2017) . In recent years, we have witnessed an explosion of research based on these two directions. With a variety of approaches being proposed, framed and evaluated in different ways, it has become difficult to assess and fairly evaluate their additive contributions. In this work, we investigate the class of masking-based saliency methods, where we train a masking model to generate saliency maps based on an explicit optimization objective. Using a general formulation, we iteratively evaluate the extent to which recently proposed ideas in the literature improve performance. In addition to evaluating our models against the commonly used Weakly Supervised Object Localization (WSOL) metrics, the Saliency Metric (SM), and the more recently introduced Pixel Average Precision (PxAP; Choe et al., 2020) , we also test our final models against a suite of "sanity checks" for saliency methods (Adebayo et al., 2018; Hooker et al., 2018) . Concretely, we make four major contributions. (1) We find that incorporating both masked-in classification maximization and masked-out entropy maximization objectives leads to the best saliency maps, and continually training the classifier improves the quality of generated maps. (2) We find that the masking model requires only the top layers of the classifier to effectively generate saliency maps. (3) Our final model outperforms other masking-based methods on WSOL and PxAP metrics. (4) We find that a small number of examples-as few as ten per class-is sufficient to train a masker to within the ballpark of our best performing model. 

2. RELATED WORK

Interpretability of machine learning models has been an ongoing topic of research (Ribeiro et al., 2016; Doshi-Velez & Kim, 2017; Samek et al., 2017; Lundberg et al., 2018) . In this work, we focus on interpretability methods that involve generating saliency maps for image classification models. An overwhelming majority of the methods for generating saliency maps for image classifiers can be assigned to two broad families: gradient-based methods and masking-based methods. Gradient-based methods, such as using backpropagated gradients (Simonyan et al., 2013 ), Guided Backprop (Springenberg et al., 2015 ), Integrated Gradients (Sundararajan et al., 2017 ), GradCam (Selvaraju et al., 2018 ), SmoothGrad (Smilkov et al., 2017) and many more, directly use the backpropagated gradients through the classifier to the input to generate saliency maps. 



Figure 1: (A) Overview of the training setup for our final model. The masker is trained to maximize masked-in classification accuracy and masked-out prediction entropy. (B) Masker architecture. The masker takes as input the hidden activations of different layers of the ResNet-50 and produces a mask of the same resolution as the input image. (C) Few-shot training of masker. Performance drops only slightly when trained on much fewer examples compared to the full training procedure.

-based methods modify input images to alter the classifier behavior and use the regions of modifications as the saliency map. Within this class of methods, one line of work focuses on optimizing over the masks directly: Fong & Vedaldi (2017) optimize over a perturbation mask for an image, Petsiuk et al. (2018) aggregates over randomly sampled masks, Fong & Vedaldi (2018) performs an extensive search for masks of a given size, while Chang et al. (2019) includes a counterfactual mask-infilling model to make the masking objective more challenging. The other line of work trains a separate masking model to produce saliency maps: Dabkowski & Gal (2017) trains a model that optimizes similar objectives to Fong & Vedaldi (2017), Zolna et al. (2020) use a continually trained pool of classifiers and an adversarial masker to generate model-agnostic saliency maps, while Fan et al. (2017) identifies super-pixels from the image and then trains the masker similarly in an adversarial manner.

Salient Object Detection(Borji et al., 2014; Wang et al., 2019) is a related line of work that concerns identifying salient objects within an image as an end in itself, and not for the purpose of model interpretability. While it is not uncommon for these methods to incorporate a pretrained image classification model to extract learned visual features, they often also incorporate techniques for improving the quality of saliency maps that are orthogonal to model interpretability. Salient object detection methods that are trained on only image-level labels bear the closest similarity to saliency map generation methods for model interpretability. Hsu et al. (2017) and follow-up Hsu et al. (2019) train a masking model to confuse a binary image-classification model that predicts whether an image contains an object or is a 'background' image. Wang et al. (2017) apply a smooth pooling operation

