LEARNING AXIS-ALIGNED DECISION TREES WITH GRADIENT DESCENT

Abstract

Decision Trees are commonly used for many machine learning tasks due to their high interpretability. However, learning a decision tree from data is a difficult optimization problem, since it is non-convex and non-differentiable. Therefore, common approaches learn decision trees using a greedy growth algorithm that minimizes the impurity at each internal node. Unfortunately, this greedy procedure can lead to suboptimal trees. In this paper, we present a novel approach for learning hard, axis-aligned decision trees with gradient descent. This is achieved by applying backpropagation with a straight-through operator on a dense decision tree representation that jointly optimizes all decision tree parameters. We show that our gradient-based optimization outperforms existing baselines on several binary classification benchmarks and achieves competitive results for multi-class tasks. To the best of our knowledge, this is the first approach that attempts to learn hard, axis-aligned decision trees with gradient descent without restrictions regarding the structure.

1. INTRODUCTION

Decision trees (DTs) are one of the most popular machine learning models and are still frequently used today. Especially with the increasing interest in explainable artificial intelligence (XAI), DTs regained popularity. However, learning a DT is a difficult optimization problem, since it is nonconvex and non-differentiable. Finding an optimal DT for a specific task is a NP-complete problem (Laurent & Rivest, 1976) . Therefore, the common approach to construct a DT based on data is a greedy procedure (Breiman et al., 1984; Quinlan, 1993) that minimizes the impurity at each internal node. The algorithms that are still used today, namely CART Breiman et al. (1984) and C4.5 Quinlan (1993) , date back until the 1980s and since then remained mostly unchanged. Unfortunately, a greedy algorithm only yields to locally optimal solutions and therefore can lead to suboptimal trees. We illustrate the issues that can arise when learning DTs with a greedy algorithm in the following example: Example 1 The Echocardiogram datasetfoot_0 deals with predicting one-year survival of patients after a heart attack based on tabular data from an echocardiogram. Figure 1 shows two decision trees. The left tree is learned using a greedy algorithm (CART) and the right tree is learned using our gradient-based approach. We can observe that the greedy procedure leads to a suboptimal tree with a significantly lower performance. Splitting on the wall-motion-score is the locally optimal split (see Figure 1a ), but globally, it is beneficial to split based on the wall-motion-score with different values conditioned on the pericarcial-effusion in the second level (Figure 1b ). The contribution of this paper is a novel approach for learning hard, axis-aligned DTs based on a joint optimization of all parameters with gradient descent, which we call gradient-based decision trees (GDTs). Using a gradient-based optimization, we can overcome the limitations of greedy approaches, as indicated in Figure 1b . We propose a suitable dense DT representation that allows a gradient-based optimization of the tree parameters (Section 3.2). We further present an algorithm that allows us to deal with the non-differentiable nature of DTs, while still allowing an efficient optimization using backpropagation with a straight-through (ST) operator (Section 3.3). with respect to the axis. These adjustments in the tree architecture allow the application of further optimization algorithms, as for instance the expectation-maximization (EM) algorithm for maximum likelihood estimation (Dempster et al., 1977; Jordan & Jacobs, 1994) or gradient descent (Irsoy et al., 2012; Frosst & Hinton, 2017) . Blanquero et al. (2020) try to increase the interpretability of oblique trees by optimizing for sparse splits using fewer predictor variables at each split and simultaneously fewer splits along the whole tree. Tanno et al. (2019) combine the benefits of neural networks and DTs using so-called adaptive neural trees (ANTs) that allow a gradient-based end-to-end optimization. They employ a stochastic routing based on a Bernoulli distribution where the mean is a learned parameter. Furthermore, ANTs utilize one or more non-linear transformer modules at the edges, making the resulting trees oblique. Norouzi et al. (2015) proposed an approach for efficient non-greedy optimization of DTs that overcomes the need of soft decisions to apply gradient-based algorithms. This includes a joint optimization of the split function at all levels and leaf parameters by minimizing a convex-concave upper bound on the tree's empirical loss. While this allows the use of hard splits, the approach is still limited to oblique trees. Zantedeschi et al. (2021) use argmin differentiation to simultaneously learn all tree parameters by relaxing a mixed-integer program for discrete parameters to allow a gradient-based optimization. This allows a deterministic tree routing which is similar to GDTs. However, they require a differentiable splitting function (e.g. a linear split function which results in oblique trees). In contrast to oblique DTs, GDTs make axis-aligned splits that only consider a single feature at each split, which provides a significantly higher interpretability for individual splits. This is supported by Molnar (2020) where the author argues that humans cannot comprehend explanations involving more than three dimensions at once. Yang et al. (2018) propose DNDTs which realize tree models as neural networks, utilizing a soft binning function for splitting. Therefore, the resulting trees are softfoot_1 , but axis-aligned which makes this work closely related to our approach. Since the tree is generated via Kronecker product of the binning layers the structure depends on the number of features (and the number of bins). As the authors discussed, this results in a poor scalability with the number of features, which can currently only be solved by using random forests for highdimensional datasets (≥ 12 features). Our approach, in contrast, scales linearly with the number of features, making it applicable for high-dimensional datasets as well.

