For this, you can manually traverse the fitted tree, accessing properties not available through public api.
First, let's get a fitted tree, using the "iris" dataset:
import numpy as np # linear algebra
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt
data = load_iris()
clf = DecisionTreeClassifier(max_depth=3).fit(data['data'],data['target'])
Let's visualize this tree, primarily to debug our final program:
plt.figure(figsize=(10,8))
plot_tree(clf,feature_names=data['feature_names'],class_names=data['target_names'],filled=True);
Which outputs in my case:
Now the main part. From this link, we know that-
The binary tree "tree_" is represented as a number of parallel arrays.
The i-th element of each array holds information about the node i
.
The arrays that we need are feature
,value
, threshold
and two children_*
. So, starting from root (i=0
), we first collect the feature and threshold for each node we visit, ask the user for value of that particular feature, and traverse left or right by comparing given value with threshold. When we reach a leaf, we find the most frequent class in that leaf, and that ends our loop.
tree = clf.tree_
node = 0 #Index of root node
while True:
feat,thres = tree.feature[node],tree.threshold[node]
print(feat,thres)
v = float(input(f"The value of {data['feature_names'][feat]}: "))
if v<=thres:
node = tree.children_left[node]
else:
node = tree.children_right[node]
if tree.children_left[node] == tree.children_right[node]: #Check for leaf
label = np.argmax(tree.value[node])
print("We've reached a leaf")
print(f"Predicted Label is: {data['target_names'][label]}")
break
An example of such a run for above tree is:
3 0.800000011920929
The value of petal width (cm): 1
3 1.75
The value of petal width (cm): 1.5
2 4.950000047683716
The value of petal length (cm): 5.96
We've reached a leaf
Predicted Label is: virginica