THIS LOOKS LIKE IT RATHER THAN THAT: PROTOKNN FOR SIMILARITY-BASED CLASSIFIERS

Abstract

Among research on the interpretability of deep learning models, the 'this looks like that' framework with ProtoPNet has attracted significant attention. By combining the strong power of deep learning models with the interpretability of casebased inference, ProtoPNet can achieve high accuracy while keeping its reasoning process interpretable. Many methods based on ProtoPNet have emerged to take advantage of this benefit, but despite their practical usefulness, they run into difficulty when utilizing similarity-based classifiers. This is because ProtoPNet and its variants adopt the training process specific to linear classifiers, which allows the prototypes to represent useful image features for class recognition. Due to this difficulty, the effectiveness of similarity-based classifiers (e.g., k-nearest neighbor (KNN)) on the 'this looks like that' framework have not been sufficiently examined. To alleviate this problem, we propose ProtoKNN, an extension of ProtoP-Net that adopts KNN classifiers. Extensive experiments on multiple open datasets demonstrate that the proposed method can achieve competitive results with a stateof-the-art method.

1. INTRODUCTION

Deep learning has achieved very high accuracy in a variety of computer vision tasks. However, since the reasoning process of deep learning models is black-boxed and cannot be interpreted by human operators, it is very difficult to validate their inference, and this impedes their utilization in high-risk domains. To alleviate this problem, several methods for constructing inherently interpretable models have been proposed. However, inherently interpretable models generally suffer from degraded accuracy compared to black-box models. 'Gray-box' models have thus been proposed (Alvarez-Melis, 2018; Chen, 2019; Koh, 2020) to take advantage of the power of deep learning models while keeping the reasoning process interpretable. Among the gray-box model approaches, the 'this looks like that' framework with ProtoPNet (Chen, 2019) has attracted significant attention because it can guarantee a transparent reasoning process without any additional supervision. ProtoPNet first calculates the similarity of the input samples to the prototypes corresponding to an image patch in the training set and then classifies samples with inherently interpretable models on the basis of this similarity. This process enables ProtoPNet to explain its reasoning process by providing patches in the training set that the model considers similar to the input sample. Thus, interpretability with case-based reasoning is achieved. Thanks to this advantage in transparency, many methods based on ProtoPNet have been proposed (Wang, 2021; Nauta, 2021; Rymarczyk, 2021; Donnelly, 2022; Keswani, 2022; Rymarczyk, 2022) . When training ProtoPNet, the weights of the linear classifier connecting each of the prototypes and class logits are fixed, and the feature vectors corresponding to an image patch are linked to the prototypes if the prototypes make a positive contribution to the class logits to which the image belongs. This enables the prototypes to represent the image patches most useful for class recognition. However, due to this special training process, it is difficult for ProtoPNet to utilize any classifiers other than the linear classifier. As an alternative, Nauta (2021) proposed ProtoTree, which use a decision tree for the last classifier. However, this method is limited to the decision tree, which makes it difficult to utilize in similarity-based classifiers with the 'this looks like that' framework. Similaritybased classifiers perform inference on the basis of similarities (or distances) between samples. As we will demonstrate in the experimental section and in the Appendix (Sec. D.2), interpreting the distance enables us to obtain more fine-grained explanation in a counterfactual manner, which is important for understanding and interpreting the reasoning process of the model. Therefore, in this work, we extend ProtoPNet so that we can utilize similarity-based classifiers (specifically, k-nearest neighbor (KNN) classifiers) in the 'this looks like that' framework. When extending ProtoPNet to the similarity-based classifier, it is no longer possible to pre-define the relationship between the prototypes and the class labels, which is necessary for calculating the cluster loss. As we will discuss in Sec. 2.2.2, it is also difficult to estimate this relationship from only one sample. Therefore, the main difficulty is how to estimate the relationship between the prototypes and the class labels. Our concept for estimating this relationship is to compare each sample in a minibatch and sum up the most distinctive prototypes. This enables us to estimate which prototypes are relatively more related to which samples and thus which classes. Then, our novel cluster loss can be defined based on this estimation. In summary, our contributions are three-fold: ・ We propose ProtoKNN, an extension of ProtoPNet that can utilize KNN classifiers. This is the first work to examine the effectiveness of similarity-based classifiers in the 'this looks like that' framework. ・ We developed a novel loss function for ProtoKNN that replaces the cluster loss in Pro-toPNet. This enables us to train our model without predefining the relationship between classes and prototypes. ・ The proposed method achieved competitive results with a state-of-the-art ProtoPNet variant on multiple open datasets.

2. METHOD

In the following, we first describe the notationfoot_0 used in this paper. Then, we present the training strategy of the proposed method and explain how to classify the samples. Finally, we demonstrate how to interpret the reasoning processes of our method. For context, we also briefly revisit the origins of ProtoPNet and elaborate on why it cannot directly utilize similarity-based classifiers in the Appendix (Sec. A).

2.1. PRELIMINARY

The input image and its class labels are denoted as x and y, respectively. Unless otherwise specified, we use subscripts a, b, ... to denote the data indices. Thus, an input image and its class label are denoted as x a and y a , respectively. The index sets of the images in a minibatch and their cardinality are denoted as B and |B|, respectively. We use F to denote the feature extractor and Z to denote the feature map output by F , i.e., Z a = F (x a ). z is used to denote the feature vectors contained in a pixel of the feature map Z. We also call these feature vectors 'image patch features' in this paper. After the transformation, the similarity between the prototypes {p i } i=0,1,... and the image patch features contained in the feature map Z a are calculated. The maximum similarity value is defined as the similarity of the input image x a to the prototype p i , as s a,i = max z∈Za Sim(z, p i ). Here, we denote the similarity of the input image x a as s a and its component corresponding to p i as s a,i . Sim is the function that calculates the similarity (cosine similarity in this paper) between the feature vectors and the prototypes. In the following, we refer to the similarity s a as the 'prototype profile' of the input image x a , and the index sets of the prototypes are denoted as P. The indicator function is denoted as 1(condition), which returns 1 if the condition is true and 0 otherwise.

2.2. TRAINING PROCESS

Originally, ProtoPNet utilized three loss functions: classification loss, cluster loss, and separation loss. In the proposed method, we do not use the separation loss because we expect the prototypes to be common among the samples with different class labels. Instead, we use the auxiliary loss function proposed in the context of deep metric learning to help the backbone model acquire better feature extraction ability. In summary, we train our models with three loss functions: classification loss L task , the novel cluster loss L clst , and auxiliary loss L aux . Figure 1 shows the loss scheme of the proposed method. The details of each loss are described in the following.



We basically follow https://github.com/goodfeli/dlbook_notation/.

