SCHEMA INFERENCE FOR INTERPRETABLE IMAGE CLASSIFICATION

Abstract

In this paper, we study a novel inference paradigm, termed as schema inference, that learns to deductively infer the explainable predictions by rebuilding the prior deep neural network (DNN) forwarding scheme, guided by the prevalent philosophical cognitive concept of schema. We strive to reformulate the conventional model inference pipeline into a graph matching policy that associates the extracted visual concepts of an image with the pre-computed scene impression, by analogy with human reasoning mechanism via impression matching. To this end, we devise an elaborated architecture, termed as SchemaNet, as a dedicated instantiation of the proposed schema inference concept, that models both the visual semantics of input instances and the learned abstract imaginations of target categories as topological relational graphs. Meanwhile, to capture and leverage the compositional contributions of visual semantics in a global view, we also introduce a universal Feat2Graph scheme in SchemaNet to establish the relational graphs that contain abundant interaction information. Both the theoretical analysis and the experimental results on several benchmarks demonstrate that the proposed schema inference achieves encouraging performance and meanwhile yields a clear picture of the deductive process leading to the predictions.

1. INTRODUCTION

"Now this representation of a general procedure of the imagination for providing a concept with its image is what I call the schema for this concept 1 ." -Immanuel Kant Deep neural networks (DNNs) have demonstrated the increasingly prevailing capabilities in visual representations as compared to conventional hand-crafted features. Take the visual recognition task as an example. The canonical deep learning (DL) scheme for image recognition is to yield an effective visual representation from a stack of non-linear layers along with a fully-connected (FC) classifier at the end (He et al., 2016; Dosovitskiy et al., 2021; Tolstikhin et al., 2021; Yang et al., 2022a) , where specifically the inner-product similarities are computed with each category embedding as the prediction. Despite the great success of DL, existing deep networks are typically required to simultaneously perceive low-level patterns as well as high-level semantics to make predictions (Zeiler & Fergus, 2014; Krizhevsky et al., 2017) . As such, both the procedure of computing visual representations and the learned category-specific embeddings are opaque to humans, leading to challenges in security-matter scenarios, such as autonomous driving and healthcare applications. Unlike prior works that merely obtain the targets in the black-box manner, here we strive to devise an innovative and generalized DNN inference paradigm by reformulating the traditional one-shot forwarding scheme into an interpretable DNN reasoning framework, resembling what occurs in deductive human reasoning. Towards this end, inspired by the schema in Kant's philosophy that describes human cognition as the procedure of associating an image of abstract concepts with the specific sense impression, we propose to formulate DNN inference into an interactive matching procedure between the local visual semantics of an input instance and the abstract category imagination, which is termed as schema inference in this paper, leading to the accomplishment of interpretable deductive inference based on visual semantics interactions at the micro-level. To elaborate the achievement of the proposed concept schema inference, we take here image classification, the most basic task in computer vision, as an example to explain our technical details. At a high level, the devised schema inference scheme leverages a pre-trained DNN to extract feature ingredients which are, in fact, the semantics represented by a cluster of deep feature vectors from a specific local region in the image domain. Furthermore, the obtained feature ingredients are organized into an ingredient relation graph (IR-Graph) for the sake of modeling their interactions that are characterized by the similarity at the semantic-level as well as the adjacency relationship at the spatial-level. We then implement the category-specific imagination as an ingredient relation atlas (IR-Atlas) for all target categories induced from observed data samples. As a final step, the graph similarity between an instance-level IR-Graph and the category-level IR-Atlas is computed as the measurement for yielding the target predictions. As such, instead of relying on deep features, the desired outputs from schema inference contribute only from the relationship of visual words, as shown in Figure 1 . More specifically, our dedicated schema-based architecture, termed as SchemaNet, is based on vision Transformers (ViTs) (Dosovitskiy et al., 2021; Touvron et al., 2021) , which are nowadays the most prevalent vision backbones. To effectively obtain the feature ingredients, we collect the intermediate features of the backbone from probe data samples clustered by k-means algorithm. IR-Graphs are established through a customized Feat2Graph module that transfers the discretized ingredients array to graph vertices, and meanwhile builds the connections, which indicates the ingredient interactions relying on the self-attention mechanism (Vaswani et al., 2017) and the spatial adjacency. Eventually, graph similarities are evaluated via a shallow graph convolutional network (GCN). Our work relates to several existing methods that mine semantic-rich visual words from DNN backbones for self-explanation (Brendel & Bethge, 2019; Chen et al., 2019; Nauta et al., 2021; Xue et al., 2022b; Yang et al., 2022b) . Particularly, BagNet (Brendel & Bethge, 2019) word extractor that is analogous to traditional bag-of-visual-words (BoVW) representation (Yang et al., 2007) with SIFT features (Lowe, 2004) . Moreover, Chen et al. (2019) develop ProtoPNet framework that evaluates the similarity scores between image parts and learned class-specific prototypes for inference. Despite their encouraging performance, all these existing methods suffer from the lack of considerations in the compositional contributions of visual semantics, which is, however, already proved critical for both human and DNN inference (Deng et al., 2022) , and are not applicable to schema inference in consequence. We discuss more related literature in Appendix A. In sum, our contribution is an innovative schema inference paradigm that reformulates the existing black-box network forwarding into a deductive DNN reasoning procedure, allowing for a clear insight into how the class evidences are gathered and deductively lead to the target predictions. This is achieved by building the dedicated IR-Graph and IR-Atlas with the extracted feature ingredients and then performing graph matching between IR-Graph and IR-Atlas to derive the desired predictions. We also provide a theoretical analysis to explain the interpretability of our schema inference results. Experimental results on CIFAR-10/100, Caltech-101, and ImageNet demonstrate that the proposed schema inference yields results superior to the state-of-the-art interpretable approaches. Further, we demonstrate that by transferring to unseen tasks without fine-tuning the matcher, the learned knowledge of each class is, in fact, stored in IR-Atlas rather than the matcher, thereby exhibiting the high interpretability of the proposed method.

2. SCHEMANET