3. GRADIENT-BASED AXIS-ALIGNED DECISION TREES

In this section, we propose gradient-based decision trees (GDTs) as a novel approach to learn DTs. Therefore, we introduce a novel DT representation and a corresponding algorithm that allows learning hard, axis-aligned DTs with gradient descent. More specifically, we will use backpropagation with a straight-through (ST) operator (Section 3.3) on a dense DT representation (Section 3.2) to adjust the model parameters during the training, as we will explain in Section 3.4.

3.1. ARITHMETIC DECISION TREE REPRESENTATION

In the following, we will introduce a notation for DTs with respect to their parameters. We formulate DTs as an arithmetic function based on addition and multiplication instead of a nested concatenation of rules. This formulation will be useful for describing the gradient-based learning in Sections 3.2-3.4. Note that we focus on learning fully-grown DTs within this paper. The parameters of a DT of depth d comprise one split threshold and one feature index for each internal node, which we denote as vectors τ ∈ R 2 d -1 and ι ∈ N 2 d -1 respectively, where 2 d -1 equals the number of internal nodes in a fully-grown DT.Additionally, each leaf node comprises a class membership in the case of a classification task or a value in the case of regression tasks, which we denote as the vector λ ∈ C 2 d , where C is the set of classes and 2 d equals the number of leaf nodes in a fully-grown DTs.Figure 2 shows our tree representation in comparison to a common tree representation for an exemplary DT. Formally, we can express a DT as a function DT (•|τ , ι, λ) : R n → C with respect to its parameters: DT (x|τ , ι, λ) = 2 d l=0 λ l L(x|l, τ , ι) L is a function that indicates whether a sample x ∈ R n belongs to a leaf l. The indicator function L can be defined as a multiplication of the split functions of the preceding internal nodes. We define a split function S as a Heaviside step function defined as: S Heaviside (x|ι, τ ) = 1, if x ι ≥ τ 0, otherwise where ι is the index of the feature considered at a certain split and τ is the corresponding split threshold. Enumerating the internal nodes of a fully-grown tree with depth d in a breadth-first order, we can now define the indicator function L as: L(x|l, τ , ι) = d j=1 (1 -p(l, d, j)) S(x|τ i(l,d,j) , ι i(l,d,j) ) +p(l, d, j) 1 -S(x|τ i(l,d,j) , ι i(l,d,j) ) Here, i is the index of the internal node preceding a specific leaf node l at a certain depth j. Further, p defines whether the left branch (p = 0) or the right branch (p = 1) was taken at a certain internal node to reach a leaf node l. Both values, i and p, are constant for a certain architecture and can be calculated straightforward (see Appendix A.1). As becomes evident, DTs involve non-differentiable operations (Equation 2), which precludes the application of the backpropagation algorithm for learning the parameters. More specifically, there are three challenges we need to solve if we want to use backpropagation to efficiently learn a DT: C1 The index ι for the selection of the considered feature in a certain split is defined as ι ∈ N. However, the index ι is a parameter of the DT and a standard optimization with gradient descent requires ι ∈ R.

