LEARNING TO ESTIMATE SHAPLEY VALUES WITH VISION TRANSFORMERS

Abstract

Transformers have become a default architecture in computer vision, but understanding what drives their predictions remains a challenging problem. Current explanation approaches rely on attention values or input gradients, but these provide a limited view of a model's dependencies. Shapley values offer a theoretically sound alternative, but their computational cost makes them impractical for large, high-dimensional models. In this work, we aim to make Shapley values practical for vision transformers (ViTs). To do so, we first leverage an attention masking approach to evaluate ViTs with partial information, and we then develop a procedure to generate Shapley value explanations via a separate, learned explainer model. Our experiments compare Shapley values to many baseline methods (e.g., attention rollout, GradCAM, LRP), and we find that our approach provides more accurate explanations than existing methods for ViTs.

1. INTRODUCTION

Transformers (Vaswani et al., 2017) were originally introduced for NLP, but in recent years they have been successfully adapted to a variety of other domains (Wang et al., 2020; Jumper et al., 2021) . In computer vision, transformer-based models are now used for problems including image classification, object detection and semantic segmentation (Dosovitskiy et al., 2020; Touvron et al., 2021; Liu et al., 2021) , and they achieve state-of-the-art performance in many tasks (Wortsman et al., 2022) . The growing use of transformers in computer vision motivates the question of what drives their predictions: understanding a complex model's dependencies is an important problem in many applications, but the field has not settled on a solution for the transformer architecture. Transformers are composed of alternating self-attention and fully-connected layers, where the selfattention operation associates attention values with every pair of tokens. In vision transformers (ViTs) (Dosovitskiy et al., 2020) , the tokens represent non-overlapping image patches, typically a total of 14 × 14 = 196 patches each of size 16 × 16. It is intuitive to view attention values as indicators of feature importance (Abnar and Zuidema, 2020; Ethayarajh and Jurafsky, 2021) , but interpreting transformer attention in this way is potentially misleading. Recent work has raised questions about the validity of attention as explanation (Serrano and Smith, 2019; Jain and Wallace, 2019; Chefer et al., 2021) , arguing that it provides an incomplete picture of a model's dependence on each token. If attention is not a reliable indicator of feature importance, then what is? We consider the perspective that transformers are no different from any other architecture, and that we can explain their predictions using model-agnostic approaches that are currently used for other architectures. Among these methods, Shapley values are a theoretically compelling approach with feature importance scores that are designed to satisfy many desirable properties (Shapley, 1953; Lundberg and Lee, 2017) . The main challenge for Shapley values in the transformer context is calculating them efficiently, because a naive calculation has exponential running time in the number of patches. If Shapley values are poorly approximated, they are unlikely to reflect a model's true dependencies, but calculating them with high accuracy is currently too slow to be practical. Thus, our work aims to make Shapley values practical for transformers, and for ViTs in particular. Our contributions include: 1. We investigate several approaches for withholding input features from vision transformers, which is a key operation for computing Shapley values. We find that ViTs can accommodate missing image patches by masking attention values for held-out tokens, and that training with random masking is important for models to properly handle partial information. 2. We develop a learning-based approach to estimate Shapley values efficiently and accurately. Our approach involves fine-tuning an existing ViT using a loss function designed specifically for Shapley values (Jethani et al., 2021b) , and we prove that our loss bounds the estimation error without requiring ground truth values during training. Once trained, our explainer model provides a significant speedup over methods like KernelSHAP (Lundberg and Lee, 2017). 3. Our experiments compare the Shapley value-based approach to several groups of competing methods: attention-based, gradient-based and removal-based explanations. We find that our approach provides the best overall performance, correctly identifying influential and noninfluential patches for both target and non-target classes. We verify this using three image datasets, and the results are consistent across multiple metrics. Overall, our work shows that Shapley values can be made practical for vision transformers, and that they provide a compelling alternative to current attention-and gradient-based approaches.

2. RELATED WORK

Understanding neural network predictions is a challenging problem that has been actively researched for the last decade (Simonyan et al., 2013; Zeiler and Fergus, 2014; Ribeiro et al., 2016) . We focus on feature attribution, or identifying the specific input features that influence a prediction, but prior work has also considered other problems (Olah et al., 2017; Kim et al., 2018) . The various techniques that have been developed can be grouped into several categories, which we describe below. Attention-based explanations Transformers use self-attention to associate weights with each pair of tokens (Vaswani et al., 2017) , and a natural idea is to assess which tokens receive the most attention (Clark et al., 2019; Rogers et al., 2020; Vig et al., 2020) . There are several versions of this approach, including attention rollout and attention flow (Abnar and Zuidema, 2020), which analyze attention across multiple layers. Attention is a popular interpretation tool, but it is only one component in a sequence of nonlinear operations that provides an incomplete picture of a model's dependencies (Serrano and Smith, 2019; Jain and Wallace, 2019; Wiegreffe and Pinter, 2019), and direct usage of attention weights has not been shown to perform well in vision tasks (Chefer et al., 2021) . Gradient-based methods For other deep learning models such as CNNs, gradient-based explanations are a popular family of approaches. There are many variations on the idea of calculating input gradients, including methods that modify the input (e.g., SmoothGrad, IntGrad) (Smilkov et al., 2017; Sundararajan et al., 2017; Xu et al., 2020) , operate at intermediate network layers (GradCAM) (Selvaraju et al., 2017) , or design modified backpropagation rules (e.g., LRP, DeepLift) (Bach et al., 2015; Shrikumar et al., 2016; Chefer et al., 2021) . Although they are efficient to compute for arbitrary network architectures, gradient-based explanations achieve mixed results in quantitative benchmarks, including for object localization and the removal of influential features (Petsiuk et al., 2018; Hooker et al., 2019; Saporta et al., 2021; Jethani et al., 2021b) , and they have been shown to be insensitive to the randomization of model parameters (Adebayo et al., 2018) . Removal-based explanations Finally, removal-based explanations are those that quantify feature importance by explicitly withholding inputs from the model (Covert et al., 2021) . For models that require all features to make predictions, several options for removing features include setting them to default values (Zeiler and Fergus, 2014) , sampling replacement values (Agarwal and Nguyen, 2020) and blurring images (Fong and Vedaldi, 2017). These methods work with any model type, but they tend to be slow because they require making many predictions. The simplest approach of removing individual features (known as leave-one-out, Ethayarajh and Jurafsky 2021) is relatively fast, but the computational cost increases as we examine more feature subsets. Shapley values (Shapley, 1953) are an influential approach within the removal-based explanation framework. By examining all feature subsets, they provide a nuanced view of each feature's influence and satisfy many desirable properties (Lundberg and Lee, 2017). They are approximated in practice using methods like TreeSHAP and KernelSHAP (Lundberg et al., 2020; Covert and Lee, 2021) , but these approaches either are not applicable or do not scale to large ViTs. Recent work highlighted

