HOLISTICALLY EXPLAINABLE VISION TRANSFORMERS

Abstract

Transformers increasingly dominate the machine learning landscape across many tasks and domains, which increases the importance for understanding their outputs. While their attention modules provide partial insight into their inner workings, the attention scores have been shown to be insufficient for explaining the models as a whole. To address this, we propose B-cos transformers, which inherently provide holistic explanations for their decisions. Specifically, we formulate each model component-such as the multi-layer perceptrons, attention layers, and the tokenisation module-to be dynamic linear, which allows us to faithfully summarise the entire transformer via a single linear transform. We apply our proposed design to Vision Transformers (ViTs) and show that the resulting models, dubbed Bcos-ViTs, are highly interpretable and perform competitively to baseline ViTs on ImageNet.

1. INTRODUCTION

Convolutional neural networks (CNNs) have dominated the last decade of computer vision. However, recently they are often surpassed by transformers (Vaswani et al., 2017) , whichif the current development is any indicationwill replace CNNs for ever more tasks and domains. Transformers are thus bound to impact many aspects of our lives: from healthcare, over judicial decisions, to autonomous driving. Given the sensitive nature of such areas, it is of utmost importance to ensure that we can explain the underlying models, which still remains a challenge for transformers. To explain transformers, prior work often focused on the models' attention layers (Jain & Wallace, 2019; Serrano & Smith, 2019; Abnar & Zuidema, 2020; Barkan et al., 2021) , as they inherently compute their output in an interpretable manner. However, as transformers consist of many additional components, explanations derived from attention alone have been found insufficient to explain the full models (Bastings & Filippova, 2020; Chefer et al., 2021) . To address this, our goal is to develop transformers that inherently provide holistic explanations for their decisions, i.e. explanations that reflect all model components. These model components are given by: a tokenisation module, a mechanism for providing positional information to the model, multi-layer perceptrons (MLPs), as well as normalisation and attention layers, see Fig. 2a . By addressing the interpretability of each component individually, we obtain transformers that inherently explain their decisions, see, for example Fig. 1 and Fig. 2b . In detail, our approach is based on the idea of designing each component to be dynamic linear, such that it computes an input-dependent linear transform. This renders the entire model dynamic linear, cf. Böhle et al. (2021; 2022) , s.t. it can be summarised by a single linear transform for each input. In short, we make the following contributions. (I) We present a novel approach for designing inherently interpretable transformers. For this, (II) we carefully design each model component to be dynamic linear and ensure that their combination remains dynamic linear and interpretable. Specifically, we address (IIa) the module, (IIb) the attention layers, (IIc) the MLPs, and (IId) the classification head. (III) Additionally, we introduce a novel mechanism for allowing the model to learn attention priors, which breaks the permutation invariance of transformers and thus allows the model to easily leverage positional information. In our experiments, we find that B-cos ViTs with such a learnt 'attention prior' achieve significantly higher classification accuracies. (IV) Finally, we evaluate a wide range of model configurations and show that the proposed B-cos ViTs are not only highly interpretable, but also constitute powerful image classifiers.

2. RELATED WORK

Attention as Explanation. As the name exemplifies, attention is often thought to give insight into what a model 'pays attention to' for its prediction. As such, various methods for using attention to understand the model output have been proposed, such as visualising the attention of single attention heads, cf. Vaswani et al. (2017) . However, especially in deeper layers the information becomes increasingly distributed and it is thus unclear whether a given token still represents its original position in the input (Serrano & Smith, 2019; Abnar & Zuidema, 2020) , thus complicating the interpretation of high attention values deep in the network (Serrano & Smith, 2019; Bastings & Filippova, 2020) . Therefore, Abnar & Zuidema (2020) proposed 'attention rollout', which summarises the various attention maps throughout the layers. However, this summary still only includes the attention layers and neglects all other network components (Bastings & Filippova, 2020) . In response, various improvements over attention rollout have been proposed, such as GradSAM (Barkan et al., 2021) or an LRP-based explanation method (Chefer et al., 2021) , that were designed to more accurately reflect the computations of all model components. The significant gains in quantitative interpretability metrics reported by Chefer et al. ( 2021) highlight the importance of such holistic explanations. Similarly, we also aim to derive holistic explanations for transformers. However, instead of deriving an explanation 'post-hoc' as in Chefer et al. ( 2021), we explicitly design our models to be holistically explainable. For this, we formulate each component-and thus the full model-to be dynamic linear. Dynamic Linearity. Plain linear models, i.e. y(x)=Wx, are usually considered interpretable, as y(x) can be decomposed into individual contributions c i =w i x i from any dimension i: y = i c i (Alvarez-Melis & Jaakkola, 2018). However, linear models have a limited capacity, which has lead to various works aimed at extending their capacity without losing their interpretability, see Alvarez-Melis & Jaakkola (2018); Brendel & Bethge (2019); Böhle et al. (2021; 2022) . An appealing strategy for this is formulating dynamic linear models (Alvarez-Melis & Jaakkola, 2018; Böhle et al., 2021; 2022) , i.e. models that transform the input with a data-dependent matrix W(x): y(x)=W(x)x. In this work, we rely on the B-cos framework (Böhle et al., 2022) , but instead of focusing on CNNs as in Böhle et al. (2022) , we investigate the applicability of this framework to transformers. Interpretability in DNNs. The question of interpretability extends, of course, beyond transformers and many methods for explaining DNNs have been proposed. While other approaches exist, cf. Kim



Fig. 1: Inherent explanations (cols. 2+3) of B-cos ViTs vs. attention explanations (cols. 4+5) for the same model. Note that W(x) faithfully reflects the whole model and yields more detailed and class-specific explanations than attention alone. For a detailed discussion, see supplement.

Fig. 2: (a) B-cos ViTs. We design each ViT component to be dynamic linear, allowing us to summarise the entire model by a single linear transform W(x), as shown in the bottom. (b) Computation is Explanation.The model output is exactly computed by the linear transform W(x). As a result, we can visualise this effective linear transform either by the corresponding matrix row (center) or the contributions c k (x) (right), cf. Eq. (7).

