LEVERAGING CLASS HIERARCHIES WITH METRIC-GUIDED PROTOTYPE LEARNING

Abstract

In many classification tasks, the set of classes can be organized according to a meaningful hierarchy. This structure can be used to assess the severity of confusing each pair of classes, and summarized under the form of a cost matrix which also defines a finite metric. We propose to integrate this metric in the supervision of a prototypical network in order to model the hierarchical class structure. Our method relies on jointly learning a feature-extracting network and a set of class representations, or prototypes, which incorporate the error metric into their relative arrangement in the embedding space. We show that this simultaneous training allows for consistent improvement of the severity of the network's errors with regard to the class hierarchy when compared to traditional methods and other prototypebased strategies. Furthermore, when the induced metric contains insight on the data structure, our approach improves the overall precision as well. Experiments on four different public datasets-from agricultural time series classification to depth image semantic segmentation-validate our approach.

1. INTRODUCTION

Most classification models focus on maximizing the prediction accuracy, regardless of the semantic nature of errors. This can lead to high performing models, but puzzling errors such as confusing a tiger and a sofa. This casts doubt on what a model actually actually understands from the required task and data distribution. Neural networks in particular have been criticized for their tendency to produce improbable yet confident errors, notably when attacked (Akhtar & Mian, 2018) . The classes of most classification problems can be organized according to a hierarchical structure. Such tree-shaped taxonomy of concepts can be generated by domain experts, or automatically from class names with the WordNet graph (Miller et al., 1990) or word embeddings (Mikolov et al., 2013) . A step towards more reliable and interpretable algorithms would be to explicitly model the difference of gravity between errors, as defined by a hierarchical nomenclature. For a classification task over a set K of K classes, the hierarchy of errors can be encapsulated by a cost matrix D ∈ R K×K + , defined such that the cost of predicting class k when the true class is l is D[k, l] ≥ 0, and D[k, k] = 0 for all k = 1 • • • K. Among many other options (Kosmopoulos et al., 2015) , one can define D[k, l] as the length of the shortest path between the nodes corresponding to classes k and l. As pointed out by Bertinetto et al. (2020) , the first step towards algorithms aware of hierarchical structures would be to generalize the use of cost-based metrics. For example, early iterations of the ImageNet challenge (Russakovsky et al., 2015; Deng et al., 2010) proposed to weight errors according to hierarchy-based costs. For a dataset indexed by N , the Average Hierarchical Cost (AHC) between class predictions y ∈ K N and the true labels z ∈ K N is defined as: AHC(y, z) = 1 |N | n∈N D[y n , z n ] . Along with evaluation metrics, loss functions should also take the cost matrix into account. While it is common to focus on retrieving certain classes through weighting schemes, preventing specific class confusions is less straightforward. The cross entropy with one-hot target vectors for example singles out the prediction with respect to the correct class, but treats all other classes equally. Beyond reducing the AHC, another advantage of incorporating the class hierarchy into the learning phase is that D may contain information about the structure of the data as well. Though it is not always the case, co-hyponyms (i.e. siblings) in a class hierarchy tend to share some structural properties. Encouraging such classes to have similar representations could lead to more efficient learning, e.g. by pooling common feature detectors. Such priors on the class structure may be especially crucial when dealing with a large taxonomy, as noted by Deng et al. (2010) . In this paper, we introduce a method to integrate the class hierarchy into a classification algorithm. We propose a new scale-free, distortion-based regularizer for prototypical network (Yang et al., 2018; Chen et al., 2019) . This penalty allows the network to learn prototypes organized such that their relative distances reflect their distance in a class hierarchy. The contributions of this paper are as follows: • We introduce a scale-independent formulation of the distortion between two metric spaces, and an associated smooth regularizer. • This formulation allows us to incorporate knowledge of the class hierarchy into a neural network at no extra cost in trainable parameters and computation. • We show on four public datasets (CIFAR100 , NYUDv2, S2-Agri, and iNaturalist-19) that our approach decreases the average cost of the prediction of standard backbones. • As illustrated in Figure 1 , we show that our approach can also lead to a better (unweighted) precision, which we attribute to useful priors contained in the taxonomy of classes.

2. RELATED WORK

Prototypical Networks: Our approach builds on the growing corpus of work on prototypical networks. These models are deep learning analogues of nearest centroid classifiers (Tibshirani et al., 2002) and Learning Vector Quantization networks (Sato & Yamada, 1995; Kohonen, 1995) , which associate to each class a representation, or prototype, and classify observations according to the nearest prototype. These networks have been successfully used for few-shot learning (Snell et al., 2017; Dong & Xing, 2018 ), zero-shot learning (Jetley et al., 2015) , and supervized classification (Guerriero et al., 2018; Yang et al., 2018; Mettes et al., 2019; Chen et al., 2019) . In most approaches, the prototypes are directly defined as the centroid of the learnt representations of samples of their classes, and updated at each episode (Snell et al., 2017) or iteration (Guerriero et al., 



For a formal definition of our scale-free distortion, see Section 3.2; computed from the means of class embeddings for the cross entropy.



Figure 1: Mean class representation , prototypes , and 2-dimensional embeddings learnt on perturbed MNIST by a 3-layer convolutional net with three different classification modules: (a) cross-entropy, (b) learnt prototypes, and (c) learnt prototypes guided by a tree-shaped taxonomy (constructed according to the authors' perceived visual similarity between digits). The guided prototypes (c) embed more faithfully the class hierarchy: classes with low error cost are closer 1 . This is associated with a decrease in the Average Hierarchical Cost (AHC), as well as Error Rate (ER), indicating that our taxonomy may contain useful information for learning better visual features.

