VARIATIONAL INFORMATION PURSUIT FOR INTERPRETABLE PREDICTIONS

Abstract

There is a growing interest in the machine learning community in developing predictive algorithms that are interpretable by design. To this end, recent work proposes to sequentially ask interpretable queries about data until a high confidence prediction can be made based on the answers obtained (the history). To promote short query-answer chains, a greedy procedure called Information Pursuit (IP) is used, which adaptively chooses queries in order of information gain. Generative models are employed to learn the distribution of query-answers and labels, which is in turn used to estimate the most informative query. However, learning and inference with a full generative model of the data is often intractable for complex tasks. In this work, we propose Variational Information Pursuit (V-IP), a variational characterization of IP which bypasses the need to learn generative models. V-IP is based on finding a query selection strategy and a classifier that minimize the expected cross-entropy between true and predicted labels. We prove that the IP strategy is the optimal solution to this problem. Therefore, instead of learning generative models, we can use our optimal strategy to directly pick the most informative query given any history. We then develop a practical algorithm by defining a finite-dimensional parameterization of our strategy and classifier using deep networks and train them end-to-end using our objective. Empirically, V-IP is 10-100x faster than IP on different Vision and NLP tasks with competitive performance. Moreover, V-IP finds much shorter query chains when compared to reinforcement learning which is typically used in sequential-decision-making problems. Finally, we demonstrate the utility of V-IP on challenging tasks like medical diagnosis where the performance is far superior to the generative modeling approach.

1. INTRODUCTION

Suppose a doctor diagnoses a patient with a particular disease. One would want to know not only the disease but also an evidential explanation of the diagnosis in terms of clinical test results, physiological data, or symptoms experienced by the patient. For practical applications, machine learning methods require an emphasis not only on metrics such as generalization and scalability but also on criteria such as interpretability and transparency. With the advent of deep learning methods over traditionally interpretable methods such as decision trees or logistic regression, the ability to perform complex tasks such as large-scale image classification now often implies a sacrifice in interpretability. However, interpretability is important in unveiling potential biases for users with different backgrounds (Yu, 2018) or gaining users' trust. Most of the prominent work in machine learning that addresses this question of interpretability is based on post hoc analysis of a trained deep network's decisions (Simonyan et al., 2013; Ribeiro et al., 2016; Shrikumar et al., 2017; Zeiler & Fergus, 2014; Selvaraju et al., 2017; Smilkov et al., 2017; Chattopadhyay et al., 2019; Lundberg & Lee, 2017) . These methods typically assign importance scores to different features used in a model's decision by measuring the sensitivity of the model output to these features. However, explanations in terms of importance scores of raw features might not always be as desirable as a description of the reasoning process behind a model's decision. Moreover, there are rarely any guarantees for the reliability of these post hoc explanations to faithfully represent the model's decision-making process (Koh et al., 2020) . Consequently, post hoc interpretability has been widely criticized (Adebayo et al., 2018; Kindermans et al., 2019; Rudin, 2019; Slack et al., 2020; Shah et al., 2021; Yang & Kim, 2019) and there is a need to shift towards ML algorithms that are interpretable by design. The authors propose the concept of an interpretable query set Q, a set of user-defined and taskspecific functions q : X → A, which map a data point in X to an answer in A, each having a clear interpretation to the enduser. For instance, a plausible query set for identifying bird species might involve querying beak shape, head colour, and other visual attributes of birds. Given a query set, their method sequentially asks queries about X until the answers obtained are sufficient for predicting the label/hypothesis Y with high confidence. Notably, as the final prediction is solely a function of this sequence of query-answer pairs, these pairs provide a complete explanation for the prediction. Figure 1 illustrates the framework on a bird classification task. To obtain short explanations (short query-answer chains), the authors propose to use a greedy procedure called Information Pursuit (IP), which was first introduced in Geman & Jedynak (1996) . Given any input x obs , IP sequentially chooses the query which has the largest mutual information about the label/hypothesis Y given the history of query-answers obtained so far. To compute this mutual information criteria, a generative model is first trained to learn the joint distribution between all query-answers q(X) and Y ; in particular, Variational Autoencoders (VAEs) (Kingma & Welling, 2013) are employed. This learnt VAE is then used to construct Markov Chain Monte Carlo (MCMC) estimates for mutual information via MCMC sampling. Unfortunately, the computational costs of MCMC sampling coupled with the challenges of learning accurate generative models that enable fast inference limit the application of this framework to simple tasks. As an example, classifying MNIST digits using 3 × 3 overlapping patches as queriesfoot_0 with this approach would take weeks! In this paper, we question the need to learn a full generative model between all query-answers q(X) and Y given that at each iteration IP is only interested in finding the most informative query given the history. More specifically, we present a variational charaterization of IP which is based on the observation that, given any history, the query q * , whose answer minimizes the KL-divergence between the label distribution P (Y | X) and the posterior P (Y | q * (X), history), will be the most informative query as required by IP. As a result, we propose to minimize this KL-divergence term in expectation (over randomization of histories) by optimizing over querier functions, which pick a query from Q given history, parameterized by deep networks. The optimal querier would then learn to directly pick the most informative query given any history, thus bypassing the need for explicitly computing mutual information using generative models. Through extensive experiments, we show that the proposed method is not only faster (since MCMC sampling methods are no longer needed for inference), but also achieves competitive performance when compared with the generative modeling approach and also outperforms other state-of-the-art sequential-decision-making methods.

Paper Contributions.

(1) We present a variational characterization of IP, termed Variational-IP or V-IP, and show that the solution to the V-IP objective is exactly the IP strategy. (2) We present a practical algorithm for optimizing this objective using deep networks. (3) Empirically, we show that V-IP achieves competitive performance with the generative modelling approach on various computer vision and NLP tasks with a much faster inference time. (4) Finally, we also compare our approach to Reinforcement Learning (RL) approaches used in sequential-decision making areas like Hard Attention (Mnih et al., 2014) and Symptom Checking (Peng et al., 2018) , where the objective is to learn a policy which adaptively chooses a fixed number of queries, one at a time, such that an accurate prediction can be made. In all experiments, V-IP is superior to RL methods.



Each patch query asks about the pixel intensities observed in that patch for x obs .



Figure1: Illustration of the framework on a bird classification task. The query set consists of questions about the presence or absence of different visual attributes of birds. Given an image x obs , a sequence of interpretable queries is asked about the image until a prediction can be made with high confidence. The choice of each query depends on the query-answers observed so far.An interesting framework for making interpretable predictions was recently introduced by Chattopadhyay et al. (2022). The authors propose the concept of an interpretable query set Q, a set of user-defined and taskspecific functions q : X → A, which map a data point in X to an answer in A, each having a clear interpretation to the enduser. For instance, a plausible query set for identifying bird species might involve querying beak shape, head colour, and other visual attributes of birds. Given a query set, their method sequentially asks queries about X until the answers obtained are sufficient for predicting the label/hypothesis Y with high confidence. Notably, as the final prediction is solely a function of this sequence of query-answer pairs, these pairs provide a complete explanation for the prediction. Figure1illustrates the framework on a bird classification task.

