JOINT GAUSSIAN MIXTURE MODEL FOR VERSATILE DEEP VISUAL MODEL EXPLANATION

Abstract

Post-hoc explanations of deep neural networks improve human understanding on the learned representations, decision-making process and uncertainty of the model with faithfulness. Explaining deep convolutional neural networks (DCNN) is especially challenging, due to the high dimensionality of deep features and the complexity of model inference. Most post-hoc explaining methods serve a single form of explanation, restricting the diversity and consistency of the explanation. This paper proposes joint Gaussian mixture model (JGMM), a probabilistic model jointly models inter-layer deep features and produces faithful and consistent post-hoc explanations. JGMM explains deep features by Gaussian mixture model and inter-layer deep feature relations by posterior distribution on the latent component variables. JGMM enables a versatile explaining framework that unifies interpretable proxy model, global or local explanatory example generation or mining. Experiments are performed on various DCNN image classifiers in comparison with other explaining methods. It shows that JGMM can efficiently produce versatile, consistent, faithful and understandable explanations.

1. INTRODUCTION

Deep convolutional neural networks (DCNN) is a powerful type of machine learning model for visual recognition tasks. The key reasons give rise to the power of DCNN include the expressive visual representations and the decision-making mechanism encoded in massive trainable convolution parameters. However, increasing model complexity incurs heavier burden for human to understand learned representations and decision making of the model. The high dimensionality and entanglement of deep features and the complexity of neural network inference are often considered the main hindrance to explaining black-box DCNN. A recent proliferation of studies in post-hoc DCNN explainability show several effective and practical DCNN explaining methods. Proxy models are interpretable models (e.g. decision trees and linear models) that has approximate decision-making behavior as the black-box model inference. Proxy models have inference process that can be intuitively understood by human, such as LIME (proposed byRibeiro et al. (2016) ), a linear classifier as a local proxy model. To make sure the proxy model is an accurate surrogate, its faithfulness should be tested. The proxy model's predictions on unseen examples should be close to the black-box model's prediction, even if different from the ground truth. Most DCNN explaining methods are single-purpose systems. Employing different explaining methods is possible to give diverse explanations, but it's not guaranteed that different explanations are compatible and consistent with each other. For example, a counterfactual example generation system may suggest that the model is sensitive to a certain feature; but a global explaining method, such as a proxy model, may have conflicts with the former explanation. There is no hard rule to determine which explanation is correct or more understandable. Thus a generic explaining framework that enables various and consistent explaining forms has important value for the explainability of DCNN. Higher Feature Y Lower Feature X Black-box Model(part) Joint GMM Step 1: Probabilistic Modelling Proxy Model: P(Y|X)

Global Explanations

Local Explanations

Interpretability and Faithfulness Validation

Example: Prototypes, Criticisms, Influentials Example: Counterfactuals and Semifactuals Step 2: Explaining We propose a probabilistic framework for a versatile explaining method for DCNN. Figure 1 demonstrates the pipeline of the framework and the enabled explanation forms. The lower and the higher features are intermediate representations of DCNN from two different layers. In the computational graph of DCNN, the higher feature is dependent on the lower feature and the black-box model (part of the DCNN) between them. In an image classification setting, the higher feature can be the classification probabilistic prediction of the black-box model, and the lower features can be the input of the DCNN, i.e. the raw image. The goal of the framework is to explain the learned representations of higher/lower features as well as the black-box model (either the whole DCNN or a part of DCNN) between them. 

3. METHOD