In this section, we introduce our proposed SchemaNet in detail. The overall procedure is illustrated in Figure 2 , including a Feat2Graph module that converts deep features to instance IR-Graphs, a learnable IR-Atlas, and a graph matcher for making predictions. The main idea is to model an input image as an instance-level graph in which nodes are local semantics captured by the DNN backbone and edges represent their interactions. Meanwhile, a category-level graph is maintained as the imagination for each class driven by training samples. Finally, by measuring the similarity between the instance and category graphs, we are able to interpret how predictions are made.

2.1. PRELIMINARY

We first give a brief review of the ViT architecture. The simplest ViT implementation for image classification is proposed by Dosovitskiy et al. (2021) , which treats an image as a sequence of independent 16 × 16 patches. The embeddings of all patches are then directly fed to a Transformerbased network (Vaswani et al., 2017) , i.e., a stack of multiple encoder layers with the same structure: a multi-head self-attention mechanism (MHSA) followed by a multilayer perceptron (MLP) with residual connections. Additionally, they append a class token (CLS) to the input sequence for classification alone rather than the average pooled feature. Following this work, Touvron et al. (2021) propose DeiT appending another distillation token (DIST) to learn soft decision targets from a teacher model. This paper will mainly focus on DeiT backbones due to the relatively simple yet efficient network architecture. Let X ∈ R (n+ζ)×d denotes the input feature (a sequence of n visual tokens and ζ auxillary tokens, e.g., CLS and DIST, each of which is a d-dimensional embedding vector) to a MHSA module. The output can be simplified as MHSA(X) = H h=1 Ψh XW h , where Ψh ∈ R (n+ζ)×(n+ζ) denotes the self-attention matrix normalized by row-wise softmax of head h ∈ {1, . . . , H}, and W h ∈ R d×d is a learnable projection matrix. As the MLP module processes each token individually, the visual token interactions take place only in the MHSA modules, facilitating the extraction of their relationships from the perspective of ViTs. For convenience, we define some notations related to the self-attention matrix. Let Ψ be the average attention matrix of all heads and Ψ be the symmetric attention matrix: Ψ = ( Ψ + Ψ⊤ )/2. We partition Ψ into four submatrices: Ψ = Ψ * Ψ A Ψ A ⊤ Ψ V , ( ) where Ψ A ∈ R ζ×n is the attention to the auxiliary tokens, Ψ V ∈ R n×n represents the relationships of visual tokens, and Ψ * ∈ R ζ×ζ . Finally, let ψ CLS ∈ R n represent the attention values to the CLS token extracted from Ψ A .

2.2. FEAT2GRAPH

As shown in Figure 2 , given a pre-trained backbone and an input image, the intermediate feature X is converted to an IR-Graph G for inference by the Feat2Graph module which includes three steps: (1) discretizing X into a sequence of feature ingredients with specific semantics; (2) mapping the ingredients to weighted vertices of IR-Graph; (3) assigning a weighted edge to each vertex pair indicating their interaction. We start by defining IR-Graph and IR-Atlas mathematically. Definition 1 (IR-Graph). An IR-Graph is an undirected fully-connected graph G = (E, V ), in which vertiex set V is the set of feature ingredients for an instance or a specific category, and edge set E encodes the interactions of vertex pairs. In particular, to indicate the vertex importance (for instance, "bird head" should contribute to the prediction of bird more than "sky" in human cognition), a non-negative weight λ ∈ R + is assigned to each vertex. Let Λ ∈ R |V | + = {λ i } |V | i=1 be the collection of all vertex weights. Besides, the interaction between vertex i and j is quantified by a non-negative weight e i,j ∈ R + . With a slight abuse of notation, we define E ∈ R

|V |×|V | +

as the weighted adjacency matrix in the following sections. The IR-Atlas is defined to represent the imagination of all categories: Definition 2 (IR-Atlas). An IR-Atlas Ĝ of C classes is a set of category-level IR-Graphs, in which element Ĝc = ( Êc , Vc ) for class c has learnable vertex weights Λc and edge weights Êc . Discretization. The main purpose of feature discretization is to mine the most common patterns appearing in the dataset, each of which corresponds to a human understandable semantic, named visual word, such as "car wheel" or "bird head". However, local regions with the same semantic may show differently because of scaling, rotation, or even distortion. In our approach, we utilize the DNN backbone to extract a relatively uniform feature instead of the traditional SIFT feature utilized in (Yang et al., 2007) . To be specific, a visual vocabulary Ω = {ω i } M i=1 (ω i ∈ R d ) of size M is constructed by k-means clustering running on the collection X of visual tokens extracted from the probe datasetfoot_0 . Further, a deep feature X (CLS and DIST are removed) is discretized to a sequence of feature ingredients by replacing each element x with the index of the closest visual word Ingredient(x) = arg min i∈{1,...,M } ∥x -ω i ∥ 2 . (2) Strictly, the ingredient is referred to as the index of visual word ω. With the entire ingredient set M = {1, . . . , M }, the discretized sequence is denoted as X = (x 1 , . . . , xn ), where xi ∈ M is computed from Equation (2). The detailed settings of M are presented in Appendix E.1. Feat2Vertex. After we have computed the discretized feature X, the vertices of this feature are the unique ingredients: V = Unique( X). The importance of each vertex is measured from two criterions: the contribution to the DNN's final prediction and the appearance statistically. Formally, for vertex v ∈ M, the importance is defined as λ v = α 1 λ CLS v + α 2 λ bag v = α 1 i∈Ξ(v| X) ψ CLS i + α 2 |Ξ(v| X)|, where Ξ(v| X) is the set of all the appeared position of ingredient v in X, ψ CLS i is the attention between the i-th visual token and the CLS token, and α 1,2 ≥ 0 are learnable weights balancing the two terms. When α 1 = 0, Λ is equivalent to BoVW representation.

Feat2Edge.

