18

I am building a decision tree using

clf = tree.DecisionTreeClassifier()
clf = clf.fit(X_train, Y_train)

This all works fine. However, how do I then explore the decision tree?

For example, how do I find which entries from X_train appear in a particular leaf?

Simd
  • 19,447
  • 42
  • 136
  • 271
  • 3
    Ran into a similar issue. You might find my answer [here](http://stackoverflow.com/questions/20224526/how-to-extract-the-decision-rules-from-scikit-learn-decision-tree/42227468#42227468) (and the walkthrough mentioned there) helpful. It uses a method, `decision_path`, from the 0.18 release. Substitute `X_test` with `X_train` in a few spots if interested in seeing training samples. – Kevin Feb 14 '17 at 14:40
  • I saw one of the best visualization of decision trees ... here ... https://github.com/parrt/dtreeviz/blob/master/notebooks/dtreeviz_sklearn_visualisations.ipynb – shantanu pathak Sep 24 '20 at 12:59

7 Answers7

17

You need to use the predict method.

After training the tree, you feed the X values to predict their output.

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier(random_state=0)
iris = load_iris()
tree = clf.fit(iris.data, iris.target)
tree.predict(iris.data) 

output:

>>> tree.predict(iris.data)
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

To get details on the tree structure, we can use tree_.__getstate__()

Tree structure translated into an "ASCII art" picture

              0  
        _____________
        1           2
               ______________
               3            12
            _______      _______
            4     7      13   16
           ___   ______        _____
           5 6   8    9        14 15
                      _____
                      10 11

tree structure as an array.

In [38]: tree.tree_.__getstate__()['nodes']
Out[38]: 
array([(1, 2, 3, 0.800000011920929, 0.6666666666666667, 150, 150.0),
       (-1, -1, -2, -2.0, 0.0, 50, 50.0),
       (3, 12, 3, 1.75, 0.5, 100, 100.0),
       (4, 7, 2, 4.949999809265137, 0.16803840877914955, 54, 54.0),
       (5, 6, 3, 1.6500000953674316, 0.04079861111111116, 48, 48.0),
       (-1, -1, -2, -2.0, 0.0, 47, 47.0), 
       (-1, -1, -2, -2.0, 0.0, 1, 1.0),
       (8, 9, 3, 1.5499999523162842, 0.4444444444444444, 6, 6.0),
       (-1, -1, -2, -2.0, 0.0, 3, 3.0),
       (10, 11, 2, 5.449999809265137, 0.4444444444444444, 3, 3.0),
       (-1, -1, -2, -2.0, 0.0, 2, 2.0), 
       (-1, -1, -2, -2.0, 0.0, 1, 1.0),
       (13, 16, 2, 4.850000381469727, 0.042533081285444196, 46, 46.0),
       (14, 15, 1, 3.0999999046325684, 0.4444444444444444, 3, 3.0),
       (-1, -1, -2, -2.0, 0.0, 2, 2.0), 
       (-1, -1, -2, -2.0, 0.0, 1, 1.0),
       (-1, -1, -2, -2.0, 0.0, 43, 43.0)], 
      dtype=[('left_child', '<i8'), ('right_child', '<i8'), 
             ('feature', '<i8'), ('threshold', '<f8'), 
             ('impurity', '<f8'), ('n_node_samples', '<i8'), 
             ('weighted_n_node_samples', '<f8')])

Where:

  • The first node [0] is the root node.
  • internal nodes have left_child and right_child refering to nodes with positive values, and greater than the current node.
  • leaves have -1 value for the left and right child nodes.
  • nodes 1,5,6, 8,10,11,14,15,16 are leaves.
  • the node structure is built using the Depth First Search Algorithm.
  • the feature field tells us which of the iris.data features was used in the node to determine the path for this sample.
  • the threshold tells us the value used to evaluate the direction based on the feature.
  • impurity reaches 0 at the leaves... since all the samples are in the same class once you reach the leaf.
  • n_node_samples tells us how many samples reach each leaf.

Using this information we could trivially track each sample X to the leaf where it eventually lands by following the classification rules and thresholds on a script. Additionally, the n_node_samples would allow us to perform unit tests ensuring that each node gets the correct number of samples.Then using the output of tree.predict, we could map each leaf to the associated class.

zneak
  • 134,922
  • 42
  • 253
  • 328
PabTorre
  • 2,878
  • 21
  • 30
  • Thank you. This tells me the class but not which leaf of the decision tree each item is in. If I could just extract the rules needed to get to each leaf somehow I could rerun those rules over the data. – Simd Sep 10 '15 at 17:39
  • When you say you want to see the leafs, do you mean that you want to see the rules that the tree used at each node? if that is the case then maybe this will help: http://stackoverflow.com/questions/20224526/how-to-extract-the-decision-rules-from-scikit-learn-decision-tree – PabTorre Sep 10 '15 at 17:47
  • 1
    For a given leaf I would like to see the training data the decision tree would place at that leaf. In other words, each leaf is associated with a sequence of rules (comparisons). I would like to see the subset of the data you get if you apply those rules. – Simd Sep 10 '15 at 17:50
  • just to make sure we are using the same terms. A decision tree consists of a root node that has no incomming edges. Internal nodes that have incoming and outgoing edges, and leaves (aka terminal or decision nodes) Each leaf is assigned one class. http://www.ise.bgu.ac.il/faculty/liorr/hbchap9.pdf When you say you want to see the leafs and not the classes, you mean that if 2 leaves are assigned the same class you want to distinguish between the different instances in one class that reached that class through different paths? – PabTorre Sep 10 '15 at 18:03
  • An example decision tree is at http://stackoverflow.com/questions/23557545/how-to-explain-the-decision-tree-from-scikit-learn . How would find the data that is mapped to each leaf? – Simd Sep 10 '15 at 18:07
  • 1
    What are the last two columns of tree.tree_.__getstate__()['nodes']? – lars Nov 17 '16 at 17:05
6

NOTE: This is not an answer, only a hint on possible solutions.

I encountered a similar problem recently in my project. My goal is to extract the corresponding chain of decisions for some particular samples. I think your problem is a subset of mine, since you just need to record the last step in the decision chain.

Up to now, it seems the only viable solution is to write a custom predict method in Python to keep track of the decisions along the way. The reason is that the predict method provided by scikit-learn cannot do this out-of-box (as far as I know). And to make it worse, it is a wrapper for C implementation which is pretty hard to customize.

Customization is fine for my problem, since I'm dealing with a unbalanced dataset, and the samples I care about (positive ones) are rare. So I can filter them out first using sklearn predict and then get the decision chain using my customization.

However, this may not work for you if you have a large dataset. Because if you parse the tree and do predict in Python, it will run slow in Python speed and will not (easily) scale. You may have to fallback to customizing the C implementation.

zaxliu
  • 2,726
  • 1
  • 22
  • 26
4

I've changed a bit what Dr. Drew posted.
The following code, given a data frame and the decision tree after being fitted, returns:

  • rules_list: a list of rules
  • values_path: a list of entries (entries for each class going through the path)

    import numpy as np  
    import pandas as pd  
    from sklearn.tree import DecisionTreeClassifier 
    
    def get_rules(dtc, df):
        rules_list = []
        values_path = []
        values = dtc.tree_.value
    
        def RevTraverseTree(tree, node, rules, pathValues):
            '''
            Traverase an skl decision tree from a node (presumably a leaf node)
            up to the top, building the decision rules. The rules should be
            input as an empty list, which will be modified in place. The result
            is a nested list of tuples: (feature, direction (left=-1), threshold).  
            The "tree" is a nested list of simplified tree attributes:
            [split feature, split threshold, left node, right node]
            '''
            # now find the node as either a left or right child of something
            # first try to find it as a left node            
    
            try:
                prevnode = tree[2].index(node)           
                leftright = '<='
                pathValues.append(values[prevnode])
            except ValueError:
                # failed, so find it as a right node - if this also causes an exception, something's really f'd up
                prevnode = tree[3].index(node)
                leftright = '>'
                pathValues.append(values[prevnode])
    
            # now let's get the rule that caused prevnode to -> node
            p1 = df.columns[tree[0][prevnode]]    
            p2 = tree[1][prevnode]    
            rules.append(str(p1) + ' ' + leftright + ' ' + str(p2))
    
            # if we've not yet reached the top, go up the tree one more step
            if prevnode != 0:
                RevTraverseTree(tree, prevnode, rules, pathValues)
    
        # get the nodes which are leaves
        leaves = dtc.tree_.children_left == -1
        leaves = np.arange(0,dtc.tree_.node_count)[leaves]
    
        # build a simpler tree as a nested list: [split feature, split threshold, left node, right node]
        thistree = [dtc.tree_.feature.tolist()]
        thistree.append(dtc.tree_.threshold.tolist())
        thistree.append(dtc.tree_.children_left.tolist())
        thistree.append(dtc.tree_.children_right.tolist())
    
        # get the decision rules for each leaf node & apply them
        for (ind,nod) in enumerate(leaves):
    
            # get the decision rules
            rules = []
            pathValues = []
            RevTraverseTree(thistree, nod, rules, pathValues)
    
            pathValues.insert(0, values[nod])      
            pathValues = list(reversed(pathValues))
    
            rules = list(reversed(rules))
    
            rules_list.append(rules)
            values_path.append(pathValues)
    
        return (rules_list, values_path)
    

It follows an example:

df = pd.read_csv('df.csv')

X = df[df.columns[:-1]]
y = df['classification']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

dtc = DecisionTreeClassifier(max_depth=2)
dtc.fit(X_train, y_train)

The Decision Tree fitted has generated the following tree: Decision Tree with width 2

At this point, just calling the function:

get_rules(dtc, df)

This is what the function returns:

rules = [  
    ['first <= 63.5', 'first <= 43.5'],  
    ['first <= 63.5', 'first > 43.5'],  
    ['first > 63.5', 'second <= 19.700000762939453'],  
    ['first > 63.5', 'second > 19.700000762939453']
]

values = [
    [array([[ 1568.,  1569.]]), array([[ 636.,  241.]]), array([[ 284.,  57.]])],
    [array([[ 1568.,  1569.]]), array([[ 636.,  241.]]), array([[ 352.,  184.]])],
    [array([[ 1568.,  1569.]]), array([[  932.,  1328.]]), array([[ 645.,  620.]])],
    [array([[ 1568.,  1569.]]), array([[  932.,  1328.]]), array([[ 287.,  708.]])]
]

Obviously, in values, for each path, there is the leaf values too.

Federico Ibba
  • 118
  • 1
  • 9
  • In the function get_rules, I think we need to switch dtc and df? – Nivi Apr 04 '18 at 14:34
  • The other thing is, should it be return(rules_list,values_path) instead of return(r,values_path)? – Nivi Apr 04 '18 at 14:49
  • 1
    Sorry for the late Nivi, I've seen just now the comments. First, in get_rules yes, I'm sorry, it has to be switched, I edit it. Second is true too, I'm sorry for the mistakes, I updated the answer – Federico Ibba Apr 09 '18 at 12:15
3

The below code should produce a plot of your top ten features:

import numpy as np
import matplotlib.pyplot as plt

importances = clf.feature_importances_
std = np.std(clf.feature_importances_,axis=0)
indices = np.argsort(importances)[::-1]

# Print the feature ranking
print("Feature ranking:")

for f in range(10):
    print("%d. feature %d (%f)" % (f + 1, indices[f], importances[indices[f]]))

# Plot the feature importances of the forest
plt.figure()
plt.title("Feature importances")
plt.bar(range(10), importances[indices],
       color="r", yerr=std[indices], align="center")
plt.xticks(range(10), indices)
plt.xlim([-1, 10])
plt.show()

Taken from here and modified slightly to fit the DecisionTreeClassifier.

This doesn't exactly help you explore the tree, but it does tell you about the tree.

Charlie Haley
  • 4,152
  • 4
  • 22
  • 36
  • Thank you but I would like to see which training data fall into each leaf, for example. Currently I have to draw the decision tree, write down the rules, write a script to filter the data using those rules. This can't be the right way! – Simd Sep 10 '15 at 17:28
  • Is your data small enough to run those calculations by hand or in a spreadsheet? I'm assuming this is for a class, in which case it may be better not to just run the algorithm and copy down the structure. That said, I imagine there is some way to get the structure of the tree from sci-kit. Here's the source for DecisionTreeClassifier: https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/tree.py – Charlie Haley Sep 10 '15 at 17:33
  • It's not for a class! I have about 1000000 items so I do it by writing a separate python script. However I don't even know how to extract the rules for each leaf automatically currently. Is there a way? – Simd Sep 10 '15 at 17:35
  • Check out these two links: http://stackoverflow.com/questions/20224526/how-to-extract-the-decision-rules-from-scikit-learn-decision-tree http://scikit-learn.org/stable/modules/generated/sklearn.tree.export_graphviz.html – Charlie Haley Sep 10 '15 at 17:39
  • That is what I currently use to draw the decision tree. I could write a parser for the dot file but it seems very awkward. – Simd Sep 10 '15 at 17:41
  • @CharlieHaley What units make up the y axis of this feature importance plot? Variance explained? Or is "importance" different than variance explained? – Arash Howaida Jun 08 '17 at 05:21
  • @ArashHowaida According to the docs: "The importance of a feature is computed as the (normalized) total reduction of the criterion brought by that feature. It is also known as the Gini importance." And the function is defined here, in cython: https://github.com/scikit-learn/scikit-learn/blob/14031f65d144e3966113d3daec836e443c6d7a5b/sklearn/tree/_tree.pyx#L1036 – Charlie Haley Jun 08 '17 at 14:20
3

This code will do exactly what you want. Here, n is the number observations in X_train. At the end, the (n,number_of_leaves)-sized array leaf_observations holds in each column boolean values for indexing into X_train to get the observations in each leaf. Each columns of leaf_observations corresponds to an element in leaves, which has the node IDs for the leaves.

# get the nodes which are leaves
leaves = clf.tree_.children_left == -1
leaves = np.arange(0,clf.tree_.node_count)[leaves]

# loop through each leaf and figure out the data in it
leaf_observations = np.zeros((n,len(leaves)),dtype=bool)
# build a simpler tree as a nested list: [split feature, split threshold, left node, right node]
thistree = [clf.tree_.feature.tolist()]
thistree.append(clf.tree_.threshold.tolist())
thistree.append(clf.tree_.children_left.tolist())
thistree.append(clf.tree_.children_right.tolist())
# get the decision rules for each leaf node & apply them
for (ind,nod) in enumerate(leaves):
    # get the decision rules in numeric list form
    rules = []
    RevTraverseTree(thistree, nod, rules)
    # convert & apply to the data by sequentially &ing the rules
    thisnode = np.ones(n,dtype=bool)
    for rule in rules:
        if rule[1] == 1:
            thisnode = np.logical_and(thisnode,X_train[:,rule[0]] > rule[2])
        else:
            thisnode = np.logical_and(thisnode,X_train[:,rule[0]] <= rule[2])
    # get the observations that obey all the rules - they are the ones in this leaf node
    leaf_observations[:,ind] = thisnode

This needs the helper function defined here, which recursively traverses the tree starting from a specified node to build the decision rules.

def RevTraverseTree(tree, node, rules):
    '''
    Traverase an skl decision tree from a node (presumably a leaf node)
    up to the top, building the decision rules. The rules should be
    input as an empty list, which will be modified in place. The result
    is a nested list of tuples: (feature, direction (left=-1), threshold).  
    The "tree" is a nested list of simplified tree attributes:
    [split feature, split threshold, left node, right node]
    '''
    # now find the node as either a left or right child of something
    # first try to find it as a left node
    try:
        prevnode = tree[2].index(node)
        leftright = -1
    except ValueError:
        # failed, so find it as a right node - if this also causes an exception, something's really f'd up
        prevnode = tree[3].index(node)
        leftright = 1
    # now let's get the rule that caused prevnode to -> node
    rules.append((tree[0][prevnode],leftright,tree[1][prevnode]))
    # if we've not yet reached the top, go up the tree one more step
    if prevnode != 0:
        RevTraverseTree(tree, prevnode, rules)
Dr. Andrew
  • 2,532
  • 3
  • 26
  • 42
2

I think an easy option would be to use the apply method of the trained decision tree. Train the tree, apply the traindata and build a lookup table from the returned indices:

import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris

iris = load_iris()
clf = DecisionTreeClassifier()
clf = clf.fit(iris.data, iris.target)

# apply training data to decision tree
leaf_indices = clf.apply(iris.data)
lookup = {}

# build lookup table
for i, leaf_index in enumerate(leaf_indices):
    try:
        lookup[leaf_index].append(iris.data[i])
    except KeyError:
        lookup[leaf_index] = []
        lookup[leaf_index].append(iris.data[i])

# test
unkown_sample = [[4., 3.1, 6.1, 1.2]]
index = clf.apply(unkown_sample)
print(lookup[index[0]])
maltesar
  • 181
  • 1
  • 7
0

Have you tried dumping your DecisionTree into a graphviz' .dot file [1] and then load it with graph_tool [2].:

import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
from graph_tool.all import *

iris = load_iris()
clf = DecisionTreeClassifier()
clf = clf.fit(iris.data, iris.target)

tree.export_graphviz(clf,out_file='tree.dot')

#load graph with graph_tool and explore structure as you please
g = load_graph('tree.dot')

for v in g.vertices():
   for e in v.out_edges():
       print(e)
   for w in v.out_neighbours():
       print(w)

[1] http://scikit-learn.org/stable/modules/generated/sklearn.tree.export_graphviz.html

[2] https://graph-tool.skewed.de/

roj4s
  • 251
  • 2
  • 8
  • Can you make it beautiful that way? As in http://scikit-learn.org/stable/_images/iris.svg ? – Simd Apr 03 '17 at 16:10
  • Once outputed with export_graphiz something like that can be achieved with dot -Tpng tree.dot -o tree.png. – roj4s Apr 18 '17 at 09:23