3.1 JOINT GAUSSIAN MIXTURE MODEL We propose joint Gaussian mixture model, a probabilistic model base on GMM to estimate the posterior distribution connecting the latent variables behind the deep features. Figure 2 illustrates the model in plate notation. The model includes three parts: the GMM on observation Y, the conditioned GMM on observation X, and the probabilistic model between discrete latent variables z and w. x n N µ i Σ i z n y n N ϕ j Θ j w n π Cat Cat q k N Kx Ky Ky The model on observation Y resembles the classical GMM. It includes the observed D y -dimension real-value higher features Y ∼ N (ϕ w , Θ w ). Data point y n is the n-th of N sampled higher features. The latent categorical variable w ∈ {1, ..., K y } decides which Gaussian component is Y generated from. Hyper-parameter K y is the number of mixture components. Prior probability π is a K y -D vector which represents the weight of each K y mixture component. The model on observation X is also a Gaussian mixture model. The observed D x -dimension lower features X is generated by the z-th Gaussian component: X ∼ N (µ z , Σ z ). Latent categorical variable z ∈ {1, ..., K x } has hyper-parameter K x , which is the number of mixture components. However, the key difference of our model from a traditional GMM is that the prior probability of z is dependent on posterior probability p(z|w) parameterized by Q. The connection between latent categorical variables w and z is modeled by the posterior probability matrix Q ∈ R Kx×Ky . The entry Q i,j is the conditional distribution of z, i.e. p(z = i|w = j). In figure 2 , q k is the k-th column vector of Q, i.e. the distribution of z conditioned on w = k, whose sum is 1. By the chain rule of Bayesian inference, the likelihood function of JGMM is(G is Gaussian distribution density function): p(x, y) = Kx i=1 Ky j=1 p(x|z = i)p(y|w = j)p(z = i|w = j)p(w = j) = Kx i=1 Ky j=1 G(x; µ i , Σ i )G(y; ϕ j , Θ j )Q i,j π j (1) The maximum likelihood estimation (MLE) problem of JGMM parameters can be solved by expectation-maximization (EM) algorithm. The optimization algorithm can be straightforwardly derived from the EM algorithm framework, whose details are described in the appendix section A. The complexity of the optimization is similar to separately learning two GMM for X and Y. Equation 1 shows that the trained JGMM is still a GMM with (K x K y ) components, preserving the desired properties of GMM. Discussions below use the following notations for simplicity. For a given observation x and y, g x ∈ R (Kx×1) and g y ∈ R (Ky×1) are probability density on each component, and Kx×Ky) is the conditional distribution of w: Ψ ∈ R ( [g x ] i = p(x|μ i , Σi ), [g y ] j = p(y|μ j , Θj ), λi = p(z = i) = Qi, * π Ψi,j = p(w = j|z = i) = Qi,j πj Ky k=1 Qi,k πk , i = 1, ..., K x , j = 1, ..., K y (2) where Q, μ, Σ, φ, Θ and π are parameters estimated by MLE. Qi, * is the i-th row vector of Q. Trained JGMM can work as a bi-directional Bayesian classifier to predict latent variable z and w. Conditioned on a known x or y, the distribution of the other feature's latent component variable w or z is: p(w = j|x) = (g T x Q * ,j )π j (g T x Qπ) , p(z = i|y) = ( Ψi, * g y ) λi ( λ Ψg y ) Posterior probability in equation 3 is naturally a Bayesian classifier to predict the component category of the given example. This enables a traditional training-test evaluation manner on the proxy model to evaluate the faithfulness of modelling. JGMM also makes a handy example generator and miner to improve human understanding on the model by examples. When the conditional distribution of w or z is computed by equation 3, JGMM on y or x degrades to a normal GMM, which is easy for sampling and interpretation. Prototypes. From a probabilistic aspect, a representative example x given w = j is at the center of the cluster, whose neighboring region has high cumulative probability on p(x|w = j), which is a GMM. So we define the prototype scoring function of the j-th higher feature component as equation 4.

3.2. MINING AND GENERATING EXPLANATORY EXAMPLES

f (j) proto (x) = p(x|w = j) = GM M (μ, Σ, qj ) Sampling prototypes for the j-th higher feature component is trivially sampling from multivariate Gaussian distribution p(x|w = j). Criticisms. Criticisms can be considered as outliers that is hardly represented by any prototypes. So we search for examples that are least likely sampled from any w, which is the opposite of the prototype score, as equation 5. f criti (x) = 1 -max j f (j) proto (x) (5) Sampling criticisms is not supported in our method, because the realistic data distribution is often very sparse compared to the feature space. Most sampled outliers are not understandable examples. We only focus on criticisms found in realistic data.

Influential examples.

Influential examples are close to the decision boundary and can lead to model decision with low confidence. From JGMM, we can measure the influence of an example by the entropy of the posterior probability of w conditioned on x. Higher the entropy, higher the uncertainty and influence. The scoring function is defined as equation 6. f influ (x) = - Ky j=1 p(w = j|x) ln p(w = j|x) Influential examples can be generated by sampling from each components of X and select those with largest f influ (x). Sampled influential examples, though not actually in the training data of the black-box, can be useful for exploring uncertain examples for the black-box model.

Counterfactuals.

A counterfactual example is close to the query example x query in the feature space but have very different posterior distribution on w. Therefore, the scoring function for counterfactuals is defined as the probability that x and x query are sampled from the same z but different w: f count (x) = Kx i Ky j p(w = j|z = i)p(z = i|x)p(w query ̸ = j|z query = i)p(z query = i|x query ) (7) To maximize the counterfactual score, there actually exists an optimal distribution of p(z|x). Let this distribution be α, the maximization of equation 7 is: max α Kx i α i ( Ky j p(w = j|z = i)p(w query ̸ = j|z query = i)p(z query = i|x query )), s.t. Kx i α i = 1 (8) Note that all terms except α are constant, which makes it a linear programming problem with sum constraint. Though it's still hard to derive optimal x, the solution still provides an efficient sampling approach. The approach to generate counterfactuals is to sample from the GMM under this optimal distribution α and select the best counterfactuals. Semi-factuals. A semi-factual example is likely sampled from a different z from the query example x query but have probably the same posterior distribution on w, which is the reverse of counterfactuals. Therefore, the scoring function for counterfactuals is defined as equation 9. f semi (x) = Kx i Ky j p(w = j|z = i)p(z = i|x)p(w query = j|z query ̸ = i)p(z query ̸ = i|x query ) Similar as the counterfactuals, the maximization of semi-factual score leads to an optimal distribution of p(z|x). We sample from the GMM of this optimal distribution and select the best semifactuals.

3.3. EXTENSION FOR DCNN

The deep feature maps of DCNN usually has additional spatial dimensions apart from the feature dimension, such as the time dimension of signal data and the vertical and the horizontal dimensions of 2D images. The feature maps can be formulated with spatial encoding s ∈ S, i.e. X (s) . For example, for a feature map of width W and height H, the spatial encoding is S = {(i, j)|1 ≤ i ≤ W, 1 ≤ j ≤ H}. In later discussions, S and T respectively denote the spatial dimensions of X and Y. The addition of spatial dimensions makes the latent variable relation changes from one-to-one to many-to-many, i.e. modelling each pair of {(s, t)|s ∈ S, t ∈ T }. Theoretically, JGMM can be extended to such scenario by estimating each of Q (s,t) , and make GMM parameters µ, Σ, ϕ, Θ, π shared by every s ∈ S. However, both E-step and M-step in the EM algorithm of JGMM requires enumeration at all possible latent variable assignment (see equation 13), where the combinatorial space (K |S| x K |T | y elements) is so large that enumeration is infeasible. However, most modern DCNN architectures have translation-invariant mapping of features, which means translation of lower features over spatial dimensions will cause the same translation (relative) of higher features. Also, though the relative receptive field (the region on the lower features that affects a specific position of higher feature) of T on S can be very large, research by Luo et al. (2016) shows that the central area (the effective receptive field) has main influence on the higher feature. Therefore we can extend the proposed probabilistic model to a one-to-many relation. Let this relative effective receptive field be R. R is a region cast the lower feature and connected to a point on higher feature. We train |R| posterior distribution matrix Q (r) between S and T . The size of R is a hyper-parameter. Also, the parameters µ and Σ are shared by all positions. Figure 3 shows the graphical model of JGMM extended for deep features with spatial dimensions. The hyper-parameter |R| decides how many lower feature positions will affect the higher feature. Like the setting of convolution hyper-parameters, the size of the effective receptive field should satisfy An important property of the extended JGMM is that X (r) is conditionally independent from other r ′ ∈ R on w: |S| + 1 -|T | ≥ |R|. x (r) n N µ i Σ i z (r) n y n N ϕ j Θ j w n π Cat Cat q (r) k k ∈ {1, 2, ..., Ky} i ∈ {1, 2, ..., N } r ∈ R Kx Ky p({x (r) |r ∈ R}|w) = r∈R p(x (r) |w) So the inference of the extended JGMM can be decomposed to the inference on each of the |R| JGMMs: p(w|{x (r) |r ∈ R}) = p({x (r) |r ∈ R}|w)p(w) p({x (r) |r ∈ R}) = p(w) r∈R p(x (r) |w) p({x (r) |r ∈ R}) Therefore, the inference and sampling methods of JGMM are directly applicable to the extended JGMM.

4. EXPERIMENTS

4.1 IMPLEMENTATION DETAILS JGMM implementation. For a vanilla GMM, the covariance matrix contributes the most parameters to be estimated. To reduce the complexity of JGMM training, we whiten the lower and higher features by PCA (principal component analysis). All PCA components are kept to make the transformation invertible and all dimensions are uncorrelated with each other empirically. Thus, training JGMM can be less complex by estimating two diagonal covariance matrices, i.e. only the diagonal entries of Σ and Θ are estimated and others are constantly zero. Data splitting. Features. For a DCNN image classifier, the classification output (without softmax) is always selected as a higher feature to explain. Because it directly links the understandable image classes with feature dimensions. For models with large input image resolution, the feature map is very large and, because JGMM sample from each position, makes the training of JGMM difficult. To balance the data size and training feasibility, the feature maps are average-pooled with 2 × 2 or 3 × 3 kernels to shrink the feature map. Following JGMM training and explaining on the pooled features. Hyper-parameters. The training of JGMM requires two hyper-parameters: the number of components of the lower and the higher features. The component number of the lower feature is chosen by validation performance on the validation dataset (split from the training data). When the higher feature is the classification output of the model, the number of components is set to the number of classes; otherwise, the higher feature has already been modelled by a JGMM where it's the lower feature, and the number of component is inherited from the last JGMM. The effective receptive field size is as large as the computation resource allows to capture as many positional information as possible. But when |R| is too large, the training of JGMM will be unacceptably slowed down. We set the effective receptive field size from 2 × 2 to 5 × 5. (2011) ). The dataset split is the same as the official split. We choose three feature maps from each network, named as 'high', 'middle' and 'low' in the tables and graphics below. In experiments, we train JGMM with 32 components on middle features and 128 components on low features for MNIST; 64 components on middle features and 256 components on low features for CIFAR-10; 128 components on middle features and 1024 components on low features for CUB-200.

4.2. EXPERIMENTAL SETTINGS

Faithfulness evaluation. We quantitatively evaluate the faithfulness of JGMM and other interpretable proxy models by classification accuracy on the test dataset (note that the ground truth label for this evaluation is the prediction of the black-box model). As equation 3 shows, JGMM can predict higher feature latent variable distribution given observation of the lower feature. The GMM of the higher feature is dependent on JGMM, which is a biased prediction target, so we only choose the 'high' feature map (which is associated with black-box model classification output) for faithfulness evaluation. By associating each component of the 'high' feature neuron with a dataset class, the JGMM between the 'middle' and the 'high' layers can directly predict class label with uncertainty. The counterpart interpretable proxy models include decision tree (DT), logistic regression (LR), k-nearest neighbor (kNN) and Gaussian process classifier (GP). Counterfactuals quality. The quality of counterfacutal examples is evaluated by the ratio of prediction change from the query input (higher is better) and the distance between the query input and the generated data (lower is better). To mitigate the conflict between the two metrics, we set different constraints on the l 2 distance between query feature and the generated counterfacutals. Table 1 presents proxy model accuracy of JGMM and other interpretable proxy models on various datasets and networks. It shows that JGMM is highly faithful to the black-box, which increases the credibility of the explanations. We also find that, if the black-box classifier is tasked with more classes, the accuracy of proxy models is lower. This is mainly because the black-box model tends to encode more complex deep features, but the fitting efficiency of interpretable models is often not comparable with the deep network. Figure 4 In table 2, we test the counterfactual generation effectiveness of JGMM and other methods on MobileNet-v2 trained on CUB-200, with l 2 constraint of feature distance varying from 0.1 to 3.0. JGMM's sampling method shows higher counterfactual generation effectiveness than the counterparts, especially in low distance-constraint settings. An advantage of JGMM on counterfactual effectiveness is that we can estimate the distance between the generated example and the queried example and reject too distant examples by equation 7, which significantly reduce the number of distant examples. Figure 6 presents query images for CIFAR-10 and MNIST as well as some mined counterfacutals and semi-factuals. For example, the first CIFAR-10 image query is a side view of a horse, but many deer images have very similar visual features as the query. So most mined counterfacutals are deers viewed from one side, which has different label prediction as the query. A EM ALGORITHM FOR JGMM The maximum likelihood estimation (MLE) problem of JGMM is the optimization on the loglikelihood of N observations {x 1 , y 1 , ..., x N , y N }: Let the parameters at the t-th iteration step be: θ (t) = {µ (t) i , Σ i , ϕ (t) j , Θ (t) j , Q i,j , π (t) j |1 ≤ i ≤ K x , 1 ≤ j ≤ K y } The expectation step (E-step) involves the estimation of the latent variables based on θ (t)  t) r ) 1 ≤ i ≤ K x , 1 ≤ j ≤ K y (15) The maximization step (M-step) maximize the likelihood function by updating the estimation θ (t) to θ (t+1) . The update of Gaussian parameters is: µ (t+1) i = N n=1 g (t) n,i x n N n=1 g (t) n,i Σ (t+1) i = N n=1 g (t) n,i (x n -µ )(y n -ϕ (t+1) i )(x n -µ (t+1) i ) T N n=1 g (t) n,i (t+1) j ) T N n=1 h (t) n,j The optimization of latent distribution parameters Q and π are:  As above, the E-step and the M-step of the EM algorithm for JGMM are defined. We first initialize Q and π by uniform distribution parameters. Then, we cluster the data by K-means (K x clusters for X and K y clusters for Y) and initialize Gaussian parameters by cluster sample mean and covariance. Last, the algorithm execute E-step and M-step until convergence.



The intermediate representations of DCNN are usually explained globally (not associated with a specific data point) by explanatory examples and association with semantic concepts. Prototypes, criticisms and influential examples are common types of global explanatory examples. Prototypes are representative examples of a certain pattern of deep features, illustrating a learned visual concept of the model. In contrast, criticisms (proposed by Kim et al. (2016)) are examples not well-represented in the deep representations, i.e. outliers of the deep features, revealing the flaw of the learned representations. Influential examples are hard examples for the model training, having more influence on the final decision boundary than others. From a post-hoc view, influential examples lie close to the decision boundary. An example can be both influential and representative (or unrepresentative). Local explanations are based on a specific query example, showing how the change of the query example features will affect the model prediction. Counterfactual examples offer an actionable re-course for the model decision. Counterfactual examples answer the 'what if' questions by minimal change of query example features and resulting different model decision against the query. Counterfactual examples reveal the sensitive features for the query. Semi-factual examples, in contrast, aim to answer the 'even if' question. Semi-factual examples have significant change on a certain feature(s) from the query, but both have the same model prediction. Semi-factual examples reveal the insensitive features of the query. For example, Kenny & Keane (2021) proposes a method to generate counterfactual and semi-factual examples from one system.

Figure 1: Proposed probabilistic framework for a versatile explaining method. The intermediate features (X and Y) are jointly modelled by a probabilistic model joint Gaussian mixture model (JGMM).

Figure 2: Graphical model of joint Gaussian mixture model (JGMM) on lower feature X and higher feature Y.

This section introduces how to mine or generate explanatory examples (prototypes, criticisms, influential examples, counterfactuals and semi-factuals) by a trained JGMM. The example mining problem is to search for the best explanatory examples in the dataset, which requires a scoring function f (x) to evaluate the quality. Generation problem is the sampling of data that maximize the scoring function. We respectively define the scoring function for each type of explanatory examples and introduce the way to mine or generate examples.

Figure 3: Extension of JGMM to relative spatial dimension R.

Networks and data. The experiments are performed on VGG-16/19(Simonyan & Zisserman (2015)), ResNet-18/50(He et al. (2016)) and MobileNet-V1/V2(Sandler et al. (2018)), respectively trained on MNIST (LeCun et al. (1998)), CIFAR-10(Krizhevsky (2009)) and CUB-200(Wah et al.

Figure 4: Global explaining for VGG-19 trained on MNIST. (a) visualizes the posterior distribution of the latent variable, which interprets how JGMM makes predictions. (b) presents global explanatory examples: prototypes (with its component index), criticisms and influentials (with its two most likely labels).

Figure 5: Global explaining for ResNet-18 trained on CUB-200.

Figure 5 presents global explanatory examples mined from ResNet-18 trained on CUB-200 with bird species classification task. The 'middle' feature map has encoded different visual concepts into the deep representations, which is revealed by the prototypes, such as z = 7 'long pointed beak', z = 16 'yellow body part', z = 26 'branches texture' and z = 82 'red body part'. An influential example x has close probability on w = 5 and w = 48. The bird in the influential example has red feather at the top of the head and pointed beak, but the shapes of the red feather and the beak are not typical in either w = 5 or w = 48, which makes it near the decision boundary. Criticisms of CUB-200 are often rarely observed sub-species or whose image features are not sufficient for discrimination.

Figure 6: Mined counterfactual and semi-factuals for queries from CIFAR-10 and MNIST.

µ i , Σ i )G(y; ϕ j , Θ j )Q i,j π j )

