8

I have trained a decision tree using a dataset. Now I want to see which samples fall under which leaf of the tree.

From here I want the red circled samples.

enter image description here

I am using Python's Sklearn's implementation of decision tree .

Maximilian Peters
  • 30,348
  • 12
  • 86
  • 99
Farshid Rayhan
  • 1,134
  • 4
  • 17
  • 31
  • 1
    This: https://stackoverflow.com/questions/32506951/how-to-explore-a-decision-tree-built-using-scikit-learn and this: https://stackoverflow.com/questions/20224526/how-to-extract-the-decision-rules-from-scikit-learn-decision-tree/42227468#42227468 may be relevant. – Miriam Farber Jul 30 '17 at 10:40
  • Is the left upper leaf left out on purpose? – Maximilian Peters Jul 30 '17 at 11:15

1 Answers1

12

If you want only the leaf for each sample you can just use

clf.apply(iris.data)

array([ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 14, 5, 5, 5, 5, 5, 5, 10, 5, 5, 5, 5, 5, 10, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 16, 16, 16, 16, 16, 16, 6, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 8, 16, 16, 16, 16, 16, 16, 15, 16, 16, 11, 16, 16, 16, 8, 8, 16, 16, 16, 15, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16])

If you want to get all samples for each node you could calculate all the decision paths with

dec_paths = clf.decision_path(iris.data)

Then loop over the decision paths, convert them to arrays with toarray() and check whether they belong to a node or not. Everything is stored in a defaultdict where the key is the node number and the values are the sample number.

for d, dec in enumerate(dec_paths):
    for i in range(clf.tree_.node_count):
        if dec.toarray()[0][i] == 1:
            samples[i].append(d)

Complete code

import sklearn.datasets
import sklearn.tree
import collections

clf = sklearn.tree.DecisionTreeClassifier(random_state=42)
iris = sklearn.datasets.load_iris()
clf = clf.fit(iris.data, iris.target)

samples = collections.defaultdict(list)
dec_paths = clf.decision_path(iris.data)

for d, dec in enumerate(dec_paths):
    for i in range(clf.tree_.node_count):
        if dec.toarray()[0][i] == 1:
            samples[i].append(d) 

Output

print(samples[13])

[70, 126, 138]

Maximilian Peters
  • 30,348
  • 12
  • 86
  • 99