With the self-attention matrix Ψ V extracted from Equation (1), the interactions between vertices can be defined and computed with efficiency. For any two different vertices u, v ∈ V , the edge weight is the comprehensive consideration of the similarity in the view of ViT and spatial adjacency. The first term is the average attention between all repeated pairs Π[(u, v)| X]: e attn u,v = 1 Π[(u, v)| X] (i,j)∈Π[(u,v)| X] Ψ V i,j , where Π[(u, v)| X] is the Cartesian product of Ξ(u| X) and Ξ(v| X), and Ψ C i,j is the attention between the visual tokens at i and j positions. Furthermore, we define the adjacency as e adj u,v = 1 Π[(u, v)| X] (i,j)∈Π[(u,v)| X] 1 ϵ + ∥Pos(i) -Pos(j)∥ 2 , where function Pos(•) returns the original 2D coordinates of the input visual token with regard to the patch array after the patch embedding module of ViTs. Eventually, the interaction between vertices u and v is the weighted sum of the two components with learnable β 1,2 : e u,v = β 1 e attn u,v + β 2 e adj u,v . Equations ( 4) and ( 5) are both invariant when exchanging vertex u and v, so our IR-Graph is equivalent to undirected graph. It is worth noting that only semantics and their relationships are preserved in IR-Graphs for further inference rather than deep features. After converting an input image to IR-Graph, the matcher finds the most similar category-level graph in IR-Atlas. The overall procedure is shown in Figure 3 , which is composed of a GCN module and a similarity computing module (Sim) that generates the final prediction. Detailed settings of the matcher are in Appendix E.3.

2.3. MATCHER

For feeding IR-Graphs to the GCN module, we assign each ingredient m ∈ M (i.e., graph vertices) with a trainable embedding vector, each of which is initialized from an d G -dimensional random vector drawn independently from multivariate Gaussian distribution N (0, I d G ). In the GCN module, we adopt GraphConv (Morris et al., 2019) with slight modifications for weighted edges. Let F ∈ R |V |×d G be the input feature of all vertices to a GraphConv layer, the output is computed as GraphConv(F ) = Norm (σ ((I d G + E)F W )) , ( ) where σ denotes a non-linear activation function, Norm denotes feature normalization, and W ∈ R d G ×d G is a learnable projection matrix. After passing through total L G layers of GraphConv, a weighted average pooling (WAP) layer summarizes the vertex embeddings F (L G ) weighted by Λ, yielding the graph representation z = ΛF (L G ) . Now we have z for G and Ẑ ∈ R C×d G = (ẑ 1 , . . . , ẑC ) for all graphs in Ĝ. By computing the inner-product similarities, the final prediction logits is defined as y = Ẑz ⊤ . Prediction interpretation. Given an instance graph G and a category graph Ĝ, we now analyze how the prediction is made based on their vertices and edges. We start with analyzing the similarity score s = ẑz ⊤ , where z = ΛF (L G ) and ẑ = Λ F (L G ) are weighted sum of the vertex embeddings respectively. For the sake of further discussion, let f (l) v denotes the embedding vector of vertex v output from the l-th GraphConv layer, and λ (l) v denotes its weight. The original vertex embedding of v is defined as f (0) v , which is identical for the same ingredient in any two graphs. Besides, let Φ = V ∩ V be the set of shared vertices in G and Ĝ. The computation of s can be expanded as s = (u,v)∈ V ×V λu λ v f (L G ) u f (L G ) v ⊤ . ( ) Since directly analyzing Equation ( 8) is hard, we show an approximation and proof it in Appendix B. Theorem 1. For a shallow GCN module, Equation (8) can be approximated by s = ϕ∈Φ λϕ λ ϕ f (L G ) ϕ f (L G ) ϕ ⊤ . Particularly, if L G = 0 and α 1 = 0, our method is equivalent to BoVW with a linear classifier. Further, as delineated in Corollary 1, the interpretability of the graph matcher with a shallow GCN can be stated as: (1) the final prediction score for a category is the summation of all present class evidence (shared vertices ϕ in Φ) represented by the vertex weights; (2) the shared neighbors (local structure) connected to the shared vertex ϕ also contribute to the final prediction.

2.4. TRAINING SCHEMANET

Except for training SchemaNet by cross-entropy loss L CE between the predictions and ground truths, we further constrain the complexity of IR-Atlas. More formally, the complexity is defined for both edges and vertex weights of all graphs in IR-Atlas as L v = 1 C C c=1 H( Λc ), L e = 1 C| V | C c=1 u∈ V H( Êc,u ), where function H(x) computes the entropy of the input vector x ∈ R k + normalized by its sum of k components, and Êc,u denotes the weighted edges connected to vertex u in Ĝc of class c. The final optimization goal is Sparsification. Each graph in IR-Atlas is initialized as a fully-connected graph with random vertex and edge weights. However, this requires O(C|V | 2 ) memory space for storage and training the whole set of edges. To alleviate this issue, we initialize IR-Atlas by averaging the instance IR-Graphs for each class and remove the edges connected to vertices whose weights are below a given threshold δ t = 0.01 from the category-specific graph. Such a procedure will not only dramatically decrease the learnable parameters but also boost the final performance, as shown in Table 1 . L = L CE + γ v L v + γ e L e ,

3.1. IMPLEMENTATION

Datasets. We evaluate our method on CIFAR-10/100 (Krizhevsky et al., 2009) , Caltech-101 (Li et al., 2022) , and ImageNet (Deng et al., 2009) . Particularly, Caltech-101 has around 9k images, and we manually split it into a training set with 7.4k images and a test set with 1.3k images. Table 1 : Comparison results on CIFAR-10/100 and Caltech-101 with different backbones. We report the top-1 accuracy, number of learnable parameters, and FLOPs. The results in gray color are merely for listing the accuracy of ViTs rather than for making comparisons. Selection of the hyperparameters. Several hyperparameters are involved in our method, including λ v,e in Equation ( 11) for adjusting the sparsity of IR-Atlas. We set λ v = 0.5 and λ e = 0.75 as the default value for the following evaluation and the sensitive analyses are presented in Appendix H. The initial values of learnable α 1,2 and β 1,2 are set to 0.5, and we show their learning curves in Figure 9 .

3.2. EXPERIMENTAL RESULTS

