COUNTERNET: END-TO-END TRAINING OF PREDIC-TION AWARE COUNTERFACTUAL EXPLANATIONS

Abstract

Counterfactual (or CF) explanations are a type of local explanations for Machine Learning (ML) model predictions, which offer a contrastive case as an explanation by finding the smallest changes (in feature space) to the input data point, which will lead to a different prediction by the ML model. Existing CF explanation techniques suffer from two major limitations: (i) all of them are post-hoc methods designed for use with proprietary ML models -as a result, their procedure for generating CF explanations is uninformed by the training of the ML model, which leads to misalignment between model predictions and explanations; and (ii) most of them rely on solving separate time-intensive optimization problems to find CF explanations for each input data point (which negatively impacts their runtime). This work makes a novel departure from the prevalent post-hoc paradigm (of generating CF explanations) by presenting CounterNet, an end-to-end learning framework which integrates predictive model training and the generation of counterfactual (CF) explanations into a single pipeline. We adopt a block-wise coordinate descent procedure which helps in effectively training CounterNet's network. Our extensive experiments on multiple real-world datasets show that CounterNet generates highquality predictions, and consistently achieves 100% CF validity and low proximity scores (thereby achieving a well-balanced cost-invalidity trade-off) for any new input instance, and runs 3X faster than existing state-of-the-art baselines. Existing CF explanation methods suffer from three major limitations. First, to our best knowledge, all prior methods belong to the post-hoc explanation paradigm, i.e., they assume a trained black-box ML model as input. This post-hoc assumption has certain advantages, e.g., post-hoc explanation techniques are often agnostic to the particulars of the ML model, and hence, they are generalizable enough to interpret any third-party proprietary ML model. However, we argue that in many real-world scenarios, the model-agnostic approach provided by post-hoc CF explanation methods is not desirable. With the advent of data regulations that enshrine the "Right to Explanation" (e.g., EU-GDPR (Wachter et al., 2017)), service providers are required by law to communicate both the decision outcome (i.e., the ML model's prediction) and its actionable implications (i.e., a CF explanation for this prediction) to an enduser. In these scenarios, the post-hoc assumption is overly limiting, as service providers can build specialized CF explanation techniques that can leverage the knowledge of their particular ML model to generate higher-quality CF explanations. Second, in the post-hoc CF explanation paradigm, the optimization procedure that finds CF explanations is completely uninformed by the ML model training procedure (and the resulting decision boundary). Consequently, such a posthoc procedure does not properly balance the cost-invalidity trade-off (as explained above), causing shortcomings in the quality of the generated CF explanations (as shown in Section 4). Finally, most CF explanation methods are very slow -they search for CF examples by solving a separate time-intensive optimization problem for each input instance (Wachter et al., 

1. INTRODUCTION

A counterfactual (CF) explanation offers a contrastive case -to explain the prediction made by a Machine Learning (ML) model on data point x, CF explanation methods find a new counterfactual example x ′ , which is similar to x but gets a different (or opposite) prediction from the ML model. From an end-user perspective, CF explanation methods 1 (Wachter et al., 2017) may be more preferable (as compared to other methods of explaining ML models), as they can be used to offer recourse to vulnerable groups. For example, if a person applies for a loan and gets rejected by a bank's ML algorithm, CF explanation methods can suggest corrective measures to the loan applicant, which can be incorporated in a future loan application to improve their chances of getting an approved loan. Generating high-quality CF explanations is a challenging problem because of the need to balance the cost-invalidity trade-off (Rawal et al., 2020) between: (i) the invalidity, i.e., the probability that a CF example is invalid, or it does not achieve the desired (or opposite) prediction from the ML model; and (ii) the cost of change, i.e., the L 1 norm distance between input instance x and CF example x ′ . Figure 1 illustrates this trade-off by showing three different CF examples for an input instance x. If invalidity is ignored (and optimized only for cost of change), the generated CF example can be trivially set to x itself. Conversely, if cost of change is ignored (and optimized only for invalidity), the generated CF example can be set to x ′ 2 (or any sufficiently distanced instance with different labels). More generally, CF examples with high (low) invalidities usually imply low (high) cost of change. To optimally balance this trade-off, it is critical for CF explanation methods to have access to the 1 CF explanations are closely related to algorithmic recourse (Ustun et al., 2019) and contrastive explanations (Dhurandhar et al., 2018) . Although these terms are proposed under different contexts, their differences from CF explanations have been blurred (Verma et al., 2020; Stepin et al., 2021) , i.e. these terms are used interchangeably. 1 Under review as a conference paper at ICLR 2023 decision boundary of the ML model, without which finding a near-optimal CF explanation (i.e., x ′ 1 ) is difficult. For example, it is difficult to distinguish between x ′ 1 (a valid CF example) and x ′ 0 (an invalid CF example) without prior knowledge of the decision boundary. (similar) data samples as model explanations (Guidotti et al., 2018; Molnar et al., 2020; Chen et al., 2019; Koh & Liang, 2017) . Our work is most closely related to prior literature on CF explanation methods, which focuses on finding new instances that lead to different predicted outcomes (Wachter et al., 2017; Verma et al., 2020; Karimi et al., 2020; Stepin et al., 2021) . CF explanations are preferred by human end-users as these explanations provide actionable recourse in many domains (Binns et al., 2018; Miller, 2019; Bhatt et al., 2020) . Almost all prior work in this area belongs to the post-hoc CF explanation paradigm, which we categorize into non-parametric and parametric methods. Non-parametric methods. Non-parametric methods aim to find a counterfactual explanation without the use of parameterized models. Wachter et al. (2017) proposed VanillaCF which generates CF explanations by minimizing the distance between the input instance and the CF example, while pushing the new prediction towards the desired class. Other algorithms, built on top of VanillaCF, optimize other aspects, such as recourse cost (Ustun et al., 2019) , fairness (Von Kügelgen et al., 2022) , diversity (Mothilal et al., 2020) , closeness to the data manifold (Van Looveren & Klaise, 2019) , causal constraints (Karimi et al., 2021) , uncertainty (Schut et al., 2021) , and robustness to model shift (Upadhyay et al., 2021) . However, this line of work is inherently post-hoc and relies on solving a separate optimization problem for each input instance. Consequently, running them is time-consuming, and their post-hoc nature leads to poor balancing of the cost-invalidity trade-off. Parametric Methods. These methods use parametric models (e.g., a neural network model) to generate CF explanations. For example, Pawelczyk et al. (2020) ; Joshi et al. (2019) generate CF explanations by perturbing the latent variable of a variational autoencoder (VAE) model. Yang et al. (2021) ; Singla et al. (2020) ; Nemirovsky et al. (2022) and Mahajan et al. (2019) ; Rodríguez et al. (2021) ; Guyomard et al. (2022) train generative models (GAN and VAE, respectively) to produce CF explanations for a trained ML model. However, these methods are still post-hoc in nature, and thus, they also suffer from poorly balanced cost-invalidity trade-offs. Contrastingly, we depart from this post-hoc paradigm, which leads to a greater alignment between CounterNet's predictions and CF explanations. Note that Ross et al. (2021) propose a recourse-friendly ML model by integrating recourse training during predictive model training. However, their work does not focus on generating CF explanations. In contrast, we focus on generating predictions and CF explanations simultaneously.

3. THE PROPOSED FRAMEWORK: COUNTERNET

Unlike prior work, our proposed framework CounterNet relies on a novel integrated architecture which combines predictive model training and counterfactual explanation generation into a single optimization framework. Through this integration, we can simultaneously optimize the accuracy of the trained predictive model and the quality of the generated counterfactual explanations. Formally, given an input instance x ∈ R d , CounterNet aims to generate two outputs: (i) the ML prediction component outputs a prediction ŷx for input instance x; and (ii) the CF explanation generation component produces a CF example x ′ ∈ R d as an explanation for input instance x. Ideally, the CF example x ′ should get a different (and often more preferable) prediction ŷx ′ , as compared to the prediction ŷx on the original input instance x (i.e., ŷx ′ ̸ = ŷx ). In particular, if the desired prediction output is binary-valued (0, 1), then ŷx and ŷx ′ should take on opposite values (i.e., ŷx + ŷx ′ = 1).

3.1. NETWORK ARCHITECTURE

Figure 2 illustrates CounterNet's architecture which includes three components: (i) an encoder network h(•); (ii) a predictor network f (•); and (iii) a CF generator network g(•). During training, each input instance x ∈ R d is first passed through the encoder network to generate a dense latent vector representation of x (denoted by z x = h(x)). Then, this latent representation is passed through both the predictor network and the CF generator network. The predictor network outputs a softmax representation of the prediction ŷx = f (z x ). To generate CF examples, the CF generator network takes two pieces of information: (i) the final representation of the predictor network p x (before it is passed through the softmax layer), and (ii) the latent vector z x (which contains a dense representation of the input x). These two vectors are concatenated to produce the final latent vector z ′ x = p x ⊕ z x , which is passed through the CF generator network to produce a CF example x ′ = g(z ′ x ). Note that passing the representation of predictor network p x through the CF generator network implicitly conveys information about the decision boundary to the CF generation procedure, who leverages this knowledge to find high-quality CF examples x ′ . This design aims to achieve a better balance on the cost-invalidity tradeoff (as shown in Section 4). Design of Encoder, Predictor & CF Generator. All three components in CounterNet's architecture consist of a multi-layer perception (MLP)foot_0 . The encoder network in CounterNet consists of two feed-forward layers that down-sample to generate a latent vector z ∈ R k (s.t. k < d). The predictor network passes this latent vector z through two feed-forward layers to produce the predictor representation p. Finally, the predictor network outputs the probability distribution over predictions with a fully-connected layer followed by a softmax layer. On the other hand, the CF generator network takes the final latent representation z ′ = z ⊕ p as an input, and up-samples (using two feed-forward layers) to produce CF examples x ′ ∈ R d . Each feed-forward neural network layer inside CounterNet uses LeakyRelu activation functions (Xu et al., 2015) followed by a dropout layer (Srivastava et al., 2014) . Note that the number of feed-forward layers, the choice of activation function, etc., were hyperparameters that were optimized using grid search. See Appendix B.4 for implementation details. Handling Categorical Features. To handle categorical features, we customize CounterNet's architecture for each dataset. First, we transform all categorical features in each dataset into numeric features via one-hot encoding. In addition, for each categorical feature, we add a softmax layer after the final output layer in the CF generator network (Figure 2 ), which ensures that the generated CF examples respect the one-hot encoding format (as the output of the softmax layer will sum up to 1). Finally, we normalize all continuous features to the [0, 1] range before training.

3.2. COUNTERNET OBJECTIVE FUNCTION

CounterNet has three objectives: (i) predictive accuracy -the predictor network should output accurate predictions ŷx ; (ii) counterfactual validity -CF examples x ′ produced by the CF generator network should be valid, i.e., they get opposite predictions from the predictor network (e.g. ŷx + ŷx ′ = 1); and (iii) minimizing cost of change -minimal modifications should be required to change input instance x to CF example x ′ . Thus, we formulate this multi-objective minimization problem to optimize the parameter of overall network θ: argmin θ 1 N N i=1 λ 1 • (y i -ŷxi ) 2 Prediction Loss (L1) + λ 2 • ŷxi -1 -ŷx ′ i 2 Validity Loss (L2) + λ 3 • (x i -x ′ i ) 2 Change Loss (L3) where N denotes the number of instances in our dataset, (λ 1 , λ 2 , λ 3 ) are hyper-parameters to balance the three loss components, the prediction loss L 1 denotes the mean squared error (MSE) between the actual and the predicted labels (y i and ŷxi on instance x i , respectively), which aims to maximize predictive accuracy. Similarly, the validity loss L 2 denotes the MSE between the prediction on instance x i (i.e., ŷxi ), and the opposite of the prediction received by the corresponding CF example x ′ i (i.e., 1 -ŷx ′ i ). Intuitively, minimizing L 2 maximizes the validity of the generated CF example x ′ i by ensuring that the predictions on x ′ i and x i are different. Finally, the proximity loss L 3 represents the MSE distance between input instance x i and the CF example x ′ i , which aims to minimize proximity. This choice of loss functions is crucial to CounterNet's superior performance, as replacing L 1 , L 2 and L 3 with alternative functions leads to degraded performance (as we show in Section 4).

3.3. TRAINING PROCEDURE

The conventional way of solving the optimization problem in Eq. 1 is to use gradient descent with backpropagation (BP). However, directly optimizing the objective function (Eq. 1) as-is results in two fundamental issues: (1) poor convergence in training (shown in Lemma 3.1), and (2) proneness to adversarial examples (shown in Lemma 3.2). Issue I: Poor Convergence. Optimizing Eq. 1 as-is via BP leads to poor convergence. This occurs because Eq. 1 contains two different loss objectives with divergent gradients, as Lemma 3.1 shows the gradients of L 1 and L 2 move in opposite directions. Consequently, the accumulated gradient across all three loss objectives fluctuates drastically, which leads to difficulty in training. Lemma 3.1 (Divergent Gradient Problem). Let L 1 = ∥y -ŷx ∥ 2 , and L 2 = ∥ŷ x -(1 -ŷx ′ )∥ 2 , if x ′ → x, 0 < ŷx < 1, y is a binary label, and |ŷ x -y| < 0.5, then ∇L 1 • ∇L 2 < 0. Issue II: Adversarial Examples. Our training procedure should generate high-quality CF examples x ′ for input instances x without sacrificing the adversarial robustness of the predictor network. Unfortunately, optimizing Eq. 1 as-is is at odds with the goal of achieving adversarial robustness. Lemma 3.2 shows that optimizing L 2 with respect to the predictive weights θ f decreases the robustness of the predictor f (•) (by increasing the Lipschitz constant of f (•) (Hein & Andriushchenko, 2017; Wu et al., 2021) ), leading to its increased vulnerability to adversarial examples. Proof of Lemma 3.1 and 3.2 can be found in Appendix A. Lemma 3.2 (Lipschitz Continuity). Suppose f is a locally Lipschitz continuous function parameterized by θ, then it satisfies |f θ (x) -f θ (x ′ )| ≤ K ∥x -x ′ ∥ 2 , where its Lipschitz constant K = sup x ′ ∈B(x,ϵ) {∥∇f θ (x ′ )∥ 2 }. Let L 2 = ∥f θ (x) -(1 -f θ (x ′ )∥ 2 , if x ′ → x, 0 < f θ (•) < 1, f (x) → y, and y is a binary label, then minimizing L 2 w.r.t. θ increases the Lipschitz constant K. We remedy these issues as follows: (1) to handle poor convergence in training, we adopt a block-wise coordinate descent procedure, which divides the problem of optimizing Eq. 1 into two parts: (i) optimizing predictive accuracy (primarily influenced by L 1 ); and (ii) optimizing the validity and proximity of CF generation (primarily influenced by L 2 and L 3 ). Specifically, for each mini-batch of m data points {x (i) , y (i) } m , we apply two gradient updates to the network through backpropagation. For the first update, we compute θ (1) = θ (0) -∇ θ (0) (λ 1 • L 1 ), and for the second update, we compute θ (2) = θ (1) -∇ θ (1) (λ 2 • L 2 + λ 3 • L 3 ). This procedure ensures that gradients for L 1 and L 2 are calculated separately, which lessens the divergent gradient problem (Lemma 3.1), and leads to significantly better convergence. (2) Moreover, to improve adversarial robustness of our predictor network, during the second stage of our coordinate descent procedure (when we optimize for λ 2 • L 2 + λ 3 • L 3 )), we only update the weights in the CF generator θ g and freeze gradient updates in both the encoder θ h and predictor θ f networks. More formally, instead of updating the weights θ of the entire network during the second update, we only update the CF generator weights θ g as follows: θ (2) g = θ (1) g -∇ θ (1) g (λ 2 • L 2 + λ 3 • L 3 ). This ensures that the Lipschitz constant of the predictor network does not increase (Lemma 3.2).

4. EXPERIMENTAL EVALUATION

We primarily focus our evaluation on heterogeneous tabular datasets for binary classification problems (which is the most common and reasonable setting for CF explanations (Verma et al., 2020; Stepin et al., 2021) ). However, CounterNet can be applied to multi-class classification settings, and it can also be adapted to work with other modalities of data, e.g., images, etc. (as shown in Appendix G, I). Baselines. We compare CounterNet against seven state-of-the-art CF explanation methods: (i) VanillaCF (Wachter et al., 2017) Unlike CounterNet, all of the post-hoc methods require a trained predictive model as input. Thus, for each dataset, we train a neural network model and use it as the target predictive model for all baselines. For a fair comparison, we only keep the encoder and predictor network inside CounterNet's architecture (Figure 2 ), and optimize them for predictive accuracy alone (i.e., L 1 ). This combination of encoder and predictor networks is then used as the black-box predictive model for our baselines. Datasets. To remain consistent with prior work on CF explanations (Verma et al., 2020) , we evaluate CounterNet on four benchmarked real-world binary classification datasets: (i) Adult (Kohavi & Becker, 1996) which predicts whether an individual's income reaches $50K (Y=1) or not (Y=0); (ii) Credit (Yeh & Lien, 2009) which uses historical payments to predict default of payment (Y=1) or not (Y=0); (iii) HELOC (FICO, 2018) which predicts if a homeowner qualifies for credit (Y=1) or not (Y=0); (iv) OULAD (Kuzilek et al., 2017) which predicts whether MOOC students drop out (Y=1) or not (Y=0). We also provide experiments on four additional datasets in Appendix E. Evaluation Metrics. For each input x, CF explanation methods generate two outputs: (i) a prediction ŷx ; and (ii) a CF example x ′ . We evaluate the quality of both these outputs using separate metrics. For evaluating predictions, we use predictive accuracy (as all four datasets are fairly class-balanced). For evaluating CF examples, we use five widely used metrics from prior literature (see Appendix B.3 for formal definitions): (i) Validity is the fraction of input instances on which CF explanation methods output valid CF examples, i.e., the fraction of input data points for which ŷx + ŷx ′ = 1. High validity is desirable, as it implies the method's effectiveness at creating valid CF examples (Mothilal et al., 2020; Upadhyay et al., 2021) . (ii) Proximity is the L 1 norm distance between x and x ′ divided by the number of features (Wachter et al., 2017; Mothilal et al., 2020) . (iii) Sparsity is the number of feature changes between x and x ′ (normalized by the total number of features) (Wachter et al., 2017; Poursabzi-Sangdeh et al., 2021) . Proximity and sparsity serve as proxies for measuring the cost of change of our CF explanation approach, as it is desirable to have fewer modifications in the input space to convert it into a valid CF example. (iv) Manifold distance is the L 1 distance to the k-nearest neighbor of x ′ (we use k = 1 based on (Verma et al., 2022) ). Low manifold distance is desirable as closeness to the training data manifold indicates realistic CF explanations (Van Looveren & Klaise, 2019; Verma et al., 2022) . (v) Finally, we also report the runtime for generating CF examples. 2 compares the validity achieved by CounterNet and baselines on all four datasets. We observe that CounterNet, C-CHVAE, and VCNet are the only three methods with 100% validity on all datasets. With respect to the other baselines, CounterNet achieves 8% and 12.3% higher average validity (across all datasets) than VanillaCF and ProtoCF (our next best baselines).

4.1. EVALUATION OF COUNTERNET PERFORMANCE

Proximity & Sparsity. Table 2 compares the proximity/sparsity achieved by all methods. CounterNet achieves at least 3% better proximity than all other baselines on three out of four datasets (Adult, HELOC, and OULAD), and it is the second best performing model on the Credit dataset (where it achieves 7.3% poorer proximity than VanillaCF). In terms of sparsity, CounterNet performs reasonably well, it is the second best performing model on the Adult and HELOC datasets even though CounterNet does not explicitly optimize for sparsity. This shows that CounterNet outperforms all baselines by generating CF examples with the highest validity and best proximity scores. Cost-Invalidity Trade-off. We illustrate the cost-invalidity trade-off (Rawal et al., 2020) for all methods. Figure 3 shows that CounterNet lies on the bottom left of this figure -it consistently achieves the lowest invalidity and cost on all four datasets. In comparison, VCNet achieves the same perfect validity, but at the expense of ∼34% higher cost than CounterNet. Similarly, C-CHVAE demands ∼71% higher cost than CounterNet to achieve perfect validity. On the other hand, VanillaCF achieves comparable cost to CounterNet (10% higher cost), but it achieves lower validity by 8%. This shows that CounterNet's joint training enables it to properly balance the cost-invalidity trade-off. Adversarial Robustness. We illustrate that CounterNet does not suffer from decreased robustness of the predictor network resulting from optimizing for the validity loss L 2 (as shown in Lemma 3.2). We compare the robustness of CounterNet's predictor network f (•) against two baselines: (i) the base predictive model described in Table 1 ; and (ii) CounterNet without freezing the predictor at the second stage of our coordinate descent optimization (CounterNet-NoFreeze). Figure 4 illustrates the perturbation stability (Wu et al., 2021) of all three CounterNet variants against adversarial examples (generated via projected gradient descent (Madry et al., 2018) ). CounterNet achieves comparable perturbation stability as the base model, which indicates that CounterNet reaches its robustness upper bound (i.e., the robustness of the base model). Moreover, the empirical results in Figure 4 confirm Lemma 3.2 as CounterNet-NoFreeze achieves significantly poorer stability. We observe similar patterns with different attack methods on other datasets (see Appendix C.1). These results show that by freezing the predictor and encoder networks at the second stage of our coordinate descent procedure, CounterNet suffers less from the vulnerability issue created by the adversarial examples. 

6. ETHICS & REPRODUCIBILITY STATEMENT

Ethics Statement. Although CounterNet is suitable for real-time deployment given its superior performance in its highly aligned CF explanations and speed, one must be aware of the possible negative impacts of its CF explanations to human end-users. It is important to ensure that generated CF examples do not amplify or provide support to the narratives resulting from pre-existing racebased and gender-based societal inequities (among others). As we stated in Section 5, one short-term workaround is to have humans in the loop. We can provide CounterNet's explanations as a decisionaid to a well-trained human official, who is in charge of communicating the decisions of ML models to human end-users in a respectful and humane manner. In the long-run, further qualitative and quantitative studies are needed to understand the social impacts of CounterNet. Reproducibility Statement. To aid the reproducibility of this work, we provide the code in the supplement material, and will make the code public once it is accepted. We also provide the dataset used for evaluating this paper in this anonymous repository. In addition, we outline the choices of hyperparameters in Appendix B.4. A detailed description of our experimental implementation can be found in Appendix B. For theoretical analysis of our novel block-wise coordinate descent procedure, we provide complete proof in Appendix A.

A SUPPLEMENTAL PROOF

A.1 PROOF OF LEMMA 3.1 Proof. ∇ θ L 1 = ∇ θ ∥y -ŷx ∥ 2 , and ∇ θ L 2 = ∇ θ ∥(1 -ŷx ) -ŷx ′ ∥ 2 . Then, we have ∇ θ L 1 = ∇ θ y 2 -2y • ∇ θ • ŷx + ∇ θ ŷ2 x = -2y • ∇ θ ŷx + ∇ θ ŷ2 x = -2y • ∇ θ ŷx + 2ŷ x ∇ θ ŷx = 2(ŷ x -y) • ∇ θ ŷx Since x ′ → x, as we expect CF example x ′ is closed to the original instance x, we can replace x ′ to x in L 2 . Then, we have ∇ θ L 2 = ∇ θ (1 -2ŷ x ) 2 = -4 • ∇ θ ŷx + 4 • ∇ θ ŷ2 x = -4 • ∇ θ ŷx + 4 • 2 • ŷx • ∇ θ ŷx = 4 • (2ŷ x -1)∇ θ ŷx Hence, ∇ θ L 1 • ∇ θ L 2 = 2 • (ŷ x -y) • ∇ θ ŷx • 4 • (2ŷ x -1)∇ θ ŷx = 8 • (ŷ x -y) • (2ŷ x -1)(∇ θ ŷx ) 2 Since (∇ θ ŷx ) 2 > 0, we only need to prove whether (ŷ x -y) • (2ŷ x -1) is positive or negative. Given that |ŷ x -y| < 0.5, • if y = 1, we have 0.5 < ŷx < 1. Then, (ŷ x -y) < 0, (2ŷ x -1) > 0. • if y = 0, we have 0 < ŷx < 0.5. Then, (ŷ x -y) > 0, (2ŷ x -1) < 0. Therefore, (ŷ x -y) • (2ŷ x -1) < 0. Hence, ∇ θ L 1 • ∇ θ L 2 < 0. A.2 PROOF OF LEMMA 3.2 Proof. Assuming f θ (x) → y as we expect the predictor network produces accurate predictions, and y = {0, 1}, we can replace f θ (x) to y. Then, minimizing L 2 (in Lemma 3.2) indicates minimizing ∥y - (1 -f θ (x ′ ))∥ 2 . Since 0 < f θ (•) < 1, we have min ∥y -(1 -f θ (x ′ ))∥ 2 = max ∥y -f θ (x ′ )∥ 2 By replacing y to f θ (x), then minimizing L 2 indicates maximizing ∥f θ (x) -f θ (x ′ )∥ 2 . By definition, the lipschitz constant K is K = sup x ′ ∈B(x,ϵ) {∥∇f θ (x ′ )∥ 2 = sup x ′ ∈B(x,ϵ) ∥f θ (x) -f θ (x ′ )∥ ∥x -x ′ ∥ where minimizing L 2 increases ∥f θ (x) -f θ (x ′ )∥ 2 . Therefore, the lipschitz constant K increases.

B IMPLEMENTATION DETAILS

Here we provide implementation details of CounterNet and five baselines on four datasets listed in Section 4. The code can be found in the supplemental material. B Here, we reiterate our used datasets for evaluations. Our evaluation is conducted on eight widely-used tabular datasets. Our primary evaluation uses four large-sized datasets (shown in Section 4), including Adult, Credit, HELOC, and OULAD, which contain at least 10k data instances. In addition, we experiment with four small-sized datasets, including Student, Titanic, Cancer, and German. Table 6 summarizes datasets used for evaluations.

B.3 EVALUATION METRICS

Here, we provide formal definitions of the evaluation metrics. Predictive Accuracy is defined as the fraction of the correct predictions. Predictive-Accuracy = #|f (x) = y| n (2) Validity is defined as the fraction of input instances on which CF explanation methods output valid CF examples. Validity = #|f (x ′ ) = 1 -y| n (3) Proximity is defined as the L 1 norm distance between x and x ′ divided by the number of features. Proximity = 1 nd n i=1 d j=1 ∥x (j) i -x ′(j) i ∥ 1 (4) Sparsity is defined as the fraction of the number of feature changes between x and x ′ . Sparsity = 1 nd n i=1 d j=1 ∥x (j) i -x ′(j) i ∥ 0

B.4 COUNTERNET IMPLEMENTATION DETAILS

Across all six datasets, we apply the following same settings in training CounterNet: We initialize the weights as in He et al. (2016) . We adopt the Adam with mini-batch size of 128. For each datasets, we trained the models for up to 1 × 10 3 iterations. To avoid gradient explosion, we apply gradient clipping by setting the threshold to 0.5 to clip gradients with norm above 0.5. We set dropout rate to 0.3 to prevent overfitting. For all six datasets, we set λ 1 = 1.0, λ 2 = 0.2, λ 3 = 0.1 in Equation 1. The learning rate is the only hyper-parameter that varies across six datasets. From our empirical study, we find the training to CounterNet is sensitive to the learning rate, although a good choice of loss function (e.g. choosing MSE over cross-entropy) can widen the range of an "optimal" learning rate. We apply grid search to tune the learning rate, and our choice is specified in Table 7 . Additionally, we specify the architecture's details (e.g. dimensions of each layer in encoder, predictor and CF generator) in Table 7 . The numbers in each bracket represent the dimension of the transformed matrix. For example, the encoder dimensions for adult dataset is [29, 50, 10] , which means that the dimension of input x ∈ R d is 29 (e.g. d = 29); the encoder first transforms the input into a 50 dimension matrix, and then downsamples it to generate the latent representation z ∈ R k where k = 10. Next, we describe the implementation of baseline methods. For VanillaCF and ProtoCF, we follow author's instruction as much as we can, and implement them in Pytorch. For VanillaCF, DiverseCF and ProtoCF, we run maximum 1 × 10 3 steps. After CF generation, we convert the results to one-hot-encoding format for each categorical feature. For training the VAE-CF, we follow Mahajan et al. (2019) 's settings on running maximum 50 epoches and setting the batch size to 1024. We use the same learning rate as in Table 7 for VAE training. For training predictive models for baseline algorithms, we apply grid search for tuning the learning rate, which is specified in Table 8 . Similar to training the CounterNet, we adopt the Adam with mini-batch size of 128, and set the dropout rate to 0.3. We train the model for up to 100 iterations with early stopping to avoid overfittings.

C ADDITIONAL EXPERIMENTAL RESULTS

Here, we provide additional results of experiments in Section 4. These results further demonstrate the effectiveness of CounteNet.

C.1 ADDITIONAL ROBUSTNESS RESULTS

We provide supplementary results on evaluating the robustness of the predictor network on three large datasets (i.e., Adult, HELOC and OULAD). In particular, we implement FSGM (Goodfellow et al., 2015) and PGD (Madry et al., 2018) attack for testing the robustness of the predictive models. Figure 5 illustrates that CounterNet achieves comparable perturbation stability (i.e., the robustness of the predictive model) as the base model. In addition, Figure 5 supports the findings in Lemma 3.2 since CounterNet-NoFreeze consistently achieves lower stability than base models and CounterNet. (Goodfellow et al., 2015) and PGD (5d-5f) (Madry et al., 2018) 

C.2 TRAINING TIME OF COUNTERNET

Table 9 shows the training time of base model and CounterNet for each epoch in seconds. CounterNet takes roughly 3X times training time as compared to the base model (which is only trained for the predictive performance). Note that the training time is a secondary metric for evaluating the speed, as the training of the model only occurs once, whereas the inference time (i.e., runtime in Table 2 ) is a more important metric, as it keeps increasing during the deployment stage. Note that Eq. 6 looks identical to Eq. 1. The only difference is that y i in Eq. 1 is replaced to M (x i ). argmin θ 1 N N i=1 λ 1 •(M (x i ) -ŷxi ) 2 Prediction Loss (L1) + λ 2 • ŷxi -1 -ŷx ′ i 2 Validity Loss (L2) + λ 3 • (x i -x ′ i ) 2 Proximity Loss (L3) Table 10 shows the performance of CounterNet under the black-box setting (CFNET-BB). CFNET-BB degrades slightly in terms of validity, average L 1 to CounterNet. This is because approximating the black-box model leads to degraded performance in the quality of generating CF explanations.

D.2 ABLATIONS ON COUNTERNET'S TRAINING

In addition, we provide supplementary results on ablation analysis of three large datasets (Adult, HELOC, and OULAD) to understand the design choices of the CounterNet training, shown in Figure 6 ). This figure shows that compared to CounterNet's learning curve for L 2 , CounterNet-BCE and CounterNet-NoSmooth's learning curves show significantly higher instability, illustrating the importance of MSE-based loss functions and label smoothing techniques. Moreover, CounterNet-SingleBP's learning curve for L 2 performs poorly in comparison, which illustrates the difficulty of optimizing three divergent objectives using a single BP procedure. In turn, this also illustrates the effectiveness of our block-wise coordinate descent optimization procedure in CounterNet's training. These results show that all design choices made in Section 3 contribute to training the model effectively. In addition, we experiment with alternative loss formulations. We replace the MSE based L 3 loss in Eq. 1 with l 1 norm (CounterNet-l 1 ). Table 11 shows that replacing L 3 with a l 1 formulation leads to a degraded performance.

E EXPERIMENTAL EVALUATION ON SMALL-SIZED DATASETS

In addition to four large datasets in Section 4, we experiment with four small-sized datasets: (i) Breast Cancer Wisconsin (Blake, 1998) which classifies malignant (Y=1) or benign (Y=0) tumors; (ii) Student Performance (Cortez & Silva, 2008) which predicts whether a student will pass (Y=1) or fail (Y=0) the exam; (iii) Titanic (Kaggle, 2018) which predicts whether passengers survived (Y=1) the Titanic shipwreck or not (Y=0); and (iv) German Credit (Asuncion & Newman, 2007) which predicts whether the credit score of a customer is good (Y=1) or bad (Y=0). 

F SECOND-ORDER EVALUATION

We define three additional second-order metrics which attempt to evaluate the usability of CF explanation techniques by human end-users. We posit that negligible feature differences (among continuous features) between instance x and CF example x ′ make it difficult for human end-users to use CF example x ′ (as many of the recourse recommendations contained within x ′ may not be actionable due to negligible differences). For example, human end-users may find it impossible to increase their Daily_Sugar_Consumed by 0.523 grams (if the value of Daily_Sugar_Consumed feature is 700 and 700.523 between x and x ′ , respectively). As such, human end-users may be willing to ignore small feature differences between x and x ′ . To define our usability related metrics, we construct a user-friendly second-order CF example x ′′ by ignoring small feature differences (i.e., |x i -x ′ i | is less than threshold b) between instance x and CF example x ′ . Formally, let x = {x 1 , x 2 , .., x d } and x ′ = {x ′ 1 , x ′ 2 , .., x ′ d } be the features of the input instance and the CF example, respectively. Then, we use a threshold of b, and create a new data point x ′′ = {l i = 1 |xi-x ′ i |≤b x i + 1 |xi-x ′ i |>b x ′ i ∀i ∈ 1 . . . d}, i. e., we replace all features i ∈ {1, d} in CF example x ′ with features in the original input instance x for which |x i -x ′ i | ≤ b. Our metrics for CF usability are defined in terms of x and x ′′ as follows: • Second-Order Validity is defined as the fraction of input instances on which x ′′ remains a valid CF example. High second-order validity is desirable, because it implies that despite ignoring small feature differences, the second-order CF example x ′′ remains valid. • Second-Order Proximity is defined as the L 1 norm distance between x and x ′′ . It is desirable to maintain low second-order proximity because it indicates fewer cumulative modifications in the input space. • Second-Order Sparsity is defined as the number of feature changes (i.e., L 0 norm) between x and x ′′ . High second-order sparsity enhances the interpretability of a CF explanation. Note that second-order sparsity is more important than the original sparsity metric, as the second-order CF example x ′′ ignores small feature changes in the continuous features, yielding fewer number of feature changes in the input space.

F.1 EXPERIMENTAL RESULTS

The evaluation of counterfactual usability measures the quality of the second-order CF example x ′′ which is created by ignoring negligible differences between input instance x and the CF example x ′ . We use a fixed threshold b = 2 to derive the "sparse" second-order CF example x ′′ , and compute the second-order evaluation metrics. Second-order validity. 2 , CounterNet performs consistently well across all six datasets on the validity metric, as CounterNet is the only CF explanation method which achieves over 93.7% second-order validity on all six datasets. In particular, CounterNet achieves ∼11% higher second-order validity than C-CHVAE (its closest competitor) on all six datasets. Further, CounterNet is the only CF method which achieves more than 90% second-order validity on the Breast Cancer dataset, whereas all post-hoc baselines perform poorly (none of them achieve second-order validity higher than 70%), despite the fact that three of these baselines (VanillaCF, DiverseCF, and ProtoCF) achieved more than 99% first-order validity on this dataset. This result demonstrates that CounterNet is much more robust against small perturbations in the continuous feature space. Second-order Sparsity and Proximity. In terms of second-order sparsity, CounterNet outperforms two parametric CF explanation methods (C-CHVAE and VAE-CF), and maintains competitive performance against two non-parametric methods (VanillaCF and ProtoCF). Across all six datasets, CounterNet outperforms C-CHVAE and VAE-CF by ∼10% on the this metric. Moreover, the difference between the second-order sparsity achieved by CounterNet and VanillaCF (and ProtoCF) is close to 1%, which indicates that CounterNet achieves the same level of second-order sparsity as these two non-parametric methods. In terms of second-order proximity, CounterNet is highly proximal against baseline methods as it achieves the lowest proximity in HELOC, OULAD, and Titanic datasets (similar to results in Table 2 ). Cost-Invalidity Trade-off. Figure 7 shows that Counter-Net positions on the bottom left of this figure, which illustrates that CounterNet can balance the cost-invalidity trade-off in the counterfactual usability evaluation. Notably, CounterNet outperforms all post-hoc methods in the second-order invalidity metric, and maintains the same level of second-order sparsity as VanillaCF and ProtoCF (∼1% difference). Moreover, although Diver-seCF achieves ∼10% lower second-order sparsity value than CounterNet, it has ∼50% higher second-order invalidity than CounterNet. This results from DiverseCF's inability to balance the the trade-off between secondorder invalidity and sparsity. This high second-order invalidity of DiverseCF hampers its usability, even though it generates more sparse explanations. We illustrate how CounterNet generates interpretable explanations for end-users. Figure 8 show an actual data point x from the Adult dataset, and the corresponding CF explanation x ′ generated by CounterNet. This figure shows that x and x ′ differ in three features. In addition, CounterNet generates x ′′ by ignoring feature changes that are less than threshold b = 2 (in practice, domain experts can help identify realistic values of b). Note that due to CounterNet's high second-order validity, x ′′ also remains a valid CF example. After this post-processing step, x and x ′′ differ in exactly two features, and the end-user is provided with the following natural-language explanation: "If you want the ML model to predict that you will earn more than US$50K, change your education from HS-Grad to Doctorate, and reduce the number of hours of work/week from 48 to 33.5." Figure 8 : A counterfactual explanation from CounterNet.



CounterNet can work with alternate neuronal blocks, e.g., convolution, attention, although effective training of these neuronal blocks demands additional efforts (see Appendix H). Note thatYang et al. (2021) propose another parametric post-hoc method, but we exclude it in our baseline comparison because it achieves comparable performance to C-CHVAE (as reported in(Yang et al., 2021)).



Figure 1: Illustration of the costinvalidity trade-off in CF explanations for binary classification problems.

-which generates CF examples by optimizing CF validity and proximity; (ii) DiverseCF (Mothilal et al., 2020), ProtoCF (Van Looveren & Klaise, 2019), and UncertainCF (Schut et al., 2021) -which optimizes for diversity, consistency with prototypes, and uncertainty, respectively; (iii) VAE-CF (Mahajan et al., 2019), CounteRGAN (Nemirovsky et al., 2022), C-CHVAE (Pawelczyk et al., 2020), and VCNet (Guyomard et al., 2022) -which rely on generative models (i.e., VAE or GAN) to generate CF examples 3 .

Figure 3: Illustration of the cost-invalidity trade-off across all four datasets. Methods at the bottom left are preferable.

Figure 4: Robustness of the predictor f (•). CounterNet reaches the upper bound of robustness (i.e., comparable to the base model).Table 4 compares the validity and proximity achieved by CounterNet and five ablations. Importantly, each ablation leads to degraded performance as compared to CounterNet, which demonstrates Counter-Net's different design choices. CounterNet-BCE and CounterNet-SingleBP perform poorly in comparison, which illustrates the importance of the MSEbased loss function and block-wise coordinate descent procedure. Similarly, CounterNet-Separate and CounterNet-NoPass-p x achieve degraded validity and proximity scores, which highlight the importance of CounterNet's architecture design. Finally, CounterNet-Posthoc achieves comparable validity as CounterNet, but fails to match the performance of proximity. This result demonstrates the importance of the joint-training procedure of CounterNet in optimally balancing the cost-invalidity trade-off.

Robustness of f (•) under PGD attack on the OULAD dataset. Robustness of f (•) under PGD attack on the HELOC dataset.

Figure5: Robustness of f (•) under FSGM attack (5a-5c)(Goodfellow et al., 2015) and PGD (5d-5f)(Madry et al., 2018) attack.

Figure7: Illustration of trade-off between invalidity and sparsity across six datasets (methods at the bottom left are preferable).

Predictive Accuracy of CounterNet

Evaluation of CF explanations: CounterNet achieves perfect validity (i.e., Val.), and it incurs comparable (or lesser) cost of changes (i.e., Prox, Spar.) than baseline methods, with comparable manifold distance (i.e., Man.). Bold and italicized cells highlight the best and secondbest performing methods, respectively. Prox. Spar. Man. Val. Prox. Spar. Man. Val. Prox. Spar. Man. Val. Prox. Spar. Man. CounterNet 1.00 .196 .644 0.64 1.00 .132 .912 0.56 1.00 .125 .740 0.56 1.00 .075 .725 0.87 benefits achieved by CounterNet's joint training of predictor and CF generator networks do not come at a cost of reduced predictive accuracy.

Runtime comparison (in milliseconds).CounterNet outperforms all of baselines in runtime.

Ablation analysis of CounterNet. Each ablation leads to degraded performance, which in turn, demonstrates the importance of different design choices inside CounterNet. Table2shows that CounterNet achieves the second-lowest manifold distance in average (right below VCNet, which explicitly optimizes for data manifold). In particular, CounterNet achieves the lowest manifold distance in OULAD, and is ranked second in Credit and HELOC. This result shows that CounterNet generates highly realistic CF examples that adhere to the data manifold.Running Time. Table3shows the average runtime (in milliseconds) of different methods to generate a CF example for a single data point. CounterNet outperforms all seven baselines in every dataset. In particular, CounterNet generates CF examples ∼3X faster than VAE-CF, CouneRGAN, and VCNet, ∼5X faster than C-CHVAE, and three orders of magnitude (>1000X) faster than other baselines. This result shows that CounterNet is more usable for adoption in time-constrained environments.

Impact of the immutable feature constraints in CounterNet. CounterNet generates feasible CF explanations without sacrificing validity and proximity.

Summary of Datasets used for Evaluation

Hyperparameters and architectures for each dataset.

Learning rate of the base predictive models on each dataset.

Training time of base model and CounterNet for each epoch (in second). CounterNet can be adapted to the post-hoc black-box setting. In this setting, CF explanation methods generate CF explanations for a trained black-box model (with access to the model's output). CounterNet can also be used in this post-hoc setting by forcing the predictor network to surrogate the black-box model. Specifically, let a black-box model M : X → Y outputs the predictions, our goal of training the predictor is to ensure that the predictor model behaves like the black-box model (i.e., M(x) = f (x)). The training objective of CounterNet is

Evaluation of CounterNet under the post-hoc setting. CFNET-BB represents the CounterNet evaluated under the black-box setting. CFNET-PH represents the CounterNet trained via a post-hoc fashion, which in turn, demonstrates the importance of joint-training procedure in CounterNet.

Ablation analysis of CounterNet. Each ablation leads to degraded performance, which in turn, demonstrates the importance of different design choices inside CounterNet.

compares the validity, average L 1 and sparsity achieved by CounterNet and baselines. Similar to results in Table2, CounterNet achieves a perfect validity. In addition, CounterNet achieves the lowest proximity in three out of four small datasets. This result further shows CounterNet's ability in balancing the cost-invalidity trade-off.

Evaluation of counterfactual explanations on four small-sized datasets.Prox. Spar. Val. Prox. Spar. Val. Prox. Spar. Val. Prox. Spar.

Table 13 compares the second-order validity of CF examples generated by CounterNet and other baselines on all six datasets. Similar to results in Table

Evaluation of Usability of Counterfactual Explanations

Results for CF explanation methods on Forester Cover Type dataset. further study the impact of the different neural network blocks. In our experiment, we primarily use multi-layer perception as it is a suitable baseline model for the tabular data. For comparison, We also implemented the CounterNet with Convolutional building blocks (i.e. replace the feed forward neural network with convolution layer). We implemented the convolutional CounterNet on the Adult dataset. To train the feed forward neural network with convolution layers, we set the learning rate as 0.03 and λ 1 = 1.0, λ 2 = 0.4, λ 3 = 0.01. The rest of the configuration is exactly the same as training CounterNet with MLP.Table15shows comparison between CounterNet with convolutional building blocks (CounterNet-Conv) and multi-layer perceptions (CounterNet-MLP). The results indicate that CounterNet-Conv matches the performances of CounterNet-MLP. In fact, CounterNet-Conv performs slightly worse than CounterNet-MLP because convolutional block is not well-suitable for tabular datasets. Yet, CounterNet-Conv outperforms the rest of our post-hoc baselines in validity (with reasonably good proximity score). This illustrates CounterNet's potential real-world usage in various settings as it is agnostic to the network structures.

Results for the CounterNet with Convolution layers on Adult dataset. is designed to generate counterfactual explanations for tabular datasets (the most common use case for CF explanations). We also experiment with CounterNet on the image datasets. This experiment uses the MNIST dataset: class "7" is used as the positive label, and class "1" is used as the negative label. Next, we apply the same CounterNet training procedure to generate image counterfactuals. Table16demonstrates the results of CounterNet on the MNIST dataset. CounterNet achieves 52.4% validity with 0.059 average L 1 distance. This result shows a current limitation of CounterNet as applying CounterNet as-is is ill-suited for generating image counterfactual explanations.

CounterNet on the Image Datasets.

G COUNTERNET UNDER THE MULTI-CLASS SETTINGS

In prior CF explanation literature, counterfactual explanations are primarily evaluated under the binary classification settings Mothilal et al. (2020); Mahajan et al. (2019) ; Upadhyay et al. (2021) . However, it is worthnoting that CF explanation methods (including CounterNet) can be adapted to the multi-class classification settings. This section first describes the problem setting of the CF explanations when dealing with multi-class classification. Next, we describe how to train CounterNet for multi-class predictions and CF explanations. Finally, we present the evaluation set-up and show the simulation results.

G.1 TRAINING COUNTERNET FOR MULTI-CLASS CLASSIFICATION

Given an input instance x ∈ R d , CounterNet aims to generate two outputs: (i) a prediction ŷx ∈ R k for input instance x; and (ii) the CF example x ′ ∈ R d as an explanation for input instance x. The prediction ŷx ∈ R k is encoded as one-hot format as ŷx ∈ {0, 1} k , where k i ŷ(i) x = 1, k denotes the number of classes. Moreover, we assume that there is a desired outcome y ′ for every input instances x. Then, it is desirable that a CF explanation y x ′ needs to be predicted as the desired outcome y ′ (i.e., y x ′ = y ′ ).The objective for CounterNet in the multi-class setting remains the same as in the binary setting. Specifically, we expect CounterNet to achieve high predictive accuracy, counterfactual validity and proximity. As a result, we adjust loss functions from Eq. 1 as follows:Same as training CounterNet in the binary setting, we optimize the parameter θ of the overall network by solving the minimization problem in Eq. 1 to (except that we are switching to use loss functions in Eq. 7). Moreover, we adopt the same block-wise coordinate optimization procedure to solve this minimization problem by first updating for predictive accuracy θ ′ = θ -∇ θ (λ 1 • L 1 ), and then updating for CF explanation

G.2 EXPERIMENTAL EVALUATION

Dataset. We use Cover Type dataset Blackard (1998) for evaluating the multi-class classification experiment. Cover Type dataset predicts forest cover type from cartographic variables. This dataset contains seven classes (e.g., Y=1, Y=2, ..., Y=7), with 10 continuous features. For CF explanation generation, we assume that cover type 5 (e.g., Y=5) is the desired class. The original dataset is highly imbalanced, so we equally sample data instances from each class.Results. 

