11

I use sklearn.tree.DecisionTreeClassifier to build a decision tree. With the optimal parameter settings, I get a tree that has unnecessary leaves (see example picture below - I do not need probabilities, so the leaf nodes marked with red are a unnecessary split)

Tree

Is there any third-party library for pruning these unnecessary nodes? Or a code snippet? I could write one, but I can't really imagine that I am the first person with this problem...

Code to replicate:

from sklearn.tree import DecisionTreeClassifier
from sklearn import datasets
iris = datasets.load_iris()
X = iris.data
y = iris.target
mdl = DecisionTreeClassifier(max_leaf_nodes=8)
mdl.fit(X,y)

PS: I have tried multiple keyword searches and am kind of surprised to find nothing - is there really no post-pruning in general in sklearn?

PPS: In response to the possible duplicate: While the suggested question might help me when coding the pruning algorithm myself, it answers a different question - I want to get rid of leaves that do not change the final decision, while the other question wants a minimum threshold for splitting nodes.

PPPS: The tree shown is an example to show my problem. I am aware of the fact that the parameter settings to create the tree are suboptimal. I am not asking about optimizing this specific tree, I need to do post-pruning to get rid of leaves that might be helpful if one needs class probabilities, but are not helpful if one is only interested in the most likely class.

Thomas
  • 4,696
  • 5
  • 36
  • 71
  • 1
    Possible duplicate of [Pruning Decision Trees](https://stackoverflow.com/questions/49428469/pruning-decision-trees) – piman314 Jul 18 '18 at 08:34
  • @ncfirth: While the question is also about pruning, it tries to do something else - see my edit. – Thomas Jul 18 '18 at 08:38
  • 1
    @ncfirth: However, thank you for providing the link, it helped me write my own code ([see my answer below](https://stackoverflow.com/a/51398390/4629950)) for post-pruning. – Thomas Jul 19 '18 at 07:48

3 Answers3

15

Using ncfirth's link, I was able to modify the code there so that it fits to my problem:

from sklearn.tree._tree import TREE_LEAF

def is_leaf(inner_tree, index):
    # Check whether node is leaf node
    return (inner_tree.children_left[index] == TREE_LEAF and 
            inner_tree.children_right[index] == TREE_LEAF)

def prune_index(inner_tree, decisions, index=0):
    # Start pruning from the bottom - if we start from the top, we might miss
    # nodes that become leaves during pruning.
    # Do not use this directly - use prune_duplicate_leaves instead.
    if not is_leaf(inner_tree, inner_tree.children_left[index]):
        prune_index(inner_tree, decisions, inner_tree.children_left[index])
    if not is_leaf(inner_tree, inner_tree.children_right[index]):
        prune_index(inner_tree, decisions, inner_tree.children_right[index])

    # Prune children if both children are leaves now and make the same decision:     
    if (is_leaf(inner_tree, inner_tree.children_left[index]) and
        is_leaf(inner_tree, inner_tree.children_right[index]) and
        (decisions[index] == decisions[inner_tree.children_left[index]]) and 
        (decisions[index] == decisions[inner_tree.children_right[index]])):
        # turn node into a leaf by "unlinking" its children
        inner_tree.children_left[index] = TREE_LEAF
        inner_tree.children_right[index] = TREE_LEAF
        ##print("Pruned {}".format(index))

def prune_duplicate_leaves(mdl):
    # Remove leaves if both 
    decisions = mdl.tree_.value.argmax(axis=2).flatten().tolist() # Decision for each node
    prune_index(mdl.tree_, decisions)

Using this on a DecisionTreeClassifier clf:

prune_duplicate_leaves(clf)

Edit: Fixed a bug for more complex trees

Thomas
  • 4,696
  • 5
  • 36
  • 71
  • 2
    note that this code will modify the tree in-place. This is not bad per-se, just good to know in case you want to compare the tree pre/post prunning. – gire Dec 28 '18 at 06:20
  • @Thomas Can you change this code to prune by number of samples? Say if you prune by n_node_samples less than 5. I tried switching the decisions[index] by n_node_samples less than a threshold but it's not working. I've lost a few weeks already on this problem, and tried adapting other solutions but no luck so far. – Kaluk May 20 '21 at 02:40
  • You don't need to prune for that, you can ensure all nodes have more than 5 samples by simply setting min_samples_split=5 during training. See for example here: https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html – Thomas May 21 '21 at 08:56
1

DecisionTreeClassifier(max_leaf_nodes=8) specifies (max) 8 leaves, so unless the tree builder has another reason to stop it will hit the max.

In the example shown, 5 of the 8 leaves have a very small amount of samples (<=3) compared to the others 3 leaves (>50), a possible sign of over-fitting. Instead of pruning the tree after training, one can specifying either min_samples_leaf or min_samples_split to better guide the training, which will likely get rid of the problematic leaves. For instance use the value 0.05 for least 5% of samples.

Jon Nordby
  • 5,494
  • 1
  • 21
  • 50
  • This is simply a reproducable example meant to show my problem and obviously not my real code... I am aware of the various settings of the decision tree, nevertheless, sklearn as of now is simply missing any post-pruning options. – Thomas Jul 19 '18 at 07:42
0

I had a problem with the code posted here so I revised it and had to add a small section (it deals with the case that both sides are the same but there is still a comparison present):

from sklearn.tree._tree import TREE_LEAF, TREE_UNDEFINED

def is_leaf(inner_tree, index):
    # Check whether node is leaf node
    return (inner_tree.children_left[index] == TREE_LEAF and 
            inner_tree.children_right[index] == TREE_LEAF)

def prune_index(inner_tree, decisions, index=0):
    # Start pruning from the bottom - if we start from the top, we might miss
    # nodes that become leaves during pruning.
    # Do not use this directly - use prune_duplicate_leaves instead.
    if not is_leaf(inner_tree, inner_tree.children_left[index]):
        prune_index(inner_tree, decisions, inner_tree.children_left[index])
    if not is_leaf(inner_tree, inner_tree.children_right[index]):
        prune_index(inner_tree, decisions, inner_tree.children_right[index])

    # Prune children if both children are leaves now and make the same decision:     
    if (is_leaf(inner_tree, inner_tree.children_left[index]) and
        is_leaf(inner_tree, inner_tree.children_right[index]) and
        (decisions[index] == decisions[inner_tree.children_left[index]]) and 
        (decisions[index] == decisions[inner_tree.children_right[index]])):
        # turn node into a leaf by "unlinking" its children
        inner_tree.children_left[index] = TREE_LEAF
        inner_tree.children_right[index] = TREE_LEAF
        inner_tree.feature[index] = TREE_UNDEFINED
        ##print("Pruned {}".format(index))

def prune_duplicate_leaves(mdl):
    # Remove leaves if both 
    decisions = mdl.tree_.value.argmax(axis=2).flatten().tolist() # Decision for each node
    prune_index(mdl.tree_, decisions)
mnz
  • 156
  • 6