We evaluate our method and comparison baselines with the following settings: • BoVW-SIFT: the traditional BoVW approach that utilizing SIFT feature for constructing the visual vocabulary (Yang et al., 2007) . • Base: the base ViT model directly trained on the benchmark datasets with initial weights obtained from the official repository in (Touvron et al., 2021) , which is our backbone. • Backbone-FC: the frozen backbone with an FC layer. The intermediate features extracted from the frozen backbone are then fed to a global average pooling layer followed by a linear classifier, similar to the standard CNN protocol. • BoVW-Deep: the BoVW approach with our extracted visual vocabulary. • BagNet: the implementation of BagNet (Brendel & Bethge, 2019) with ViT backbone which is constructed following our "Base" setting. • SchemaNet: our proposed SchemaNet without initialization. • SchemaNet-Init: our proposed SchemaNet initialized by the average instance IR-Graphs. The comparison results are shown in Table 1 , where we demonstrate the top-1 accuracy, the number of learnable parameters, as well as the FLOPs for each setting. It is noticeable that the proposed SchemaNet consistently outperforms the baseline methods. Specifically, for the backbone of DeiT-Tiny, ours achieves significant improvement (i.e., about 25.7% and 35.3% absolute gain on CIFAR-10 and CIFAR-100, respectively) over BagNet. With a larger backbone such as DeiT-Small and DeiT-Base, though the performance gap slightly shrinks, our SchemaNet with initializations still yields results superior to BagNet, by an average absolute gain of 12.7%. The reason is that, unlike BagNet that uses summation to obtain the similarity map as the BoVW representation (without feature discretization), "BoVW-Deep" is implemented with the discretized visual words. As such, both of Figure 4 : Accuracy decay curves of perturbation tests. The curve name ending with "+"/"-" represents positive/negative perturbation, accordingly. the approaches do not consider the visual word interactions, leading to the inferior performance especially when the size of the visual vocabulary expands on CIFAR-100 and Caltech-101. Moreover, compared with "Backbone-FC" that merely relies on intermediate deep features, the proposed SchemaNet also demonstrates encouraging results, achieving about a 1.4% absolute gain. Furthermore, when initialization and redundant vertices are removed as indicated in Section 3.1, our "SchemaNet-Init", with an even less learnable number of parameters, still leads to a higher accuracy. The proposed method also delivers gratifying performance, on par with the prevalent deep ViT especially on CIFAR-10, but gives a clear insight on the reasoning procedure. We present more experimental results in Appendix F, including the comparison results on Ima-geNet (Appendix F.1), employing other attribution methods (Appendix F.2), attacking with adversarial images (Appendix F.3), the extendability analysis of our SchemaNet (Appendix F.4), and the inference cost of all components (Appendix F.5).

3.3. EVALUATION OF THE INTERPRETABILITY

To figure out how interpretability contributes to the model prediction, we adopt the positive and negative perturbation tests presented in (Chefer et al., 2021) to give both quantitative and qualitative evaluations on the interpretability. For a fair comparison, the attention values to the CLS token ϕ CLS are extracted as the pixel relevance for BoVW-Deep, BagNet, and SchemaNet-Init with the DeiT-Tiny backbone. During the testing stage, we gradually drop the pixels in relevance descending/ascending order (for positive and negative perturbation, respectively) and measure the top-1 accuracy of the models. We plot the accuracy curves in Figure 4 , showing that: (1) in the positive perturbation test, our method behaves similarly to BoVW while outperforming BagNet; (2) in the negative perturbation test, our method achieves better performance by a large margin (the area-under-the-curve of ours is about 5.71% higher than BagNet on CIFAR-10, and 11.23% higher than BagNet on CIFAR-100).

3.4. VISUALIZATION

In Figure 5 , we show the examples including instance IR-Graphs and learned IR-Atlas with DeiT-Tiny backbone trained on Caltech-101 dataset (for visualization purposes, the number of ingredients is set to 256, which are shown entirely in Figure 11 ). Each row shows several instances of a specific class. The first column includes the category graphs and excerpts of appeared ingredients in Figure 11 for quicker reference. We further render the vertices, i.e., ingredients, in the category graph with different colors, and the appeared ingredients in the instance images and graphs are colored uniformally corresponding to the category graph. We interpret the visualization from three perspectives: the interpretability of the ingredients, the interpretability of the edges, and the consistency between instance and category graphs. (1) Thanks to the powerful representation learning capability of DNNs, the extracted ingredients are able to represent explicit semantics (such as "fuselage", "bird legs", etc.) with robustness. Besides, the 

A ADDITIONAL RELATED WORK

In this section, we introduce literature that is related to our method. A.1 VISUAL CONCEPT ANALYSIS Visual concept analysis aims to interpret the behavior of pre-trained deep models by extracting and tracking intuitive visual concepts during the DNN's inference procedure. Previous works use masks (Zhang et al., 2020b) , explanatory graphs (Zhang et al., 2020a) and probabilistic models (Konforti et al., 2020) to interpret the internal layers of convolutional neural networks (CNNs) (Liu et al., 2022; Yu et al., 2023) . Further, Deng et al. (2022) propose to capture and analyze the interaction of visual concepts contributing cooperatively to the prediction results for CNNs. However, they use Shapley values to compute the interaction, which is computationally expensive. Consequently, we adopt vision Transformers (ViTs) (Dosovitskiy et al., 2021; Touvron et al., 2021) as the DNN backbone of our method instead of conventional CNNs thanks to the multi-head self-attention (MHSA) mechanism (Vaswani et al., 2017) that explicitly encodes the interactions of visual tokens in the similarity level. More recently, VRX (Ge et al., 2021) is proposed to use a graphical model to interpret the prediction of a pre-trained DNN, in which edges are constructed to represent the spatial relation of visual semantics. They further utilize a GCN module to make a prediction with the input structural concept graphs by training to mimic DNN's prediction. Nevertheless, their approach requires training a graph for every input sample, which is extremely expensive for applying to large-scale datasets. Besides, VRX does not construct the class imagination that carries the category-specific knowledge, meaning that the GCN predictor must learn to memorize and distinguish all the class-specific graphs, which impairs the overall interpretability. Our proposed schema inference, however, explicitly creates an interactive matching procedure between the instances and category imagination. Moreover, our GNN model only captures the local structure for gathering class evidence, which is easier to accomplish.

A.2 VISUAL CONCEPT LEARNING

