NEUROCHAINS: EXTRACT LOCAL REASONING CHAINS OF DEEP NEURAL NETS

Abstract

We study how to explain the main steps/chains of inference that a deep neural net (DNN) relies on to produce predictions in a local region of data space. This problem is related to network pruning and interpretable machine learning but the highlighted differences are: (1) fine-tuning of neurons/filters is forbidden: only exact copies are allowed; (2) we target an extremely high pruning rate, e.g., ≥ 95%; (3) the interpretation is for the whole inference process in a local region rather than for individual neurons/filters or on a single sample. In this paper, we introduce an efficient method, NeuroChains, to extract the local inference chains by optimizing a differentiable sparse scoring for the filters and layers to preserve the outputs on given data from a local region. Thereby, NeuroChains can extract an extremely small sub-network composed of filters exactly copied from the original DNN by removing the filters/layers with small scores. We then visualize the sub-network by applying existing interpretation technique to the retained layer/filter/neurons and on any sample from the local region. Its architecture reveals how the inference process stitches and integrates the information layer by layer and filter by filter. We provide detailed and insightful case studies together with three quantitative analyses over thousands of trials to demonstrate the quality, sparsity, fidelity and accuracy of the interpretation within the assigned local regions and over unseen data. In our empirical study, NeuroChains significantly enriches the interpretation and makes the inner mechanism of DNNs more transparent than before.

1. INTRODUCTION

Deep neural networks (DNNs) greatly reshape a variety of tasks -object classification, semantic segmentation, natural language processing, speech recognition, robotics, etc. Despite its success on a vast majority of clean data, DNNs are also well-known to be sensitive to small amounts of adversarial noises. The lack of sufficient interpretability about their success or failure is one major bottleneck of applying DNNs to important areas such as medical diagnosis, public health, transportation systems, financial analysis, etc. Interpretable machine learning has attracted growing interest in a variety of areas. The forms of interpretation vary across different methods. For example, attribution methods (Bach et al., 2015; Sundararajan et al., 2017; Shrikumar et al., 2017; Montavon et al., 2017; Kindermans et al., 2017; Smilkov et al., 2017) produce the importance score of each input feature to the output prediction for an given sample, while some other methods (Zeiler & Fergus, 2014; Simonyan et al., 2013; Erhan et al., 2009) aim to explain the general functionality of each neuron/filter or an individual layer regardless of the input sample. Another line of works (Ribeiro et al., 2016; Wu et al., 2018; Hou & Zhou, 2018) explain DNNs in a local region of data space by training a shallow (e.g., linear) and easily interpretable model to approximate the original DNN on some locally similar samples. Thereby, they reduces the problem to explaining the shallow model. These methods essentially reveal the neuron to neuron correlations (e.g., input to output, intermediate layer/neuron to output, etc), but they cannot provide an overview of the whole inference process occurring inside the complicated structure of DNNs. In this paper, we study a more challenging problem: Can we unveil the major hidden steps of inference in DNNs and present them in a succinct and human-readable form? Solving this problem helps to answer many significant questions, e.g., which layer(s)/neuron(s) plays the most/least im-portant role in the inference process? Do two similar samples really share most inference steps? Do all samples need the same number of neurons/layers to locate the key information leading to their correct predictions? How/when/where does a failure happen during the inference on a DNN? Which neuron(s)/layer(s)/feature(s) are shared by different samples from the same region even when their labels differ? Are DNNs using entirely different parts for data from different local regions? Some of them are related to other problems such as network pruning (Han et al., 2015; Li et al., 2016) We develop an efficient tool called Neu-roChains to extract the underlying inference chains of a DNN for a given local data region. Specifically, we aim to extract a much smaller sub-network composed of a subset of neurons/filters exactly copied from the original DNN and whose output for data from the local region stays consistent with that of the original DNN. In experiments, we assume the data from the same classes reside in a local region. While the selected filters explain the key information captured by the original DNN when applied to data from the local region, the architecture of the sub-network stitches these information sequentially, i.e., step by step and layer by layer, and thus recover the major steps of inference that lead to the final outputs. Despite its combinatorial nature, we parameterize the sub-network as the original DNN with an additional score multiplied to each filter/layer's output featuremap. Thereby, we formulate the above problem of sub-network extraction as optimizing a differentiable sparse scoring of all the filters and layers in order to preserve the outputs on all given samples. The above problem can be solved by an efficient back-propagation that only updates the scores with fixed filter parameters. The objective is built upon the Kullback-Leibler (KL) divergence between the sub-network's output distribution and that of the original DNN, along with an 1 regularization for sparse scores over filters. We further use a sigmoid gate per layer to choose whether removing the entire layer. The gate plays an important role in reducing the sub-network size since most local regions do not rely on all the layers. In practice, we further apply a thresholding to the sparse scores to obtain an even smaller sub-network and employ an additional fine-tuning to the filter scores on the sub-network. We illustrate the subnetwork's architecture and visualize its filters and intermediate-layers' featuremaps by existing methods (Mundhenk et al., 2019; Erhan et al., 2009) . NeuroChains is a novel pruning technique specifically designed for interpreting the local inference chains of DNNs. As aforementioned, it potentially provides an efficient tool to study other problems in related tasks. However, it has several fundamental differences to network pruning and exist-



Figure 1: Inference chain by NeuroChains for ResNet-50 (pre-trained on ImageNet) when applied to 20 test images of "dalmatian" and "strawberry". Top: The sub-network retains only 13/67 layers and 75/26560 filters of the ResNet-50. The scores for selected filters are represented using the colormap on the topright. Middle: The per-layer featuremaps generated by SMOE (Mundhenk et al., 2019) show a clear trends of firstly extracting the local patterns (dots on dalmatian and strawberry) and gradually covering a global shapes of the classes. Bottom: Filters with the largest scores are visualized using the method by Erhan et al. (2009). In shallower layers, L4B1C2 48 and L4B1C3 524 capture a more local black spot pattern of the dalmatian, L3B1C3 683 captures the eyes and nose patterns; L3B1C3 818 extracts the local color pattern of strawberry. In the last bottleneck layer, L4B3SC 769 and L4B3SC 1457 capture the global patterns of dalmatian's black and white fur; L4B3C3 456 and L4B3SC 511 captures the main shape and color of strawberry. It shows an inference chain for strawberry: L1B1SC 177 → L2B1SC 75 and L2B1SC 342 → L3B1C3 818 → L4B1C2 328 → L4B3C3 456 and L4B3SC 511.

