CONCEPT-LEVEL DEBUGGING OF PART-PROTOTYPE NETWORKS

Abstract

Part-prototype Networks (ProtoPNets) are concept-based classifiers designed to achieve the same performance as black-box models without compromising transparency. ProtoPNets compute predictions based on similarity to class-specific part-prototypes learned to recognize parts of training examples, making it easy to faithfully determine what examples are responsible for any target prediction and why. However, like other models, they are prone to picking up confounders and shortcuts from the data, thus suffering from compromised prediction accuracy and limited generalization. We propose ProtoPDebug, an effective concept-level debugger for ProtoPNets in which a human supervisor, guided by the model's explanations, supplies feedback in the form of what part-prototypes must be forgotten or kept, and the model is fine-tuned to align with this supervision. Our experimental evaluation shows that ProtoPDebug outperforms state-of-the-art debuggers for a fraction of the annotation cost. An online experiment with laypeople confirms the simplicity of the feedback requested to the users and the effectiveness of the collected feedback for learning confounder-free part-prototypes. ProtoPDebug is a promising tool for trustworthy interactive learning in critical applications, as suggested by a preliminary evaluation on a medical decision making task.

1. INTRODUCTION

Part-Prototype Networks (ProtoPNets) are "gray-box" image classifiers that combine the transparency of case-based reasoning with the flexibility of black-box neural networks (Chen et al., 2019) . They compute predictions by matching the input image with a set of learned part-prototypes -i.e., prototypes capturing task-salient elements of the training images, like objects or parts thereof -and then making a decision based on the part-prototype activations only. What makes ProtoPNets appealing is that, despite performing comparably to more opaque predictors, they explain their own predictions in terms of relevant part-prototypes and of examples that these are sourced from. These explanations are -by design -more faithful than those extracted by post-hoc approaches (Dombrowski et al., 2019; Teso, 2019; Lakkaraju & Bastani, 2020; Sixt et al., 2020) and can effectively help stakeholders to simulate and anticipate the model's reasoning (Hase & Bansal, 2020) . Despite all these advantages, ProtoPNets are prone -like regular neural networks -to picking up confounders from the training data (e.g., class-correlated watermarks), thus suffering from compromised generalization and out-of-distribution performance (Lapuschkin et al., 2019; Geirhos et al., 2020) . This occurs even with well-known data sets, as we will show, and it is especially alarming as it can impact high-stakes applications like COVID-19 diagnosis (DeGrave et al., 2021) and scientific analysis (Schramowski et al., 2020) . We tackle this issue by introducing ProtoPDebug, a simple but effective interactive debugger for ProtoPNets that leverages their case-based nature. ProtoPDebug builds on three key observations: (i) In ProtoPNets, confounders -for instance, textual meta-data in X-ray lung scans (DeGrave et al., 2021) and irrelevant patches of background sky or foliage (Xiao et al., 2020) -end up appearing as part-prototypes; (ii) Sufficiently expert and motivated users can easily indicate which part-prototypes are confounded by inspecting the model's explanations; (iii) Concept-level feedback of this kind is context-independent, and as such it generalizes across instances. The model has acquired a confounded part-prototype p (the blue square " ") that correlates with, but is not truly causal for, the Crested Auklet class, and hence mispredicts both unconfounded images of this class and confounded images of other classes (top row). Upon inspection, an end-user forbids the model to learn part-prototypes similar to p, achieving improved generalization (bottom row). Relevance of all part-prototypes is omitted for readability but assumed positive. In short, ProtoPDebug leverages the explanations naturally output by ProtoPNets to acquire concept-level feedback about confounded (and optionally high-quality) part-prototypes, as illustrated in Fig. 1 . Then, it aligns the model using a novel pair of losses that penalize part-prototypes for behaving similarly to confounded concepts, while encouraging the model to remember highquality concepts, if any. ProtoPDebug is ideally suited for human-in-the-loop explanation-based debugging (Kulesza et al., 2015; Teso & Kersting, 2019) , and achieves substantial savings in terms of annotation cost compared to alternatives based on input-level feedback (Barnett et al., 2021) . In fact, in contrast to the per-pixel relevance masks used by other debugging strategies (Ross et al., 2017; Teso & Kersting, 2019; Plumb et al., 2020; Barnett et al., 2021) , concept-level feedback automatically generalizes across instances, thus speeding up convergence and preventing relapse. Our experiments show that ProtoPDebug is effective at correcting existing bugs and at preventing new ones on both synthetic and real-world data, and that it needs less corrective supervision to do so than state-of-the-art alternatives. Contributions. Summarizing, we: (1) Highlight limitations of existing debuggers for black-box models and ProtoPNets; (2) Introduce ProtoPDebug, a simple but effective strategy for debugging ProtoPNets that drives the model away from using confounded concepts and prevents forgetting well-behaved concepts; (3) Present an extensive empirical evaluation showcasing the potential of ProtoPDebug on both synthetic and real-world data sets. Embedding stage: Let x be an image of shape w × h × d, where d is the number of channels. The embedding stage passes x through a sequence of (usually pre-trained) convolutional and pooling layers with parameters φ, obtaining a latent representation z = h(x) of shape w × h × d , where w < w and h < h. Let Q(z) be the set of 1 × 1 × d subtensors of z. Each such subtensor q ∈ Q(z) encodes a filter in latent space and maps a rectangular region of the input image x. Part-prototype stage: This stage memorizes and uses k partprototypes P = {p 1 , . . . , p k }. Each p j is a tensor of shape 1 × 1 × d explicitly learned -as explained below -so as to capture salient visual concepts appearing in the training images, like heads or wings. The activation of a part-prototype p on a part q ∈ Q(z) is given by a difference-of-logarithms function, defined as (Chen et al., 2019): act(p, q) := log( pq 2 + 1) -log( pq 2 + ) ≥ 0 (1)



Figure 1: Left: architecture of ProtoPNets. Right: schematic illustration of the ProtoPDebug loop.The model has acquired a confounded part-prototype p (the blue square " ") that correlates with, but is not truly causal for, the Crested Auklet class, and hence mispredicts both unconfounded images of this class and confounded images of other classes (top row). Upon inspection, an end-user forbids the model to learn part-prototypes similar to p, achieving improved generalization (bottom row). Relevance of all part-prototypes is omitted for readability but assumed positive.

2 PART-PROTOTYPE NETWORKS ProtoPNets (Chen et al., 2019) classify images into one of v classes using a three-stage process comprising an embedding stage, a part-prototype stage, and an aggregation stage; see Fig. 1 (left).

Figure 2: Part-prototype activation functions.