Visual concept learning refers to a range of methods that mine semantic-rich visual words from a DNN, which are further utilized to make interpretable predictions towards a self-explanatory model in reality. Concept bottleneck models (CBMs) (Koh et al., 2020; Zarlenga et al., 2022; Deng et al., 2022; Wong & McPherson, 2021 ) link the neurons with human-interpretable semantics explicitly, encouraging the trustworthiness of DNNs. Human interventions of learned bottleneck layers can fix the misclassified concepts to improve the model performance. Another line of visual concept learning, part-prototype-based methods, collectively makes predictions on target tasks with DNN and semantic-rich prototypes, which can be divided into two schools according to their prototypes. (1) Bag-of-Visual-Word (BoVW) approaches (Brendel & Bethge, 2019; Gidaris et al., 2020; Tripathi et al., 2022) obtain prototypes with semantics through passively clustering the hand-crafted features (Yang et al., 2007) or deep features into a set of discrete embeddings (visual words) that relate to specific visual semantics. The prediction procedure of these approaches is analog to the BoW model in natural language processing (NLP), in that an interpretable representation is constructed based on the statistics of the occurrence of the visual words, and then fed to an interpretable classifier, such as the linear classifier or decision trees. (2) In contrast, part-prototypical networks (ProtoPNets) (Chen et al., 2019) and the following works (Nauta et al., 2021; Xue et al., 2022a; Zhang et al., 2022; Peters, 2022; Rymarczyk et al., 2021) jointly train DNN (as the feature extractor) and parameterized prototypes. Then decisions are made based on the linear combination of similarity scores between prototypes and feature vectors at all the spatial locations. Despite the relatively high performance, as discussed in (Brendel & Bethge, 2019; Hoffmann et al., 2021) , the similarity of the learned prototypes and deep features in the embedding space may be significantly different in the input space. As such, we only compare those approaches in which prototypes are generated based on clustering for the same level of interpretability. To the best of our knowledge, none of the existing works have explored a deductive inference paradigm based on the interaction of visual semantics.

A.3 SCENE GRAPH GENERATION

A scene graph can be generated from an input image to excavate the collection of objects, attributes, and relationships in a whole scene. Typical graph-based representations and learning algorithms (Herzig et al., 2018; Chang et al., 2021; Shit et al., 2022; Zhong et al., 2021; Zareian et al., 2020) adopt graph neural networks (GNNs) (Kipf & Welling, 2016; Jing et al., 2021a; Yang et al., 2020b; Jing et al., 2021b; Yang et al., 2020a; Jing et al., 2022; Yang et al., 2019) to model the relationships between objects in a scene. Similar to scene graphs, our proposed SchemaNet also utilizes GNNs to represent the relationships between ingredients related to local semantics. However, the nodes in GNNs of SchemaNet represent more fine-grained representations, i.e., local semantics of objects, compared to those relating to the whole objects in scene graphs. Moreover, in SchemaNet, the trained GNN is responsible for estimating the similarity between an instance IR-Graph and category IR-Graphs and making classification based on the similarity scores.

A.4 FEATURE ATTRIBUTION FOR EXPLAINABLE AI

In the "Feat2Vertex" module, we extract the attention to the CLS token as one important component of the ingredient importance. Currently, plenty of approaches have been proposed for indicating local relevance (forming saliency maps) to the DNN's prediction, termed feature attribution methods. Most existing works can be roughly divided into three classes: perturbation-based, gradient-based, and decomposition-based approaches. Perturbation-based approaches (Strumbelj & Kononenko, 2010; Zeiler & Fergus, 2014; Ancona et al., 2019; Zintgraf et al., 2017) compute attribution by evaluating output differences via removing or altering input features. Gradient-based approaches (Simonyan et al., 2014; Shrikumar et al., 2017; Sundararajan et al., 2017; Selvaraju et al., 2017; Feng et al., 2022) compute gradients w.r.t. the input feature through backpropagation. Decomposition-based approaches (Bach et al., 2015; Montavon et al., 2017; Chefer et al., 2021) propagate the final prediction to the input following the Deep Taylor Decomposition (Montavon et al., 2017) . Besides, CAM (Zhou et al., 2016) and ABN (Fukui et al., 2019) provide interpretable predictions with a learnable attribution module. As most previous methods focus on CNNs, Chefer et al. (2021) propose ViT-LRP tailored for vision Transformers. However, most of the methods mentioned above are designed to generate a saliency map for a particular class, making them inefficient in our implementation due to the O(C) computational complexity or dependency on gradients.

B PROOF OF THEOREM 1

