1

I am extracting decision rules from random forest, and I have read reference link :

how extraction decision rules of random forest in python

this code output is :

TREE: 0
0 NODE: if feature[33] < 2.5 then next=1 else next=4
1 NODE: if feature[38] < 0.5 then next=2 else next=3
2 LEAF: return class=2
3 LEAF: return class=9
4 NODE: if feature[50] < 8.5 then next=5 else next=6
5 LEAF: return class=4
6 LEAF: return class=0
...

but it is not a ideal output. It is not rules, just print trees.

ideal output is :

IF weight>80 AND weight<150 AND height<180 THEN figure=fat

I don't know how to generate ideal output. Looking forward to your help!

Rachel
  • 21
  • 1
  • 5
  • Does this answer your question? [How to extract the decision rules from scikit-learn decision-tree?](https://stackoverflow.com/questions/20224526/how-to-extract-the-decision-rules-from-scikit-learn-decision-tree) – ldmtwo Jan 28 '21 at 15:13

2 Answers2

2

Here's the solution according to your requirement. This will give you the decision rules used by each base learner(i.e value used in n_estimator in sklearn's RandomForestClassifier will be no of DecisionTree used.)

from sklearn import metrics, datasets, ensemble
from sklearn.tree import _tree

#Decision Rules to code utility
def dtree_to_code(tree, feature_names, tree_idx):
        """
        Decision tree rules in the form of Code.
        """
        tree_ = tree.tree_
        feature_name = [
            feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
            for i in tree_.feature
        ]
        print('def tree_{1}({0}):'.format(", ".join(feature_names),tree_idx))

        def recurse(node, depth):
            indent = "  " * depth
            if tree_.feature[node] != _tree.TREE_UNDEFINED:
                name = feature_name[node]
                threshold = tree_.threshold[node]
                print ('{0}if {1} <= {2}:'.format(indent, name, threshold))
                recurse(tree_.children_left[node], depth + 1)
                print ('{0}else:  # if {1} > {2}'.format(indent, name, threshold))
                recurse(tree_.children_right[node], depth + 1)
            else:
                print ('{0}return {1}'.format(indent, tree_.value[node]))
        recurse(0, 1)
def rf_to_code(rf,feature_names):
    """
    Conversion of Random forest Decision rules to code.
    """
    for base_learner_id, base_learner in enumerate(rf.estimators_):
        dtree_to_code(tree = base_learner,feature_names=feature_names,tree_idx=base_learner_id)

I got the decision rules code from here How to extract the decision rules from scikit-learn decision-tree??

#clf : RandomForestClassifier(n_estimator=100)
#df :  Iris Dataframe

rf_to_code(rf=clf,feature_names=df.columns)

If everything goes well expected Output :

def tree_0(sepal length, sepal width, petal length, petal width, species):
  if sepal length <= 5.549999952316284:
    if petal length <= 2.350000023841858:
      return [[40.  0.  0.]]
    else:  # if petal length > 2.350000023841858
      return [[0. 5. 0.]]
  else:  # if sepal length > 5.549999952316284
    if petal length <= 4.75:
      if petal width <= 0.7000000029802322:
        return [[2. 0. 0.]]
      else:  # if petal width > 0.7000000029802322
        return [[ 0. 22.  0.]]
    else:  # if petal length > 4.75
      if sepal width <= 3.049999952316284:
        if petal length <= 5.1499998569488525:
          if sepal length <= 5.950000047683716:
            return [[0. 0. 6.]]
          else:  # if sepal length > 5.950000047683716
            if petal width <= 1.75:
              return [[0. 3. 0.]]
            else:  # if petal width > 1.75
              return [[0. 0. 1.]]
        else:  # if petal length > 5.1499998569488525
          return [[ 0.  0. 15.]]
      else:  # if sepal width > 3.049999952316284
        return [[ 0.  0. 11.]]
