LEARNING AXIS-ALIGNED DECISION TREES WITH GRADIENT DESCENT

Abstract

Decision Trees are commonly used for many machine learning tasks due to their high interpretability. However, learning a decision tree from data is a difficult optimization problem, since it is non-convex and non-differentiable. Therefore, common approaches learn decision trees using a greedy growth algorithm that minimizes the impurity at each internal node. Unfortunately, this greedy procedure can lead to suboptimal trees. In this paper, we present a novel approach for learning hard, axis-aligned decision trees with gradient descent. This is achieved by applying backpropagation with a straight-through operator on a dense decision tree representation that jointly optimizes all decision tree parameters. We show that our gradient-based optimization outperforms existing baselines on several binary classification benchmarks and achieves competitive results for multi-class tasks. To the best of our knowledge, this is the first approach that attempts to learn hard, axis-aligned decision trees with gradient descent without restrictions regarding the structure.

1. INTRODUCTION

Decision trees (DTs) are one of the most popular machine learning models and are still frequently used today. Especially with the increasing interest in explainable artificial intelligence (XAI), DTs regained popularity. However, learning a DT is a difficult optimization problem, since it is nonconvex and non-differentiable. Finding an optimal DT for a specific task is a NP-complete problem (Laurent & Rivest, 1976) . Therefore, the common approach to construct a DT based on data is a greedy procedure (Breiman et al., 1984; Quinlan, 1993) that minimizes the impurity at each internal node. The algorithms that are still used today, namely CART Breiman et al. (1984) and C4.5 Quinlan (1993) , date back until the 1980s and since then remained mostly unchanged. Unfortunately, a greedy algorithm only yields to locally optimal solutions and therefore can lead to suboptimal trees. We illustrate the issues that can arise when learning DTs with a greedy algorithm in the following example: Example 1 The Echocardiogram datasetfoot_0 deals with predicting one-year survival of patients after a heart attack based on tabular data from an echocardiogram. Figure 1 shows two decision trees. The left tree is learned using a greedy algorithm (CART) and the right tree is learned using our gradient-based approach. We can observe that the greedy procedure leads to a suboptimal tree with a significantly lower performance. Splitting on the wall-motion-score is the locally optimal split (see Figure 1a ), but globally, it is beneficial to split based on the wall-motion-score with different values conditioned on the pericarcial-effusion in the second level (Figure 1b ). The contribution of this paper is a novel approach for learning hard, axis-aligned DTs based on a joint optimization of all parameters with gradient descent, which we call gradient-based decision trees (GDTs). Using a gradient-based optimization, we can overcome the limitations of greedy approaches, as indicated in Figure 1b . We propose a suitable dense DT representation that allows a gradient-based optimization of the tree parameters (Section 3.2). We further present an algorithm that allows us to deal with the non-differentiable nature of DTs, while still allowing an efficient optimization using backpropagation with a straight-through (ST) operator (Section 3.3).



Available under: https://archive.ics.uci.edu/ml/datasets/echocardiogram (last accessed 16.08.2022)