We first restate the theorem: Theorem (Theorem 1 restated). For a shallow GCN module, Equation (8) can be approximated by s = ϕ∈Φ λϕ λ ϕ f (L G ) ϕ f (L G ) ϕ ⊤ . Particularly, if L G = 0 and α 1 = 0, our method is equivalent to BoVW with a linear classifier. For further discussion, we first prove the following lemma: Lemma 1. For random vectors f , g ∈ R d G drawn independently from multivariate Gaussian distribution N (0, I), we have E f,g f W ⊤ W g ⊤ = 0 E f f W ⊤ W f ⊤ = ∥W ∥ 2 F , where ∥ • ∥ F is the matrix Frobenius norm, and W ∈ R d G ×d G is a projection matrix. Proof. To begin with, we expand term f W ⊤ W g ⊤ as f W ⊤ W g ⊤ = d G r=1 d G s=1 d G t=1 f s g t w r,s w r,t . Therefore, the expectation while E f,g f W ⊤ W g ⊤ = d G r=1 d G s=1 d G t=1 w r,s w r,t E fs [f s ] E gt [g t ] = 0, E f f W ⊤ W f ⊤ = d G r=1 E f d G s=1 d G t=1 w r,s w r,t f s f t = d G r=1 E f d G s=1 w 2 r,s f 2 s = d G r=1 d G s=1 w 2 r,s = ∥W ∥ 2 F . Particularly, if W is an identity matrix E f f W ⊤ W f ⊤ = d G . ( ) Further, if the components of matrix W are drawn i.i.d. from Gaussian distribution N (0, 1) and are independent with f , E f,W f W ⊤ W f ⊤ = d 2 G . ( ) Now we proof Theorem 1. Proof. We start with a simple case that the depth of GCN module is zero. As G and Ĝ share the same original vertex embeddings, we have s = (u,v)∈ V ×V λu λ v f (0) u f (0) v ⊤ = (u,v)∈ V ×V λu λ v f (0) u f (0) v ⊤ . ( ) According to Lemma 1, the expectation of the same vertex term E f f ⊤ = d G is more significant than different vertices terms λu λ v f (0) u f (0) v ⊤ for u ̸ = v when G and Ĝ are constrained with sparsity. In Figure 6 (a), we visualize the vertex similarities f (0) u f (0) v ⊤ , showing that the similarities between different vertices are significantly lower than that of the same vertex. Consequently, Equation (8) in the case of L G = 0 can be simplified as s = ϕ∈Φ λϕ λ ϕ ∥f (0) ϕ ∥ 2 2 , ( ) where Φ = V ∩ V is the set of shared vertices of G and Ĝ. With an extra assumption that α 1 = 0, λ ϕ therefore represents the count of visual word ϕ. Combing the learnable term λϕ ∥f (0) ϕ ∥ 2 2 as w ϕ , the prediction of the matcher is y = ΛW FC , where W FC ∈ R |V |×C is the learnable projection matrix of a linear classifier. Now, we prove the whole theorem. For different vertices u ∈ V and v ∈ V , the summation term s u,v = λu λ v f (L G ) u f (L G ) v ⊤ can be rewritten as the aggregation form with feature output from the (L G -1)-th layer: s u,v = λu λ v (i,j)∈ Ṅ Ĝ(u)× ṄG (v) êu,i e v,j f (L G -1) i W (L G ) W (L G ) ⊤ f (L G -1) j ⊤ , ( ) where Ṅ Ĝ(u) represent the neighbors of vertex u in Graph Ĝ and with u itself, and e u,u = 1 as the self loop. With the assumption that f (L G -1) i and f (L G -1) j are whitened random vectors (we say a random vector x ∈ R d is whitened if all the components are zero mean and independent from each other, which can be achieved by normalization), the summation term f (L G -1) i W (L G ) W (L G ) ⊤ f (L G -1) j ⊤ with vertex i = j will be significantly larger than those with different vertices. We further illustrate the vertex similarities output from the final layer in Figure 6  (b), showing that f (L G ) u f (L G ) v ⊤ is conspicuous for vertices u = v. Corollary 1. Consequently, s u,v is significant if Ṅ Ĝ(u) and ṄG (v) have shared vertices, relative large edge weight product êu,i e v,j , and vertex weight product λu λ v , which means u and v have similar local structure, particularly in the case that u and v are the same vertex. In conclusion, Equation ( 8) can be simplified as the summation of the joint vertices Φ = V ∩ V of the instance and category graph G and Ĝ.

C TRAINING ALGORITHM

The training algorithm of our proposed SchemaNet is shown in Algorithm 1 with initializing IR-Atlas and sparsification.

D EFFECTIVE RECEPTIVE FIELD OF VITS

In this section, we discuss the effective receptive field (ERF) of ViTs, which is crucial as the visual token in the intermediate layers of ViTs may relate to other positions due to MHSA, affecting the interpretability of the ingredients (visual word semantics). Although Luo et al. (2016) propose to measure the ERF for CNNs, it cannot be directly implemented to Transformer-base models. In this section, we propose a relatively simple yet effective approach to measuring the ERF of ViTs. Definition 3 (Transformer ERF). Supposed that visual sequence X ∈ R n×d is fed into a ViT backbone with N layers, x * is a randomly chosen anchor token in X, and ε ∈ R d is the random vector drawn from Gaussian distribution N (0, I) normalizing to the unit vector. The Transformer ERF is the average Euclidean distance between the closest token y to x * in the 2D token array and x * so that when y is disturbed to ŷ = y + ϵ the output change of token x * is less than a predefined threshold δ r > 0. end for 20: G ← FEAT2GRAPH(Backbone(x), Ω) 21: ŷ ← Matcher(G, Ĝ) 22: Compute the final loss and gradient ∇ Θ w.r.t. parameters Θ.

23:

Update parameters Θ with AdamW optimizer. Figure 7 shows the visualization heatmap w.r.t. the change of the output anchor with different ViT backbones adopted in our method. As we can observe, for an input image size of 224 × 224, the output tokens with relatively significant changes are mainly distributed in the circle with a radius of 25. Therefore, in our method, all the ingredients are in charge of 50 × 50 patches corresponding to the input image.

E EXPERIMENTAL DETAILS

E.1 VISUAL VOCABULARY  Λ norm = Λ |V | i=1 λ i . ( ) The adjacency matrix E is divided by the row-wise summation: E norm =      E1,: |V | i=1 e1,i . . . E |V |,: |V | i=1 e |V |,i      . ( ) Finally, the symmetric adjacency matrix E sym is defined as E sym = 1 2 (E norm + E ⊤ norm )

E.3 GCN SETTINGS

Now we describe the detailed settings of the GCN module in the graph matcher. We adopt Lay-erNorm (Ba et al., 2016) as Norm(•) function and rectified linear unit (ReLU) as the activation in Equation ( 7). Besides, the embedding dimension d G is set to 256 for all the experiments.

E.4 TRAINING DETAILS

The matcher and IR-Atlas in our method are optimized by AdamW (Loshchilov & Hutter, 2019) with a learning rate of 10 -3 , weight decay of 5 × 10 -4 , and cosine annealing as the learning rate decay schedule. We implement our method with Pytorch (Paszke et al., 2019) and train all the settings for 50 epochs with the batch size of 64 on one NVIDIA Tesla A100 GPU. All input images are resized to 224 × 224 pixels before feeding to our SchemaNet. We adopt ResNet-style data augmentation strategies: random-sized cropping and random horizontal flipping.

F ADDITIONAL RESULTS

This section contains additional experimental results, highlighting the efficiency, robustness, and extendability of our proposed schema inference.

F.1 RESULTS ON IMAGENET