def tree_1(sepal length, sepal width, petal length, petal width, species):
  if petal length <= 2.350000023841858:
    return [[39.  0.  0.]]
  else:  # if petal length > 2.350000023841858
    if petal length <= 4.950000047683716:
      if petal length <= 4.799999952316284:
        return [[ 0. 29.  0.]]
      else:  # if petal length > 4.799999952316284
        if sepal width <= 2.9499999284744263:
          if petal width <= 1.75:
            return [[0. 1. 0.]]
          else:  # if petal width > 1.75
            return [[0. 0. 2.]]
        else:  # if sepal width > 2.9499999284744263
          return [[0. 3. 0.]]
    else:  # if petal length > 4.950000047683716
      return [[ 0.  0. 31.]]
......
def tree_99(sepal length, sepal width, petal length, petal width, species):
  if sepal length <= 5.549999952316284:
    if petal width <= 0.75:
      return [[28.  0.  0.]]
    else:  # if petal width > 0.75
      return [[0. 4. 0.]]
  else:  # if sepal length > 5.549999952316284
    if petal width <= 1.699999988079071:
      if petal length <= 4.950000047683716:
        if petal width <= 0.7000000029802322:
          return [[3. 0. 0.]]
        else:  # if petal width > 0.7000000029802322
          return [[ 0. 42.  0.]]
      else:  # if petal length > 4.950000047683716
        if sepal length <= 6.049999952316284:
          if sepal width <= 2.450000047683716:
            return [[0. 0. 2.]]
          else:  # if sepal width > 2.450000047683716
            return [[0. 1. 0.]]
        else:  # if sepal length > 6.049999952316284
          return [[0. 0. 3.]]
    else:  # if petal width > 1.699999988079071
      return [[ 0.  0. 22.]]

Since n_estimators = 100 you'll get a total of 100 such functions.

0

Based on another answer... cross compatibile and only uses one variable X.

from sklearn import metrics, datasets, ensemble
from sklearn.tree import _tree

#Decision Rules to code utility
def dtree_to_code(fout,tree, variables, feature_names, tree_idx):
        """
        Decision tree rules in the form of Code.
        """
        f = fout
        tree_ = tree.tree_
        feature_name = [
            variables[i] if i != _tree.TREE_UNDEFINED else "undefined!"
            for i in tree_.feature
        ]
        if tree_idx<=0:
            f.write('def predict(X):\n\tret = 0\n')

        def recurse(node, depth):
            indent = "\t" * depth
            if tree_.feature[node] != _tree.TREE_UNDEFINED:
                variable = variables[node]
                name = feature_names[node]
                threshold = tree_.threshold[node]
                f.write('%sif %s <= %s: # if %s <= %s\n'%(indent, variable, threshold, name, threshold))
                recurse(tree_.children_left[node], depth + 1)
                f.write ('%selse:  # if %s > %s\n'%(indent, name, threshold))
                recurse(tree_.children_right[node], depth + 1)
            else:
                yhat = np.argmax(tree_.value[node][0])
                if yhat!=0:
                    f.write("%sret += %s\n"%(indent, yhat))
                else:
                    f.write("%spass\n"%(indent))
        recurse(0, 1)
def rf_to_code(f,rf,variables,feature_names):
    """
    Conversion of Random forest Decision rules to code.
    """
    for base_learner_id, base_learner in enumerate(rf.estimators_):
        dtree_to_code(f, tree=base_learner, variables=variables, feature_names=feature_names, tree_idx=base_learner_id)
    f.write('\treturn ret/%s\n'%(base_learner_id+1))

with open('_model.py', 'w') as f:
    f.write('''
from numba import jit,njit
@njit\n''')
    labels = ['w_%s'%word for word in d_q2i.keys()]
    variables = ['X[%s]'%i for i,word in enumerate(d_q2i.keys())]
    rf_to_code(f,estimator,variables,labels)  

Output looks like this. X is 1d vector to represent a single instance's features.

