COUNTERNET: END-TO-END TRAINING OF PREDIC-TION AWARE COUNTERFACTUAL EXPLANATIONS

Abstract

Counterfactual (or CF) explanations are a type of local explanations for Machine Learning (ML) model predictions, which offer a contrastive case as an explanation by finding the smallest changes (in feature space) to the input data point, which will lead to a different prediction by the ML model. Existing CF explanation techniques suffer from two major limitations: (i) all of them are post-hoc methods designed for use with proprietary ML models -as a result, their procedure for generating CF explanations is uninformed by the training of the ML model, which leads to misalignment between model predictions and explanations; and (ii) most of them rely on solving separate time-intensive optimization problems to find CF explanations for each input data point (which negatively impacts their runtime). This work makes a novel departure from the prevalent post-hoc paradigm (of generating CF explanations) by presenting CounterNet, an end-to-end learning framework which integrates predictive model training and the generation of counterfactual (CF) explanations into a single pipeline. We adopt a block-wise coordinate descent procedure which helps in effectively training CounterNet's network. Our extensive experiments on multiple real-world datasets show that CounterNet generates highquality predictions, and consistently achieves 100% CF validity and low proximity scores (thereby achieving a well-balanced cost-invalidity trade-off) for any new input instance, and runs 3X faster than existing state-of-the-art baselines.

1. INTRODUCTION

A counterfactual (CF) explanation offers a contrastive case -to explain the prediction made by a Machine Learning (ML) model on data point x, CF explanation methods find a new counterfactual example x ′ , which is similar to x but gets a different (or opposite) prediction from the ML model. From an end-user perspective, CF explanation methods 1 (Wachter et al., 2017) may be more preferable (as compared to other methods of explaining ML models), as they can be used to offer recourse to vulnerable groups. For example, if a person applies for a loan and gets rejected by a bank's ML algorithm, CF explanation methods can suggest corrective measures to the loan applicant, which can be incorporated in a future loan application to improve their chances of getting an approved loan. Generating high-quality CF explanations is a challenging problem because of the need to balance the cost-invalidity trade-off (Rawal et al., 2020) between: (i) the invalidity, i.e., the probability that a CF example is invalid, or it does not achieve the desired (or opposite) prediction from the ML model; and (ii) the cost of change, i.e., the L 1 norm distance between input instance x and CF example x ′ . Figure 1 illustrates this trade-off by showing three different CF examples for an input instance x. If invalidity is ignored (and optimized only for cost of change), the generated CF example can be trivially set to x itself. Conversely, if cost of change is ignored (and optimized only for invalidity), the generated CF example can be set to x ′ 2 (or any sufficiently distanced instance with different labels). More generally, CF examples with high (low) invalidities usually imply low (high) cost of change. To optimally balance this trade-off, it is critical for CF explanation methods to have access to the 1 CF explanations are closely related to algorithmic recourse (Ustun et al., 2019) and contrastive explanations (Dhurandhar et al., 2018) . Although these terms are proposed under different contexts, their differences from CF explanations have been blurred (Verma et al., 2020; Stepin et al., 2021) , i.e. these terms are used interchangeably. decision boundary of the ML model, without which finding a near-optimal CF explanation (i.e., x ′ 1 ) is difficult. For example, it is difficult to distinguish between x ′ 1 (a valid CF example) and x ′ 0 (an invalid CF example) without prior knowledge of the decision boundary. Existing CF explanation methods suffer from three major limitations. First, to our best knowledge, all prior methods belong to the post-hoc explanation paradigm, i.e., they assume a trained black-box ML model as input. This post-hoc assumption has certain advantages, e.g., post-hoc explanation techniques are often agnostic to the particulars of the ML model, and hence, they are generalizable enough to interpret any third-party proprietary ML model. However, we argue that in many real-world scenarios, the model-agnostic approach provided by post-hoc CF explanation methods is not desirable. With the advent of data regulations that enshrine the "Right to Explanation" (e.g., EU-GDPR (Wachter et al., 2017)), service providers are required by law to communicate both the decision outcome (i.e., the ML model's prediction) and its actionable implications (i.e., a CF explanation for this prediction) to an enduser. In these scenarios, the post-hoc assumption is overly limiting, as service providers can build specialized CF explanation techniques that can leverage the knowledge of their particular ML model to generate higher-quality CF explanations. Second, in the post-hoc CF explanation paradigm, the optimization procedure that finds CF explanations is completely uninformed by the ML model training procedure (and the resulting decision boundary). Consequently, such a posthoc procedure does not properly balance the cost-invalidity trade-off (as explained above), causing shortcomings in the quality of the generated CF explanations (as shown in Section 4). Finally, most CF explanation methods are very slow -they search for CF examples by solving a separate time-intensive optimization problem for each input instance (Wachter et al., 2017; Mothilal et al., 2020; Karimi et al., 2021) , which is not viable in time-constrained environment, as the runtime is a critical factor when deployed to end-user facing devices (Zhao et al., 2018; Arapakis et al., 2021) . Contributions. We make a novel departure from the prevalent post-hoc paradigm of generating CF explanations by proposing CounterNet, a learning framework that combines the training of the ML model and the generation of corresponding CF explanations into a single end-to-end pipeline (i.e., from input to prediction to explanation). CounterNet has three contributions: • Unlike post-hoc approaches (where CF explanations are generated after the ML model is trained), CounterNet uses a (neural network) model-based CF generation method, enabling the joint training of its CF generator network and its predictor network. At a high level, CounterNet's CF generator network takes as input the learned representations from its predictor network, which is jointly trained along with the CF generator. This joint training is key to achieving a well-balanced cost-invalidity trade-off (as we show in Section 4). 

2. RELATED WORK

Prior explanation techniques for ML models include LIME (Ribeiro et al., 2016) , SHAP (Lundberg & Lee, 2017), saliency maps (Selvaraju et al., 2017; Sundararajan et al., 2017; Smilkov et al., 2017) , which highlight attribution importance for each data instance. Further, case-based methods provide



Figure 1: Illustration of the costinvalidity trade-off in CF explanations for binary classification problems.

• We theoretically analyze CounterNet's objective function to show two key challenges in training CounterNet: (i) poor convergence of learning; and (ii) a lack of robustness against adversarial examples. To remedy these issues, we propose a novel block-wise coordinate descent procedure. • We conduct extensive experiments which show that CounterNet generates CF explanations with ∼100% validity and low cost of change (∼9.8% improvement to baselines), which shows that CounterNet balances the cost-invalidity trade-off significantly better than baseline approaches. In addition, this joint-training procedure does not sacrifice CounterNet's predictive accuracy and robustness. Finally, CounterNet runs orders of magnitude (∼3X) faster than baselines.