We further implement SchemaNet on mini-ImageNet (Vinyals et al., 2016) and ImageNet-1k (Deng et al., 2009) . We adopt DeiT-Small as the backbone, and the visual vocabulary size is set to 1024 for mini-ImageNet and 8000 for ImageNet-1k (due to the memory constraint). Particularly, as implementing 1000 fully-connected category-level IR-Graphs is expensive, we keep at most 500 valid vertices for each class while all other vertices are pruned during the training process based on the initialized vertex weights. All other settings are identical to the experiments on Caltech-101. The results are presented in Table 3 , drawing consistent conclusions as the main results in Table 1 . 3), we further employ feature attribution methods for computing the ingredient importance. Specifically, we compare the results using raw attention, CAM (Zhou et al., 2016) , and Transformer-LRP (Chefer et al., 2021) in Table 4 . We can observe that using raw attention is superior to the attribution methods in terms of accuracy and running time. Such results can be explained that: (1) the attribution methods compute a saliency map for each class (particularly, ViT-LRP conduct C-times backpropagation, consuming enormous time), which is significantly slower than our implementation; (2) further, as the backbone predicts soft probability distributions other than one-hot targets, the computed saliency maps for similar categories will be highlighted to some extent, impairing the graph matcher. 

F.3 ADVERSARIAL ATTACKS

To analyze the robustness of our proposed SchemaNet, we evaluate the pre-trained SchemaNet (from ImageNet-1k described in Appendix F.1) on two popular adversarial benchmarks: ImageNet-A (Djolonga et al., 2021) and ImageNet-R (Hendrycks et al., 2021) . The comparison results are shown in Table 5 , revealing that our method is more robust than the baselines.  F.4 EXTENDABILITY We evaluate the extendability of our method by extending a trained SchemaNet to unseen tasks, while keeping the original framework frozen. As such, only the category-level IR-Graphs for the new tasks are optimized and inserted into the original IR-Atlas. Specifically, the tasks are disjointly drawn from Caltech-101 dataset. The "Base" task has 21 classes, and task 1 to 4 has 20 classes. The backbone model, i.e., DeiT-Tiny, is trained on the "Base" task and then is kept unchanged for the new tasks. The experimental results are presented in Table 6 , revealing limited performance degradation for incoming new tasks. Such results show that the graph matcher is only responsible for evaluating the graph similarity, while the category knowledge is stored in the imagination, i.e., IR-Atlas. 

F.5 FEAT2GRAPH EVALUATION

We here present a performance bottleneck of our method, i.e., the Feat2Graph module, in which a feature ingredient array is converted to IR-Graph without benefiting from GPU parallel acceleration. Specifically, as duplicated visual words may exist after the feature discretization, which happens from time to time, we implement this module in C++ to achieve better random access performance. The experiments are conducted on one NVIDIA Tesla A100 GPU platform with AMD EPYC 7742 64-Core Processor. In Table 7 , we show the average running time of each component with an input batch size of 64. We can see that without acceleration in parallel, the running time for "Feat2Edge" is significantly longer than others (about 1.75 ms per image). However, the inference time is still acceptable for real-time applications. Figure 8(a) shows SchemaNet accuracy when using different sizes of visual vocabulary. We can observe that even with a relatively small size, e.g., M = 256 on CIFAR-100, the performance is still competitive (about 7% absolute degradation). Besides, for the case of M = 2048, however, the performance suffers from low generalizability of the vocabulary. 

G.2 NUMBER OF GRAPHCONV LAYERS

The effect of using different depths of GCN is illustrated in Figure 8(b) . As L G increases to 5, the accuracy on CIFAR-100 decreases rapidly, yielding that deep GCN matcher impairs the graph matching performance, which has been explored due to the over-smoothing caused by many convolutional layers (Zhao & Akoglu, 2020; Li et al., 2018) . In our method, however, the graph matcher with only two layers of GraphConv is proven to be adequate for high performance and interpretability.

G.3 ABLATION OF WEIGHT COMPONENTS

Table 8 : Ablation study of the components defining the vertex and edge weights with DeiT-Tiny as the backbone DNN.

CIFAR-10 CIFAR-100

Learnable α 1,2 , β 1,2 95.96 78.20 Fixed α 1 = 0, α 2 = 1 95.88 75.91 Fixed α 1 = 1, α 2 = 0 95.68 72.17 The effect of λ CLS and λ bag defined in Equation (3), and e attn and e adj defined in Equation ( 6) are shown in Table 8 . By setting the corresponding weight (α 1,2 and β 1,2 ) to zero once at a time, we are able to analyze the components individually. In general, removing any term will lead to varying degrees of performance degradation. More significantly, when removing the term λ bag , the accuracy decreases by 6.03% on CIFAR-100, revealing that the statistics of the ingredients help filter noisy vertices. The learning curves of the learnable weights are shown in Figure 9 , from which we can observe that the model tend to adopt a relatively larger α 2 (for visual word count) while keep λ CLS v term for filtering noisy and background ingredients.

H SENSITIVITY ANALYSIS OF HYPERPARAMETERS

H.1 SPARSIFICATION THRESHOLD Table 9 presents the SchemaNet-Init top-1 accuracy trained with different sparsification threshold δ t . We can observe that an appropriate value of δ t , e.g., 0.01, will boost the model performance. 11) that constrains the complexity of IR-Atlas. We evaluate the top-1 accuracy performance on the CIFAR-100 dataset with DeiT-Tiny as the backbone. The curve of λ v shows that neither dense nor extreme sparse IR-Atlas impairs the performance, while our method is more robust when changing λ e .

I MORE VISUALIZATIONS I.1 VISUALIZATION AND ANALYSIS OF MISCLASSIFIED EXAMPLES

Figure 10 shows two misclassified examples along with the classification evidence. In Figure 10(a) , the object is misclassified to a fine-grained category "rooster" because of a noise ingredient #240, which should be #103. Unfortunately, ingredient #240 is a crucial vertex in the rooster's IR-Graph, contributing about 9.5 absolute gains in the logit, leading to misclassification. Figure 10 (b) shows a more complicated example. We can observe that instead of discretizing the appeared human face to the "face" ingredients, the backbone provides features closer to #17, which is more similar to the animal's face. Moreover, some part of the object's body is assigned to #240 rather than the panda's body (#204). Thus, it creates a remarkable pattern, i.e., the interaction between #17 and #240, which is the crucial local structure in the llama's graph. As a result, #240 and its local structure contribute about 14.9 gains in the logit. Although our schema inference framework is capable of revealing why an image is misclassified by highlighting the key points, in future work, we must explore a more compatible feature extractor that can generate more robust local features.

I.2 VISUALIZATION OF INGREDIENTS