To fairly test the faithfulness of JGMM, we only use the training data of the blackbox model to train JGMM and use the test data to evaluate JGMM's accuracy in predicting the black-box model's decision. The effectiveness test of counter/semi-factual example generation is also performed with the test data. But the explaining step (the generation of all kinds of explanatory examples) makes use of all data.

Proxy model accuracy (%) evaluation on 'middle' and 'high' layer features with various benchmarks and networks.

presents the estimated posterior Q and the derived Ψ of JGMM and global explanatory examples mined from the test dataset. The visualization of Q suggests that each latent component z of the lower feature is mainly mapped to a single higher feature component (though noisy). The visualization of Ψ illustrates that each higher feature component (linked to digit class label) has multiple patterns of examples. As observed from Ψ, we can find that components z ∈ {2, 18, 19} are learned patterns for digit "4", and z ∈ {4, 14} are learned patterns for digit "1". The prototypes of those components are visualized. It can be found that each latent component z is encoded with a certain style of writing. The mined criticisms in figure4can hardly be represented by any latent component, which is intuitively true by the examples and technically evaluated by equation 5. The influentials are presented with two most likely p(w = j|x) in equation 6. For example, the first influential example can probably be classified as '7' and '9' with similar posterior probability of w.

this paper, joint Gaussian mixture model is proposed to explain the deep features and model inference of black-box DCNN. As a probabilistic model that links two deep features, JGMM enable consistent and diverse explanation forms, including proxy models and example-based explaining. Experiments show that JGMM can faithfully interpret model inference and efficiently mine or generate explanatory examples. The computational efficiency of JGMM worth improvement, especially with large dataset and higher feature dimensionality, as training with EM algorithm in such scenarios is still very time and memory costly.

. Let g

