CONCEPT-LEVEL DEBUGGING OF PART-PROTOTYPE NETWORKS

Abstract

Part-prototype Networks (ProtoPNets) are concept-based classifiers designed to achieve the same performance as black-box models without compromising transparency. ProtoPNets compute predictions based on similarity to class-specific part-prototypes learned to recognize parts of training examples, making it easy to faithfully determine what examples are responsible for any target prediction and why. However, like other models, they are prone to picking up confounders and shortcuts from the data, thus suffering from compromised prediction accuracy and limited generalization. We propose ProtoPDebug, an effective concept-level debugger for ProtoPNets in which a human supervisor, guided by the model's explanations, supplies feedback in the form of what part-prototypes must be forgotten or kept, and the model is fine-tuned to align with this supervision. Our experimental evaluation shows that ProtoPDebug outperforms state-of-the-art debuggers for a fraction of the annotation cost. An online experiment with laypeople confirms the simplicity of the feedback requested to the users and the effectiveness of the collected feedback for learning confounder-free part-prototypes. ProtoPDebug is a promising tool for trustworthy interactive learning in critical applications, as suggested by a preliminary evaluation on a medical decision making task.

1. INTRODUCTION

Part-Prototype Networks (ProtoPNets) are "gray-box" image classifiers that combine the transparency of case-based reasoning with the flexibility of black-box neural networks (Chen et al., 2019) . They compute predictions by matching the input image with a set of learned part-prototypes -i.e., prototypes capturing task-salient elements of the training images, like objects or parts thereof -and then making a decision based on the part-prototype activations only. What makes ProtoPNets appealing is that, despite performing comparably to more opaque predictors, they explain their own predictions in terms of relevant part-prototypes and of examples that these are sourced from. These explanations are -by design -more faithful than those extracted by post-hoc approaches (Dombrowski et al., 2019; Teso, 2019; Lakkaraju & Bastani, 2020; Sixt et al., 2020) and can effectively help stakeholders to simulate and anticipate the model's reasoning (Hase & Bansal, 2020) . Despite all these advantages, ProtoPNets are prone -like regular neural networks -to picking up confounders from the training data (e.g., class-correlated watermarks), thus suffering from compromised generalization and out-of-distribution performance (Lapuschkin et al., 2019; Geirhos et al., 2020) . This occurs even with well-known data sets, as we will show, and it is especially alarming as it can impact high-stakes applications like COVID-19 diagnosis (DeGrave et al., 2021) and scientific analysis (Schramowski et al., 2020) . We tackle this issue by introducing ProtoPDebug, a simple but effective interactive debugger for ProtoPNets that leverages their case-based nature. ProtoPDebug builds on three key observations: (i) In ProtoPNets, confounders -for instance, textual meta-data in X-ray lung scans (DeGrave et al., 2021) and irrelevant patches of background sky or foliage (Xiao et al., 2020) -end up appearing as part-prototypes; (ii) Sufficiently expert and motivated users can easily indicate which part-prototypes are confounded by inspecting the model's explanations; (iii) Concept-level feedback of this kind is context-independent, and as such it generalizes across instances.