C2

The split function S(x, ι, τ ) is a Heaviside step function with an undefined gradient for x ι = τ and 0 gradient elsewhere, which precludes an efficient optimization.

C3

The leaf node in a vanilla DT comprises only the class membership and therefore λ ∈ C. To optimize the leaf node parameters, we need λ ∈ R to apply gradient descent and calculate an informative loss value as cost function. The main difference of GDTs to a standard DT algorithm is that we use a dense representation for the internal node parameters (τ and ι) to jointly optimize the split thresholds and the selection of the corresponding features for all internal nodes (Section 3.2). Furthermore, we use a ST operator during the backpropagation to allow the use of hard and axis-aligned DTs during training (Section 3.3).

3.2. DENSE DECISION TREE REPRESENTATION

In this subsection, we propose a differentiable representation of the feature indices ι to allow a gradient-based optimization. Therefore, we extend the vector ι ∈ R 2 d -1 to a matrix, where I ∈ R 2 d -1 × R n . This is achieved by one-hot encoding the feature index as ι ∈ R n for each internal node. This adjustment is necessary for the optimization process to account for the fact that feature indices are categorical instead of ordinal. We further propose using a similar representation for the split thresholds as T ∈ R 2 d -1 × R n by storing one value for each feature as τ ∈ R n instead of a single feature. This adjustment is designed to support the optimization procedure, since the split thresholds are not interchangeable for different features: A split threshold that is reasonable for one feature might not be reasonable for another feature. By storing one split threshold for each feature, we support exploration during the optimization. This representation is inspired by the DT representation of Marton et al. (2022) used for predicting a DT as a numeric output of a neural network. Further, our matrix representation for the feature selection is similar to the one proposed by Popov et al. (2019) but does additionally include a matrix representation for the split thresholds. Besides the previously mentioned advantages, using a dense DT representation allows the use of matrix multiplications for an efficient computation, as we will show in the following. Therefore, we reformulate the Heaviside split function in Equation 2 as follows: S logistic (x, ι, τ ) = σ n i=0 ι i x i - n i=0 ι i τ i (4) S logistic hard (x, ι, τ ) = ⌊S logistic (x, ι, τ )⌉ where ⌊•⌉ stands for rounding to the closest integer. For our case, where ι is a one-hot encoded vector for the feature index, S logistic hard (x, ι, τ ) = S Heaviside (x, ι, τ ) holds.

3.3. BACKPROPAGATION OF DECISION TREE LOSS