from numba import jit,njit
@njit
def predict(X):
    ret = 0
    if X[0] <= 0.5: # if w_pizza <= 0.5
        if X[1] <= 0.5: # if w_mexico <= 0.5
            if X[2] <= 0.5: # if w_reusable <= 0.5
                ret += 1
            else:  # if w_reusable > 0.5
                pass
        else:  # if w_mexico > 0.5
            ret += 1
    else:  # if w_pizza > 0.5
        pass
    if X[0] <= 0.5: # if w_pizza <= 0.5
        if X[1] <= 0.5: # if w_mexico <= 0.5
            if X[2] <= 0.5: # if w_reusable <= 0.5
                ret += 1
            else:  # if w_reusable > 0.5
                pass
        else:  # if w_mexico > 0.5
            pass
    else:  # if w_pizza > 0.5
        ret += 1
    if X[0] <= 0.5: # if w_pizza <= 0.5
        if X[1] <= 0.5: # if w_mexico <= 0.5
            if X[2] <= 0.5: # if w_reusable <= 0.5
                ret += 1
            else:  # if w_reusable > 0.5
                ret += 1
        else:  # if w_mexico > 0.5
            ret += 1
    else:  # if w_pizza > 0.5
        pass
    if X[0] <= 0.5: # if w_pizza <= 0.5
        if X[1] <= 0.5: # if w_mexico <= 0.5
            if X[2] <= 0.5: # if w_reusable <= 0.5
                ret += 1
            else:  # if w_reusable > 0.5
                ret += 1
        else:  # if w_mexico > 0.5
            pass
    else:  # if w_pizza > 0.5
        ret += 1
    if X[0] <= 0.5: # if w_pizza <= 0.5
        if X[1] <= 0.5: # if w_mexico <= 0.5
            if X[2] <= 0.5: # if w_reusable <= 0.5
                ret += 1
            else:  # if w_reusable > 0.5
                pass
        else:  # if w_mexico > 0.5
            pass
    else:  # if w_pizza > 0.5
        pass
    if X[0] <= 0.5: # if w_pizza <= 0.5
        if X[1] <= 0.5: # if w_mexico <= 0.5
            if X[2] <= 0.5: # if w_reusable <= 0.5
                ret += 1
            else:  # if w_reusable > 0.5
                pass
        else:  # if w_mexico > 0.5
            ret += 1
    else:  # if w_pizza > 0.5
        ret += 1
    if X[0] <= 0.5: # if w_pizza <= 0.5
        if X[1] <= 0.5: # if w_mexico <= 0.5
            if X[2] <= 0.5: # if w_reusable <= 0.5
                ret += 1
            else:  # if w_reusable > 0.5
                pass
        else:  # if w_mexico > 0.5
            pass
    else:  # if w_pizza > 0.5
        ret += 1
    if X[0] <= 0.5: # if w_pizza <= 0.5
        if X[1] <= 0.5: # if w_mexico <= 0.5
            if X[2] <= 0.5: # if w_reusable <= 0.5
                ret += 1
            else:  # if w_reusable > 0.5
                pass
        else:  # if w_mexico > 0.5
            pass
    else:  # if w_pizza > 0.5
        pass
    if X[0] <= 0.5: # if w_pizza <= 0.5
        if X[1] <= 0.5: # if w_mexico <= 0.5
            if X[2] <= 0.5: # if w_reusable <= 0.5
                ret += 1
            else:  # if w_reusable > 0.5
                pass
        else:  # if w_mexico > 0.5
            pass
    else:  # if w_pizza > 0.5
        pass
    if X[0] <= 0.5: # if w_pizza <= 0.5
        if X[1] <= 0.5: # if w_mexico <= 0.5
            if X[2] <= 0.5: # if w_reusable <= 0.5
                ret += 1
            else:  # if w_reusable > 0.5
                pass
        else:  # if w_mexico > 0.5
            pass
    else:  # if w_pizza > 0.5
        pass
    return ret/10
ldmtwo
  • 419
  • 5
  • 14