You need to use the predict method.
After training the tree, you feed the X values to predict their output.
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier(random_state=0)
iris = load_iris()
tree = clf.fit(iris.data, iris.target)
tree.predict(iris.data)
output:
>>> tree.predict(iris.data)
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
To get details on the tree structure, we can use tree_.__getstate__()
Tree structure translated into an "ASCII art" picture
0
_____________
1 2
______________
3 12
_______ _______
4 7 13 16
___ ______ _____
5 6 8 9 14 15
_____
10 11
tree structure as an array.
In [38]: tree.tree_.__getstate__()['nodes']
Out[38]:
array([(1, 2, 3, 0.800000011920929, 0.6666666666666667, 150, 150.0),
(-1, -1, -2, -2.0, 0.0, 50, 50.0),
(3, 12, 3, 1.75, 0.5, 100, 100.0),
(4, 7, 2, 4.949999809265137, 0.16803840877914955, 54, 54.0),
(5, 6, 3, 1.6500000953674316, 0.04079861111111116, 48, 48.0),
(-1, -1, -2, -2.0, 0.0, 47, 47.0),
(-1, -1, -2, -2.0, 0.0, 1, 1.0),
(8, 9, 3, 1.5499999523162842, 0.4444444444444444, 6, 6.0),
(-1, -1, -2, -2.0, 0.0, 3, 3.0),
(10, 11, 2, 5.449999809265137, 0.4444444444444444, 3, 3.0),
(-1, -1, -2, -2.0, 0.0, 2, 2.0),
(-1, -1, -2, -2.0, 0.0, 1, 1.0),
(13, 16, 2, 4.850000381469727, 0.042533081285444196, 46, 46.0),
(14, 15, 1, 3.0999999046325684, 0.4444444444444444, 3, 3.0),
(-1, -1, -2, -2.0, 0.0, 2, 2.0),
(-1, -1, -2, -2.0, 0.0, 1, 1.0),
(-1, -1, -2, -2.0, 0.0, 43, 43.0)],
dtype=[('left_child', '<i8'), ('right_child', '<i8'),
('feature', '<i8'), ('threshold', '<f8'),
('impurity', '<f8'), ('n_node_samples', '<i8'),
('weighted_n_node_samples', '<f8')])
Where:
- The first node [0] is the root node.
- internal nodes have left_child and right_child refering to nodes with positive values, and greater than the current node.
- leaves have -1 value for the left and right child nodes.
- nodes 1,5,6, 8,10,11,14,15,16 are leaves.
- the node structure is built using the Depth First Search Algorithm.
- the feature field tells us which of the iris.data features was used in the node to determine the path for this sample.
- the threshold tells us the value used to evaluate the direction based on the feature.
- impurity reaches 0 at the leaves... since all the samples are in the same class once you reach the leaf.
- n_node_samples tells us how many samples reach each leaf.
Using this information we could trivially track each sample X to the leaf where it eventually lands by following the classification rules and thresholds on a script. Additionally, the n_node_samples would allow us to perform unit tests ensuring that each node gets the correct number of samples.Then using the output of tree.predict, we could map each leaf to the associated class.