VARIATIONAL INFORMATION PURSUIT FOR INTERPRETABLE PREDICTIONS

Abstract

There is a growing interest in the machine learning community in developing predictive algorithms that are interpretable by design. To this end, recent work proposes to sequentially ask interpretable queries about data until a high confidence prediction can be made based on the answers obtained (the history). To promote short query-answer chains, a greedy procedure called Information Pursuit (IP) is used, which adaptively chooses queries in order of information gain. Generative models are employed to learn the distribution of query-answers and labels, which is in turn used to estimate the most informative query. However, learning and inference with a full generative model of the data is often intractable for complex tasks. In this work, we propose Variational Information Pursuit (V-IP), a variational characterization of IP which bypasses the need to learn generative models. V-IP is based on finding a query selection strategy and a classifier that minimize the expected cross-entropy between true and predicted labels. We prove that the IP strategy is the optimal solution to this problem. Therefore, instead of learning generative models, we can use our optimal strategy to directly pick the most informative query given any history. We then develop a practical algorithm by defining a finite-dimensional parameterization of our strategy and classifier using deep networks and train them end-to-end using our objective. Empirically, V-IP is 10-100x faster than IP on different Vision and NLP tasks with competitive performance. Moreover, V-IP finds much shorter query chains when compared to reinforcement learning which is typically used in sequential-decision-making problems. Finally, we demonstrate the utility of V-IP on challenging tasks like medical diagnosis where the performance is far superior to the generative modeling approach.

1. INTRODUCTION

Suppose a doctor diagnoses a patient with a particular disease. One would want to know not only the disease but also an evidential explanation of the diagnosis in terms of clinical test results, physiological data, or symptoms experienced by the patient. For practical applications, machine learning methods require an emphasis not only on metrics such as generalization and scalability but also on criteria such as interpretability and transparency. With the advent of deep learning methods over traditionally interpretable methods such as decision trees or logistic regression, the ability to perform complex tasks such as large-scale image classification now often implies a sacrifice in interpretability. However, interpretability is important in unveiling potential biases for users with different backgrounds (Yu, 2018) or gaining users' trust. Most of the prominent work in machine learning that addresses this question of interpretability is based on post hoc analysis of a trained deep network's decisions (Simonyan et al., 2013; Ribeiro et al., 2016; Shrikumar et al., 2017; Zeiler & Fergus, 2014; Selvaraju et al., 2017; Smilkov et al., 2017; Chattopadhyay et al., 2019; Lundberg & Lee, 2017) . These methods typically assign importance scores to different features used in a model's decision by measuring the sensitivity of the model output to these features. However, explanations in terms of importance scores of raw features might not always be as desirable as a description of the reasoning process behind a model's decision. Moreover, there are rarely any guarantees for the reliability of these post hoc explanations to faithfully represent the model's decision-making process (Koh et al., 2020) . Consequently, post hoc interpretability has been widely criticized (Adebayo et al., 2018; Kindermans et al., 2019; Rudin, 

