LEARNING AXIS-ALIGNED DECISION TREES WITH GRADIENT DESCENT

Abstract

Decision Trees are commonly used for many machine learning tasks due to their high interpretability. However, learning a decision tree from data is a difficult optimization problem, since it is non-convex and non-differentiable. Therefore, common approaches learn decision trees using a greedy growth algorithm that minimizes the impurity at each internal node. Unfortunately, this greedy procedure can lead to suboptimal trees. In this paper, we present a novel approach for learning hard, axis-aligned decision trees with gradient descent. This is achieved by applying backpropagation with a straight-through operator on a dense decision tree representation that jointly optimizes all decision tree parameters. We show that our gradient-based optimization outperforms existing baselines on several binary classification benchmarks and achieves competitive results for multi-class tasks. To the best of our knowledge, this is the first approach that attempts to learn hard, axis-aligned decision trees with gradient descent without restrictions regarding the structure.

1. INTRODUCTION

Decision trees (DTs) are one of the most popular machine learning models and are still frequently used today. Especially with the increasing interest in explainable artificial intelligence (XAI), DTs regained popularity. However, learning a DT is a difficult optimization problem, since it is nonconvex and non-differentiable. Finding an optimal DT for a specific task is a NP-complete problem (Laurent & Rivest, 1976) . Therefore, the common approach to construct a DT based on data is a greedy procedure (Breiman et al., 1984; Quinlan, 1993) that minimizes the impurity at each internal node. The algorithms that are still used today, namely CART Breiman et al. (1984) and C4.5 Quinlan (1993) , date back until the 1980s and since then remained mostly unchanged. Unfortunately, a greedy algorithm only yields to locally optimal solutions and therefore can lead to suboptimal trees. We illustrate the issues that can arise when learning DTs with a greedy algorithm in the following example: Example 1 The Echocardiogram datasetfoot_0 deals with predicting one-year survival of patients after a heart attack based on tabular data from an echocardiogram. Figure 1 shows two decision trees. The left tree is learned using a greedy algorithm (CART) and the right tree is learned using our gradient-based approach. We can observe that the greedy procedure leads to a suboptimal tree with a significantly lower performance. Splitting on the wall-motion-score is the locally optimal split (see Figure 1a ), but globally, it is beneficial to split based on the wall-motion-score with different values conditioned on the pericarcial-effusion in the second level (Figure 1b ). The contribution of this paper is a novel approach for learning hard, axis-aligned DTs based on a joint optimization of all parameters with gradient descent, which we call gradient-based decision trees (GDTs). Using a gradient-based optimization, we can overcome the limitations of greedy approaches, as indicated in Figure 1b . We propose a suitable dense DT representation that allows a gradient-based optimization of the tree parameters (Section 3.2). We further present an algorithm that allows us to deal with the non-differentiable nature of DTs, while still allowing an efficient optimization using backpropagation with a straight-through (ST) operator (Section 3.3). While the CART DT (left) only makes locally optimal splits, the GDT (right) performs a joint optimization of all parameters, which results in a significantly higher performance. We empirically evaluate GDTs on several real-world benchmark datasets (Section 4). GDTs outperform existing baselines on several benchmark datasets and achieves competitive results for DT learning. Furthermore, the resulting trees are less prone to overfitting and represent a more accurate quantification of uncertainty, since they comprise probability distributions at the leafs. A gradientbased optimization also provides more flexibility, since splits can be adjusted during the training. This allows an application of GDTs to dynamic scenarios, as for instance online learning.

2. RELATED WORK

Greedy DT Algorithms The most prominent and still frequently used algorithms, namely CART (Breiman et al., 1984) and C4.5 (Quinlan, 1993) as an extension of ID3.0 (Quinlan, 1986) , data back to 1980s and both follow a greedy procedure to learn a DT. Since then, many variations to those algorithms have been proposed, as for instance C5.0 (Kuhn et al., 2013) and GUIDE (Loh, 2002; 2009) . However, until today, none of these algorithms was able to consistently outperform CART and C4.5 as shown by Zharmagambetov et al. (2021) . Lookahead DTs To overcome the issues of a greedy DT induction, many researchers focused on finding an efficient alternative. One solution to mitigate the impact of a greedy procedure are methods that look ahead during the induction (Sarkar et al., 1994) . However, Murthy (1996) argue that those methods not only suffer from an enormous increase in the computational complexity, but also suffer from pathology, i.e., they frequently produce worse trees in terms of accuracy, tree size and depth. One explanation could be that lookahead trees, especially without regularization, are prone to overfitting. Optimal DTs Optimal DTs try to optimize an objective (e.g., the purity) using approximate brute force search to find a globally optimal tree with a certain specification (Zharmagambetov et al., 2021) . OCT (Bertsimas & Dunn, 2017) defines the optimization as a mixed integer optimization (MIO) problem which is solved using a MIO solver. In contrast, DL8.5 (Aglin et al., 2020) and GOSDT (Lin et al., 2020) approximate a brute force search using a branch-and-bound algorithm to remove irrelevant parts from the search space. MurTree (Demirović et al., 2022) further uses dynamic programming to reduce the runtime significantly. However, state-of-the-art approaches still require binary data and therefore a discretization of continuous features (Aglin et al., 2020; Demirović et al., 2022; Bertsimas & Dunn, 2017) . Genetic DTs Another way to learn DTs in a non-greedy fashion is using evolutionary algorithms for DT induction. Evolutionary algorithms perform a robust global search in the space of candidate solutions based on the concept of survival of the fittest (Barros et al., 2011) . This usually results in a better identification of feature interactions compared to a greedy, local search (Freitas, 2002) . Oblique DTs In contrast to vanilla DTs that make hard decision at each internal node, many approaches to hierarchical mixture of expert models (Jordan & Jacobs, 1994) have been proposed. They make soft decisions where each branch is associated with a probability (Irsoy et al., 2012; Frosst & Hinton, 2017) . The resulting models do not comprise axis-aligned splits, but are oblique



Available under: https://archive.ics.uci.edu/ml/datasets/echocardiogram (last accessed 16.08.2022)



Figure 1: Greedy versus Gradient-Based DT. Two DTs learned on the Echocardiogram dataset.While the CART DT (left) only makes locally optimal splits, the GDT (right) performs a joint optimization of all parameters, which results in a significantly higher performance.