In Figure 11 , we visualize the whole set of ingredients on Caltech-101 for visualizations in Figure 5 and Figures 12 to 14. We extract 256 ingredients on Caltech-101 for a better view. For each cluster center generated from k-means clustering, we select the top-40 tokens that are closest to it and show the corresponding image patch (50 × 50 in pixel, delineated in Appendix D). We can observe that image patches of the same cluster share the same semantics. 



For probe dataset with D instances, the collection X has n × D visual tokens (ignoring CLS and DIST). CONCLUSION AND OUTLOOKIn this paper, we propose a novel inference paradigm, named schema inference, guided by Kant's philosophy towards resembling human deductive reasoning of associating the abstract concept image with the specific sense impression. To this end, we reformulate the traditional DNN inference into a graph matching scheme by evaluating the similarity between instance-level IR-Graph and category-level imagination in a deductive manner. Specifically, the graph vertices are visual semantics represented by common feature vectors from DNN's intermediate layer. Besides, the edges indicate the vertex interactions characterized by semantic similarity and spatial adjacency, which facilitate capturing the compositional contributions to the predictions. Theoretical analysis and experimental results on several benchmarks demonstrate the superiority and interpretability of schema inference. In future work, we will implement schema inference to more complicated vision tasks, such as visual question answering, that enables linking the visual semantics to the phrases in human language, achieving a more powerful yet interpretable reasoning paradigm.



Figure 1: An example showing how an instance IR-Graph is matched to the class imagination. The vertices of IR-Graphs represent visual semantics, and the edges indicate the vertex interactions. The graph matcher captures the similarity between the local structures of joint vertices (e.g., vertex #198 and vertex #150) by aggregating information from their neighbors as class evidences. The final prediction is defined as the sum of all evidence.

Figure 2: The overall pipeline of our proposed SchemaNet. Firstly, intermediate feature X from the backbone is fed into the Feat2Graph module. The converted IR-Graph is then matched to a set of imaginations, i.e., category-level IR-Atlas induced from observed data, for making the prediction.

Figure 3: Illustration of the matcher.

11) where γ v and γ e are hyperparameters. The overall training procedure is shown in Appendix C.

Figure 5: Examples of the learned IR-Atlas and instance IR-Graphs randomly sampled from three categories ("panda", "flamingo", and "airplanes") on Caltech-101. Please zoom in for a better view.

Figure 6: Vertex embedding similarities between a vertex v in the instance IR-Graph and a vertex u in the category IR-Atlas, which are extracted from the original vertex embeddings and output from the last GraphConv layer. We show the average similarity over 1000 samples on the Caltech-101 dataset with a visual vocabulary size of 256.

SchemaNet optimizer with initialization and sparsification. Input: D = {(x i , y i )} D i=1 : the training dataset with D samples; Ĝ = { Ĝc } C c=1 : initial IR-Atlas; Backbone(•): the ViT backbone; Matcher(•, •): the graph matcher; Ω: the visual vocabulary; Θ: the set of all trainable parameters; δ t : the sparsification threshold. 1: procedure INITIALIZATION(D, Ω) 2: Sample a subset D ⊂ D as the probe dataset.

procedure 13: procedure TRAINING( Ĝ, D, Ω) 14: for (x, y) ∈ D do 15: for êi,j ∈ Ê do ▷ Removing redundant edges 16:if λi < δ t or λj < δ t then

Figure 7: The heat maps of the receptive field of an anchor token over its neighborhood region. Averaged visualization results of 64 random images of three ViTs (DeiT-Ti, DeiT-S, and DeiT-B) from Caltech-101 dataset are demonstrated. The full image size is 224 × 224, and we crop a small region with a size of 96 × 96 with the anchor token as the center for better visualization.

SchemaNet accuracy when using GCN with various depth LG. Sensitivity analysis of hyperparameters involved in Equation (11).

Figure 8: Ablation study and sensitivity analysis of hyperparameters of our proposed SchemaNet on CIFAR-100 backed with DeiT-Tiny.

Figure 9: Learning curves of α 1,2 and β 1,2 with DeiT-Tiny backbone on Caltech-101 dataset.



Figure 10: Illustration of misclassified samples. Better view in color.

Figure12: Examples of the learned IR-Atlas and instance IR-Graphs randomly sampled from five categories ("bass", "bonsai", "emu", "euphonium", and "grand piano") on Caltech-101.

Figure13: Examples of the learned IR-Atlas and instance IR-Graphs randomly sampled from five categories ("hawksbill", "ibis", "kangaroo", "leopards", and "llama") on Caltech-101.

Figure14: Examples of the learned IR-Atlas and instance IR-Graphs randomly sampled from five categories ("okapi", "rooster", "sunflower", "water lilly", and "wild cat") on Caltech-101.

Visual vocabulary size.Before an instance-level graph G and a category-level graph Ĝ are fed to the graph matcher, their vertex weights and edge weights are normalized as follows. The vertex weights in Λ are divided by the sum of Λ

Comparison results (top-1 accuracy) on ImageNet-1k and mini-ImageNet.

Comparison results (top-1 accuracy and running time) of using different attribution methods on CIFAR-10 and CIFAR-100 datasets.

Adversarial attack results (top-1 accuracy) on ImageNet-A and ImageNet-R datasets.

Top-1 accuracy results when extending SchemaNet to unseen tasks. The accuracies are evaluated on the corresponding task, and the average accuracy over all the tasks is given as well.

Time costing (ms) of the components in SchemaNet with input batch size of 64.

Sensitivity analysis of sparsification threshold δ t .

I.3 VISUALIZATION OF IR-ATLAS AND INSTANCE IR-GRAPHS We provide more visualization examples on Caltech-101 dataset, shown in Figures 12 to 14.Figure 11: Ingredient visualization on Caltech-101 dataset. Please zoom in for a better view.

ACKNOWLEDGMENTS

This work is supported by National Natural Science Foundation of China (61976186, U20B2066, 62106220), Ningbo Natural Science Foundation (2021J189), the Starry Night Science Fund of Zhejiang University Shanghai Institute for Advanced Study (Grant No. SN-ZJU-SIAS-001), Open Research Projects of Zhejiang Lab (NO. 2019KD0AD01/018) and the Fundamental Research Funds for the Central Universities (2021FZZX001-23).

