Here's the solution according to your requirement.
This will give you the decision rules used by each base learner(i.e value used in n_estimator in sklearn's RandomForestClassifier will be no of DecisionTree used.)
from sklearn import metrics, datasets, ensemble
from sklearn.tree import _tree
#Decision Rules to code utility
def dtree_to_code(tree, feature_names, tree_idx):
"""
Decision tree rules in the form of Code.
"""
tree_ = tree.tree_
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
print('def tree_{1}({0}):'.format(", ".join(feature_names),tree_idx))
def recurse(node, depth):
indent = " " * depth
if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
print ('{0}if {1} <= {2}:'.format(indent, name, threshold))
recurse(tree_.children_left[node], depth + 1)
print ('{0}else: # if {1} > {2}'.format(indent, name, threshold))
recurse(tree_.children_right[node], depth + 1)
else:
print ('{0}return {1}'.format(indent, tree_.value[node]))
recurse(0, 1)
def rf_to_code(rf,feature_names):
"""
Conversion of Random forest Decision rules to code.
"""
for base_learner_id, base_learner in enumerate(rf.estimators_):
dtree_to_code(tree = base_learner,feature_names=feature_names,tree_idx=base_learner_id)
I got the decision rules code from here
How to extract the decision rules from scikit-learn decision-tree??
#clf : RandomForestClassifier(n_estimator=100)
#df : Iris Dataframe
rf_to_code(rf=clf,feature_names=df.columns)
If everything goes well expected Output :
def tree_0(sepal length, sepal width, petal length, petal width, species):
if sepal length <= 5.549999952316284:
if petal length <= 2.350000023841858:
return [[40. 0. 0.]]
else: # if petal length > 2.350000023841858
return [[0. 5. 0.]]
else: # if sepal length > 5.549999952316284
if petal length <= 4.75:
if petal width <= 0.7000000029802322:
return [[2. 0. 0.]]
else: # if petal width > 0.7000000029802322
return [[ 0. 22. 0.]]
else: # if petal length > 4.75
if sepal width <= 3.049999952316284:
if petal length <= 5.1499998569488525:
if sepal length <= 5.950000047683716:
return [[0. 0. 6.]]
else: # if sepal length > 5.950000047683716
if petal width <= 1.75:
return [[0. 3. 0.]]
else: # if petal width > 1.75
return [[0. 0. 1.]]
else: # if petal length > 5.1499998569488525
return [[ 0. 0. 15.]]
else: # if sepal width > 3.049999952316284
return [[ 0. 0. 11.]]
def tree_1(sepal length, sepal width, petal length, petal width, species):
if petal length <= 2.350000023841858:
return [[39. 0. 0.]]
else: # if petal length > 2.350000023841858
if petal length <= 4.950000047683716:
if petal length <= 4.799999952316284:
return [[ 0. 29. 0.]]
else: # if petal length > 4.799999952316284
if sepal width <= 2.9499999284744263:
if petal width <= 1.75:
return [[0. 1. 0.]]
else: # if petal width > 1.75
return [[0. 0. 2.]]
else: # if sepal width > 2.9499999284744263
return [[0. 3. 0.]]
else: # if petal length > 4.950000047683716
return [[ 0. 0. 31.]]
......
def tree_99(sepal length, sepal width, petal length, petal width, species):
if sepal length <= 5.549999952316284:
if petal width <= 0.75:
return [[28. 0. 0.]]
else: # if petal width > 0.75
return [[0. 4. 0.]]
else: # if sepal length > 5.549999952316284
if petal width <= 1.699999988079071:
if petal length <= 4.950000047683716:
if petal width <= 0.7000000029802322:
return [[3. 0. 0.]]
else: # if petal width > 0.7000000029802322
return [[ 0. 42. 0.]]
else: # if petal length > 4.950000047683716
if sepal length <= 6.049999952316284:
if sepal width <= 2.450000047683716:
return [[0. 0. 2.]]
else: # if sepal width > 2.450000047683716
return [[0. 1. 0.]]
else: # if sepal length > 6.049999952316284
return [[0. 0. 3.]]
else: # if petal width > 1.699999988079071
return [[ 0. 0. 22.]]
Since n_estimators = 100 you'll get a total of 100 such functions.