While the dense DT representation introduced in the previous section emphasizes an efficient learning of axis-aligned DTs, it does not solve C1-C3. Therefore, in this subsection, we will explain how we solve those challenges by using a ST operator during the backpropagation when optimizing with gradient descent. For the function value calculation in the forward pass, we need to assure that ι is a one-hot encoded vector. This can be achieved by applying a hardmax function on the feature index vector of each internal node. However, applying a hardmax is a non-differentiable operation, which precludes gradient computation. To overcome this issue, we use the ST operator (Bengio et al., 2013) : For the forward pass, we apply the hardmax as is. For the backward pass, however, we exclude this operation and directly propagate back the gradients of ι. Accordingly, we can optimize the parameters of ι where ι ∈ R while still using axis-aligned splits during training (C1). Similarly, we use the ST operator to assure hard splits (Equation 5) by excluding ⌊•⌉ for the backward pass (C2). Using the sigmoid logistic function before applying the ST operator (see Equation 4) utilizes the distance to the split threshold as additional information for the gradient calculation. If the feature considered at an internal node for a specific sample is close to the split threshold, this will result in smaller gradients compared to a sample that is more distant to the split threshold. Furthermore, we need to adjust the leaf nodes of the DT to allow an efficient loss calculation (C3). Vanilla DTs comprise the predicted class for each leaf node and are functions DT : R n → C. We use 𝜾 " = 0.0 1.0 0.0 𝝉 " = 0.0 0.9 0.0 𝜾 # = 1.0 0.0 0.0 𝝉 # = 2.0 0.0 0.0 5). 𝑥 ! ≥ -1.2 𝑥 " ≥ 0.9 𝑥 # ≥ 1 -𝕊!"(𝒙, 𝜾#, 𝝂#) 𝕊!"(𝒙, 𝜾#, 𝝂#) 𝕊!"(𝒙, 𝜾$, 𝝂$) 𝕊!"(𝒙, 𝜾%, 𝝂%) 1 -𝕊!"(𝒙, 𝜾$, 𝝂$) 1 -𝕊!"(𝒙, a probability distribution at each leaf node and therefore define DTs as a function DT : R n → R c . Accordingly, the parameters of the leaf nodes are defined as L ∈ R 2 n × R c for the whole tree and λ ∈ R c for a specific leaf node. This adjustment allows the application of standard loss functions, as for instance the cross-entropy, during the optimization. Figure 3 visualizes our dense tree representation in comparison to standard tree representation for an exemplary DT.

3.4. TRAINING PROCEDURE

In the previous subsections, we introduced the adjustments that are necessary to apply a gradientbased optimization for DTs. We implement this optimization using a gradient descent algorithm. During the gradient descent optimization, we calculate the gradients using backpropagation (see Algorithm 2). We implement backpropagation using automatic differentiation based on the computation graph of the tree pass function, which is used to calculate the function values. The tree pass function is summarized in Algorithm 1 and utilizes the adjustments introduced in the previous sections. Furthermore, our tree routing allows calculating the tree pass function as a joint matrix operation for all tree nodes and all samples. We also want to note that the sophisticated dense representation which is necessary during the training can be converted into an equivalent vanilla DT representation at each point in time. Furthermore, our implementation optimizes the gradient descent algorithm by exploiting common stochastic gradient descent techniques, including mini-batch calculation and momentum using the Adam optimizer (Kingma & Ba, 2014) . Moreover, we implement an early stopping procedure and to avoid bad initial parametrizations during the initialization, we additionally implement random restarts where the best model is selected based on the validation loss.

4. EXPERIMENTAL EVALUATION

The goal of our experiments is to evaluate the predictive performance of GDTs against existing approaches to learning DTs. We focus on axis-aligned DTs and compare GDTs to the following baseline algorithms: • CART: We use the sklearn (Pedregosa et al., 2011) implementation, which uses an optimized version of the CART algorithm. While CART usually only uses Gini as impurity measure, we also allow information gain/entropy as an option during our hyperparameter optimization (HPO). Since the impurity measure is also the relevant difference between CART and C4.5, we decided to not use both algorithms as a benchmark, but only the optimized CART algorithm which is also the stronger benchmark to compare against. • Evolutionary DTs: We use the GeneticTree (Pysiak, 2021) implementation for evolutionary DTs that implements an efficient learning of DTs using a genetic algorithm. • DNDT (Yang et al., 2018) : We use the official impmementation (Yang et al., 2022) for learning DTs with gradient descent. for l = 0, . . . , 2 d do 6: p ← 1 7: for j = 1, . . . , d do 411: 8: i ← 2 j-1 + l 2 d-(j-1) -1 ▷ Equation 6 9: p ← l 2 d-j mod 2 ▷ Equation 7 10: s ← σ ( n i=0 T i,i I i,i - n i=0 x i I i,i ) ▷ Equation c * 2 = s -⌊s⌉ ▷ Excluded in the backward pass 12: s ← s -c * 2 ▷ Excluded in the backward pass 13: p ← p ((1 -p) s + p (1 -s)) ▷ Equation 314: end for 15: ŷ ← ŷ + L l p ▷ Equation 116: end for 17: return ŷ 18: end function • DL8.5 (Aglin et al., 2020) : We use the official implementation (Aglin et al., 2022) for learning optimal DTs which includes improvements from MurTree (Demirović et al., 2022) reducing the runtime significantly. As required, we discretized the data for DL8.5 as described in the Appendix A.4. GDTs are implemented in Python using TensorFlow (Abadi et al., 2015) foot_2 . The experiments were conducted on several benchmark datasets, mainly from the UCI Machine Learning repository (Dua & Graff, 2017) . The dataset specifications and source are listed in Table 7 . We use a random 80%/20% split to train and test data for all datasets. Since GDTs and DNDTs requires a validation set for early stopping, we performed another 80%/20% split on the training data. The remainder of the approaches utilize the complete training data. Further information on the hyperparameters is given in the Appendix A.5.

