1

I am using Random Forest Classifier from Skleran. I have trained and tuned my model.

My dataset contains 40 samples, each with 4 features, and there are two classes in which I want to classify my samples.

Now my question is: I want to save trees formed by this model and load it again in another script to make predictions.

Note- I am aware of joblib and pickle modules, which save the models in ".sav" files but I don't want to save that instance of the model.

I found a very interesting way of doing this by using sklearns's "tree.export_graphviz". This is the code I used to save the trees:

from sklearn.ensemble import RandomForestClassifier 
from sklearn.tree import export_graphviz

model=RandomForestClassifier()
model.fit(X, Y)

i_tree=0
for tree in model.estimators_:
    with open('iris_tree_' + str(i_tree) + '.dot', 'w') as my_file:
        my_file = export_graphviz(tree, out_file = my_file)
    i_tree = i_tree + 1

The problem I am facing is how to use these trees for making predictions?

Saved files contain tree in this format:

digraph Tree {
node [shape=box] ;
0 [label="X[3] <= 0.4\ngini = 0.4387\nsamples = 20\nvalue = [27, 13]"] ;
1 [label="gini = 0.0\nsamples = 7\nvalue = [0, 13]"] ;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
2 [label="gini = 0.0\nsamples = 13\nvalue = [27, 0]"] ;
0 -> 2 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
}

The data can be converted into a tree using the online portal of graphviz.

This data when converted looks like this

How to parse this type of data?

I am mostly interested in those "X[3]<=0.4" values in every block of my tree. I just need to know if there is any condition like "X[3]<=0.4" in any of the block of my tree(as the tree can be nested)

Blessy
  • 19
  • 5
  • Is your goal to get some graphical output, a realization of the generated `Tree` graph you quoted in your question? – TomServo Jul 19 '17 at 12:05
  • No, I don't need a graphical output. Any format from which I can get the parameters like - at what values the split is occurring, the left and the right children. See, I need these values so that I can make predictions on an unknown dataset, without using sklearn's predict function, as Trees are just if and else loops on different conditions. – Blessy Jul 19 '17 at 12:24
  • I see. So of the data in the `digraph Tree` example you posted, what exactly do you wish to extract for your purposes? Perhaps I can help with that. – TomServo Jul 19 '17 at 13:04
  • Thank you, I need those "X[3]<=0.4" values in every block of my tree. I have added an image of the tree formed by that data. Please have a look at that. I just need to know if there is any condition like "X[3]<=0.4" in any of my block of my tree(as the tree can be nested), and if yes then what's that. – Blessy Jul 19 '17 at 13:43
  • Please see my regular expression solution below. – TomServo Jul 19 '17 at 14:00
  • Possible duplicate of [how to extract the decision rules from scikit-learn decision-tree?](https://stackoverflow.com/questions/20224526/how-to-extract-the-decision-rules-from-scikit-learn-decision-tree) – zfisher Jul 19 '17 at 14:03

1 Answers1

1

If it's truly that small snip you're looking for, you could consider using a regular expression such as:

\D\[\d+\]\s+<=\s+\d+\.\d+

That is, "non-digit character, open bracket, some digits, close bracket, whitespace, <= symbol, whitespace, some digits, decimal point, some digits." I tested this regex on your text and it matches that snip and nothing else.

TomServo
  • 7,248
  • 5
  • 30
  • 47