LEVERAGING CLASS HIERARCHIES WITH METRIC-GUIDED PROTOTYPE LEARNING

Abstract

In many classification tasks, the set of classes can be organized according to a meaningful hierarchy. This structure can be used to assess the severity of confusing each pair of classes, and summarized under the form of a cost matrix which also defines a finite metric. We propose to integrate this metric in the supervision of a prototypical network in order to model the hierarchical class structure. Our method relies on jointly learning a feature-extracting network and a set of class representations, or prototypes, which incorporate the error metric into their relative arrangement in the embedding space. We show that this simultaneous training allows for consistent improvement of the severity of the network's errors with regard to the class hierarchy when compared to traditional methods and other prototypebased strategies. Furthermore, when the induced metric contains insight on the data structure, our approach improves the overall precision as well. Experiments on four different public datasets-from agricultural time series classification to depth image semantic segmentation-validate our approach.

1. INTRODUCTION

Most classification models focus on maximizing the prediction accuracy, regardless of the semantic nature of errors. This can lead to high performing models, but puzzling errors such as confusing a tiger and a sofa. This casts doubt on what a model actually actually understands from the required task and data distribution. Neural networks in particular have been criticized for their tendency to produce improbable yet confident errors, notably when attacked (Akhtar & Mian, 2018) . The classes of most classification problems can be organized according to a hierarchical structure. Such tree-shaped taxonomy of concepts can be generated by domain experts, or automatically from class names with the WordNet graph (Miller et al., 1990) or word embeddings (Mikolov et al., 2013) . A step towards more reliable and interpretable algorithms would be to explicitly model the difference of gravity between errors, as defined by a hierarchical nomenclature. 2020), the first step towards algorithms aware of hierarchical structures would be to generalize the use of cost-based metrics. For example, early iterations of the ImageNet challenge (Russakovsky et al., 2015; Deng et al., 2010) proposed to weight errors according to hierarchy-based costs. For a dataset indexed by N , the Average Hierarchical Cost (AHC) between class predictions y ∈ K N and the true labels z ∈ K N is defined as: AHC(y, z) = 1 |N | n∈N D[y n , z n ] . Along with evaluation metrics, loss functions should also take the cost matrix into account. While it is common to focus on retrieving certain classes through weighting schemes, preventing specific class confusions is less straightforward. The cross entropy with one-hot target vectors for example singles out the prediction with respect to the correct class, but treats all other classes equally. Beyond reducing the AHC, another advantage of incorporating the class hierarchy into the learning phase



For a classification task over a set K of K classes, the hierarchy of errors can be encapsulated by a cost matrix D ∈ R K×K + , defined such that the cost of predicting class k when the true class is l is D[k, l] ≥ 0, and D[k, k] = 0 for all k = 1 • • • K. Among many other options (Kosmopoulos et al., 2015), one can define D[k, l] as the length of the shortest path between the nodes corresponding to classes k and l. As pointed out by Bertinetto et al. (