4.1. RESULTS

GDTs are competitive to baseline DT learners First, we evaluated the performance of GDTs against the baseline approaches on the mentioned benchmark datasets in terms of the F1-Score. We noted the mean reciprocal rank (MRR), similar to Yang et al. (2018) . The results are shown in Table 1 . Overall, GDTs outperformed state-of-the-art non-greedy DT approaches for binary classification tasks (MRR of 0.648 for GDT vs. 0.454 for GeneticTree as best non-greedy benchmark) and achieved competitive results for multi-class tasks (0.512 for GDT vs. 0.490 for DNDT). Compared to a greedy DT trained by improved CART, GDTs achieved a slightly higher performance on binary classification tasks (0.648 vs. 0.623) but underperformed on multi-class tasks (0.512 vs. 0.685). Considering the stdev of the reciprocal rank, we can observe that GDTs (0.275 and 0.213) are more robust than to CART (0.323 and 0.282). Further, we can observe that for several datasets a greedy optimization achieved very poor results and was significantly outperformed by a non-greedy optimization, as for instance on the Wisconsin Diagnostic Breast Cancer, Heart Disease and Lymphography datasets. In general, it stands out, that GDTs achieved the best results for binary classification datasets with a lower dimensionality (n ≤ 15). We can explain this by the dense DT representation required for the gradient-based optimization. Using our representation, the difficulty of the optimization task increases with the number of features (more parameters at each internal node) and the number of Table 1 : Performance Comparison. We report F1-scores (mean ± stdev over 10 trials). We also report the ranking of each approach in brackets. The top part comprises binary classification tasks and the bottom part multi-class datasets. The datasets are sorted by the number of features. Due to scalability issues, DNDTs were only applied to datasets with a maximum of 12 features, as suggested by Yang et al. (2018) . classes (more parameters at each leaf node). Therefore, in future work, we aim to optimize the proposed dense representation, e.g., by using parameter sharing techniques. GDTs are robust to overfitting We can observe that GDTs were more robust and less prone to overfitting than a greedy optimization. We measure overfitting by the difference between the mean train and test performance (see Table 2 ). This difference was significantly smaller for GDTs (0.026 for binary and 0.038 for multi-class) compared to CART (0.042 and 0.085). An even higher overfitting can be observed for DNDT (0.081 and 0.062) and DL8.5 (0.069 and 0.059). Overfitting can also explain why optimal decision trees, do not achieve superior results on test data, which was also reported by Zharmagambetov et al. (2021) . Similar to GDTs, GeneticTrees were less prone to overfitting (0.015 and 0.021). GDTs do not rely on extensive HPO An advantage of DTs over more sophisticated models, besides their interpretability, is that greedy approaches are not reliant on an extensive HPO to achieve reasonable results. In this experiment, we wanted to show that the same is true for GDTs. We used the default parameter settings for each model and fixed the maximum depth to 5 for GDTs and CART. Overall, the performance gain of HPO was rather small for all approaches: The maximum average performance gain for binary tasks could be observed for CART, with an increase of only 0.007 (see Table 2 ). For multi-class tasks, the average performance gain was higher with a maximum of 0.049 for GDTs. However, this is mainly due to a small number of datasets where HPO increased the performance significantly as for instance the Segment dataset with a performance increase of 0.111 for GDTs and 0.171 for CART (see (1) The mean difference between the train and test performance of each approach as overfitting indicator (see Table 6 for complete results) and ( 2) the impact of HPO by measuring the difference of the model performance with and without HPO for each approach (see Table 3 for complete results). Probabilities at the GDTs leaf nodes provide calibrated uncertainties Another advantage of GDTs (and similarly DNDTs) is the fact that they provide probabilities as a direct measure of the model's confidence. We compared the probability-based confidence with the purity measurefoot_3 in terms of the ROC AUC score (see Table 5 ). Comparing the probability-based approaches, GDTs achieved a significantly higher MRR based on the ROC AUC score of 0.818 for binary and 0.789 for multi-class tasks, compared to DNDT (0.628 and 0.655). Both of these probability-based approaches significantly outperformed approaches that learn vanilla DTs, where CART achieved the highest MRR (0.460 and 0.458). GTDs can be learned efficiently for large and high-dimensional datasets Considering the runtime of the different approaches listed in Table 4 , it stands out, that a greedy optimization was significantly faster than any other approach for each dataset (perfect MRR of 1.00 based on the runtime). Nevertheless, for most datasets, learning a GDT took less than 60 seconds, with a maximum of 122 seconds for the Congressional Voting dataset. DNDTs achieved similar runtimes compared to GDTs. This is also reflected by the MRR where GDTs achieved 0.254 for binary and 0.257 for multi-class and DNDTs achieved 0.228 and 0.221 respectively. DL8.5 has a low runtime < 10 seconds for most datasets, however we can observe scalability issues with the number of features and number of samples as they require a significantly higher runtime for certain datasets (e.g., > 330 seconds for Credit Card and > 1800 seconds for Splice).

5. CONCLUSION AND FUTURE WORK

In this paper, we proposed a method for learning hard, axis-aligned DTs based on a joint optimization of all parameters with gradient descent. We use backpropagation with a ST operator to deal with the non-differentiable nature of DTs and introduced a dense DT representation that allows an efficient optimization. Using a gradient-based optimization, GDTs are not prone to locally optimal solutions, as it is the case for standard, greedy DT induction methods like CART. We empirically showed that GDTs outperform existing baseline methods on several benchmark datasets. Additionally, GDTs provide calibrated uncertainties in terms of probability distributions at the leaf nodes, which also increases the interpretability of the model. Furthermore, a gradient-based optimization provides more flexibility. It is straightforward to use a custom loss function that is well-suited to the specific application scenario. Another advantage of a gradient-based optimization is the possibility to relearn the threshold value as well as the split index. This allows the application of GDTs to dynamic environments, as for instance online learning scenarios, without adjustments. Currently, GDTs are fully-grown. In future work, we want to apply pruning mechanisms to reduce the tree size for instance through a learnable choice parameter to decide if a node is pruned, similar to Zantedeschi et al. (2021) . While we focused on stand-alone DTs to generate intrinsically interpretable models within this paper, GDTs can easily be extended to random forests as a performance-interpretability trade-off, which is subject to future work.

A APPENDIX

A.1 CALCULATIONS The index i of an internal node preceding a specific leaf node l at a certain depth j can be calculated as follows: i(l, d, j) = 2 j-1 + l 2 d-(j-1) -1 Additionally, for a certain leaf node l, p defines whether the left branch (p = 0) or the right branch (p = 1) was taken at the internal node i. We can calculate p as follows: p(l, d, j) = l 2 d-j mod 2 (7) The calculation of the internal node index i as well as the specification of the path position p also involve non-differentiable operations. However, since we only focus on fully-grown trees, the resulting values are constant and can be calculated independently of the optimization, which does not preclude the application of a gradient-based optimization algorithm.

A.2 GRADIENT DESCENT OPTIMIZATION

We use stochastic gradient descent (SGD) to minimize the loss function of GDTs, which is outlined in Algorithm 2. We use backpropagation to calculate the gradients in Line 11-13. Furthermore, our implementation optimizes Algorithm 2 by exploiting common SGD techniques, including minibatch calculation and momentum using the Adam optimizer Kingma & Ba (2014) . We also formulate Line 6-9 as a single matrix multiplication for efficiency reasons. end for 15: end function

A.4 DATASETS

The datasets along with their specifications and source are summarized in Table 7 . For all datasets, we performed a standard preprocessing. We applied ordinal encoding to all non-numeric features so that they can be handled by DTs. We further standardized all datasets to zero mean and unit variance. This step however is only necessary for our approach and DNDTs. While standard DT We report the F1-score (mean ± stdev over 10 trials) on the test data. We also report the ranking of each approach for the corresponding dataset in brackets. The top part comprises binary classification tasks and the bottom part multi-class datasets. The datasets are sorted based on the number of features. Due to scalability issues, DNDTs were only applied to datasets with a maximum of 12 features, as suggested by Yang et al. (2018) . learning algorithms are independent of the feature scale, GDTs and DNDTs are sensitive to the feature scale due to the gradient updates. For DL8.5 additional preprocessing was necessary since they can only handle binary features. Therefore, we one-hot encoded all nominal and ordinal features and discretized numeric features by onehot encoding them using quantile binning with 5 bins.

A.5 HYPERPARAMETERS

In the following, we report the hyperparameters used for each approach. The hyperparameters were selected based on a random search over a predefined parameter range for GDT, CART and GeneticTree and are summarized in Table 8 -Table 10 . All parameters that were considered are noted in the tables. The number of trials was equal for each approach. For GDTs, we did not optimize the batch size as well as the number of epochs, but used early stopping with a predefined patience. For DL8.5 there are no tunable hyperparameters except the maximum depth. However, the maximum depth strongly impacts the runtime which is why we fixed the maximum depth to 4, similar to the maximum depth used during the experiments of the authors (Demirović et al., 2022; Aglin et al., 2020) . Running the experiments with a higher depth becomes infeasible for many datasets. In preliminary experiments, we also observed that increasing the depth to 5 results in a decrease in the test performance. For DNDT, the number of cut points is the tunable hyperparameter of the model according to the authors (Yang et al., 2018) . However, it has to be restricted to 1 in order to generate binary trees for comparability reasons. 



Available under: https://archive.ics.uci.edu/ml/datasets/echocardiogram (last accessed 16.08.2022) The authors also propose using ST Gumbel-Softmax as alternative to generate hard trees, which is similar to our approach. The code of our implementation is available to all reviewers in the supplementary material to assure anonymity. We will make it publically available here upon acceptance. Commonly, probabilities are obtained from the purity ratio of each class as a gateway to the model's confidence for vanilla DTs(Pedregosa et al., 2011).



Figure 2: DT Representations. Exemplary DT representations with depth 2 for a dataset comprising 3 variables and 2 classes. The DT at the right can be represented using the vectors ι = [0, 2, 1], τ = [2.0, -1.2, 0.9] and λ = [0, 0, 0, 1].

Figure 3: DT Representations. Exemplary DT representations with depth 2 for a dataset comprising 3 variables and 2 classes. Here, S lh stands for S logistic hard (Equation5).

Tree Pass of a Training Sample 1: function PASS(I, T, L, x) 2: c * 1 ← I ihardmax(I i ) for i = 0, . . . , |I| ▷ Excluded in the backward pass 3: I ← I -c * 1 Excluded

Gradient Descent Training for Decision Trees 1: function TRAINDT(I, T, L, X, y, n, c, d, ξ)

Table 3 for details). Summarized Results.

Performance Comparison Default Hyperparameters.

The learning rate and temperature were Under review as a conference paper at ICLR 2023 GDT Hyperparameters

CART Hyperparameters

GeneticTree Hyperparameters

Yes

Yes No No No (a) Vanilla DT Representation 𝕊 ! = 𝕊 "#$%&

GDT

Greedy (CART) GeneticTree DNDT DL8.5Blood Transfusion 27.706 ± 9.000 (5) 0.002 ± 0.000 (1) 4.243 ± 1.000 (3) 8.235 ± 4.000 (4) 0.027 ± 0.000 (2) Banknote Authentication 46.912 ± 8.000 (4) 0.003 ± 0.000 (1) 1.152 ± 0.000 (3) 47.474 ± 37.000 (5) 0.028 ± 0.000 (2) Titanic 20.803 ± 10.000 (5) 0.002 ± 0.000 (1) 8.056 ± 5.000 (3) 16.085 ± 4.000 (4) 0.104 ± 0.000 (2) Raisins 42.109 ± 9.000 (5) 0.003 ± 0.000 (1) 0.833 ± 0.000 (3) 22.533 ± 7.000 (4) 0.231 ± 0.000 (2) Rice 36.642 ± 9.000 (4) 0.008 ± 0.000 (1) 1.539 ± 0.000 (3) 71.666 ± 28.000 (5) 0.498 ± 0.000 (2) Echocardiogram 51.105 ± 10.000 (5) 0.002 ± 0.000 (1) 0.991 ± 0.000 (3) 14.556 ± 5.000 (4) 0.095 ± 0.000 (2) Wisconsin Diagnostic Breast Cancer 24.790 ± 7.000 (4) 0.004 ± 0.000 (1) 1.987 ± 1.000 (3) 29.417 ± 12.000 ( 5 We report runtime without restarts based on the optimized hyperparameters (mean ± stdev over 10 trials). We also report the ranking of each approach in brackets. The top part comprises binary classification tasks and the bottom part multi-class datasets. The datasets are sorted by the number of features. Due to scalability issues, DNDTs were only applied to datasets with a maximum of 12 features, as suggested by Yang et al. (2018) .set as suggested by the authors. However, we extended their implementation to use early stopping based on the validation loss, similar to GDTs, to reduce the runtime and prevent overfitting.Further details can be found in the code, which we provided in the supplementary material. Furthermore, we used random restarts only for GDTs We made this decision based on the fact that CART is deterministic and therefore does not benefit from additional restarts and for GeneticTree, increasing the population size would achieve similar results. Therefore, both baseline approaches would suffer from the use of restarts, since less data can be used for the training. Iris 0.964 ± 0.016 (2) 0.964 ± 0.009 (3) 0.939 ± 0.013 (5) 0.960 ± 0.014 (4) 0.996 ± 0.004 (1) Balance Scale 0.819 ± 0.025 (2) 1.000 ± 0.000 (1) 0.778 ± 0.017 (4) 0.808 ± 0.018 (3) 0.768 ± 0.007 (5) Car 0.868 ± 0.037 (2) 0.995 ± 0.001 (1) 0.720 ± 0.049 (5) 0.725 ± 0.036 (4) 0.838 ± 0.007 (3) Glass 0.773 ± 0.037 (4) 0.995 ± 0.006 (1) 0.593 ± 0.038 (5) 0.803 ± 0.067 (2) 0.785 ± 0.011 (3) Contraceptive 0.545 ± 0.019 (4) 0.578 ± 0.008 (2) 0.509 ± 0.032 (5) 0.555 ± 0.024 (3) 0.585 ± 0.004 (1) Solar Flare 0.770 ± 0.012 (4) 0.825 ± 0.007 (1) 0.766 ± 0.006 (5) 0.792 ± 0.018 (3) 0.804 ± 0.006 (2) Wine 0.990 ± 0.011 (3) 1.000 ± 0.000 (1) 0.947 ± 0.015 (5) 0.998 ± 0.004 (2) 0.976 ± 0.005 (4) Zoo 0.978 ± 0.011 (3) 1.000 ± 0.000 (1) 0. 

